[XLA] When checking async copy ordering, also check for pending copies.
PiperOrigin-RevId: 291286705
Change-Id: I6fb6303c6ad92c0f7de8a2aa198196f376d96f00
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 8d9510f..21b2226 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -859,8 +859,8 @@
VLOG(4) << "This would violate the outstanding async copy limit.";
continue;
}
- if (async_copy_ordering_.ViolatesOrdering(alternate_mem_interval.start,
- alternate_mem_interval.end)) {
+ if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
+ alternate_mem_interval.end)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
continue;
}
@@ -937,6 +937,23 @@
return num_async_copies + 1 > options_.max_outstanding_async_copies;
}
+bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
+ int64 start_time, int64 end_time) const {
+ if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) {
+ return true;
+ }
+
+ // Also check pending async copies.
+ for (const auto& async_copy : pending_async_copies_) {
+ if (async_copy.destination == MemorySpace::kAlternate &&
+ async_copy.start_time <= end_time &&
+ start_time <= async_copy.end_time) {
+ return true;
+ }
+ }
+ return false;
+}
+
bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
int64 start_time, int64 end_time, int64 last_use_time,
HloPosition defining_position, HloUse use,
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 9bf04a0..53fe2e4 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -688,6 +688,9 @@
bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time,
int64 end_time) const;
+ // Return true if the asynchronous copy would violate the pipelining order.
+ bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
+
// Adds an asynchronous copy to the allocations.
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, Chunk chunk, int64 start_time,