[XLA] When checking async copy ordering, also check for pending copies.

PiperOrigin-RevId: 291286705
Change-Id: I6fb6303c6ad92c0f7de8a2aa198196f376d96f00
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 8d9510f..21b2226 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -859,8 +859,8 @@
       VLOG(4) << "This would violate the outstanding async copy limit.";
       continue;
     }
-    if (async_copy_ordering_.ViolatesOrdering(alternate_mem_interval.start,
-                                              alternate_mem_interval.end)) {
+    if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
+                                  alternate_mem_interval.end)) {
       VLOG(4) << "This would violate asynchronous copy ordering.";
       continue;
     }
@@ -937,6 +937,23 @@
   return num_async_copies + 1 > options_.max_outstanding_async_copies;
 }
 
+bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
+    int64 start_time, int64 end_time) const {
+  if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) {
+    return true;
+  }
+
+  // Also check pending async copies.
+  for (const auto& async_copy : pending_async_copies_) {
+    if (async_copy.destination == MemorySpace::kAlternate &&
+        async_copy.start_time <= end_time &&
+        start_time <= async_copy.end_time) {
+      return true;
+    }
+  }
+  return false;
+}
+
 bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
     int64 start_time, int64 end_time, int64 last_use_time,
     HloPosition defining_position, HloUse use,
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 9bf04a0..53fe2e4 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -688,6 +688,9 @@
   bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time,
                                              int64 end_time) const;
 
+  // Return true if the asynchronous copy would violate the pipelining order.
+  bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
+
   // Adds an asynchronous copy to the allocations.
   void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
                     MemorySpace memory_space, Chunk chunk, int64 start_time,