[XLA] Don't allocate request identifiers to alternate mem.
PiperOrigin-RevId: 289733771
Change-Id: Ib1a0324648952a4ea88be91890de568d34456018
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 7fe4913..1000ef0 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1375,8 +1375,8 @@
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
auto preset_allocations_iter = preset_allocations.find(value.color());
CHECK(preset_allocations_iter != preset_allocations.end())
- << "No preset value allocation for color " << value.color()
- << " found.";
+ << "No preset value allocation for color " << value.color() << " for "
+ << value.ToShortString() << " found.";
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
chunk.size);
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index c721ebc..4b73365 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -258,6 +258,51 @@
return colocated_intervals;
}
+bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
+ const BufferInterval& interval) const {
+ // If the buffer is a tuple, don't use this algorithm for now. The buffers
+ // that are pointed to by the tuple will still use this algorithm. Because
+ // tuples are cheap to place in the alternate memory (they are just pointers)
+ // we don't need to use prefetch/evict logic.
+ if (interval.buffer->shape().IsTuple()) {
+ VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
+ << " in default mem because it is a tuple.";
+ return false;
+ }
+
+ // The semantics of TupleSelect are weird: TupleSelect doesn't define a
+ // buffer, but just forwards the buffers in the either left or right side.
+ // This means the the two different inputs to TupleSelect must not alias, yet
+ // they should be allocated in the same memory space, and both buffers must be
+ // kept alive for the entire live range of TupleSelect. Instead, just don't
+ // allocate TupleSelect in the alternate memory space.
+ // TODO(berkin): Not allocating add-dependencies either since they need to be
+ // treated specially. We should revisit this later.
+ for (const HloPosition& position : interval.buffer->positions()) {
+ if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
+ position.instruction->opcode() == HloOpcode::kAddDependency) {
+ VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
+ << " in default mem because it has a tuple-select or "
+ << "add-dependency position.";
+ return false;
+ }
+ }
+
+ // Send and Recv HLOs return a request identifier. These should not be
+ // allocated in the alternate memory.
+ const HloPosition& defining_position = interval.buffer->defining_position();
+ if ((defining_position.instruction->opcode() == HloOpcode::kSend ||
+ defining_position.instruction->opcode() == HloOpcode::kRecv) &&
+ defining_position.index == ShapeIndex({1})) {
+ VLOG(4)
+ << "Keeping value " << interval.buffer->ToShortString()
+ << " in default mem because it is a request identifier for send/recv.";
+ return false;
+ }
+
+ return true;
+}
+
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
std::vector<BufferInterval> sorted_buffer_intervals =
GetSortedBufferIntervals();
@@ -279,36 +324,7 @@
continue;
}
- // If the buffer is a tuple, don't use this algorithm for now. The buffers
- // that are pointed to by the tuple will still use this algorithm. Because
- // tuples are cheap to place in the alternate memory (they are just
- // pointers) we don't need to use prefetch/evict logic.
- if (interval.buffer->shape().IsTuple()) {
- VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
- << " in default mem because it is a tuple.";
- continue;
- }
-
- // The semantics of TupleSelect are weird: TupleSelect doesn't define a
- // buffer, but just forwards the buffers in the either left or right side.
- // This means the the two different inputs to TupleSelect must not alias,
- // yet they should be allocated in the same memory space, and both buffers
- // must be kept alive for the entire live range of TupleSelect. Instead,
- // just don't allocate TupleSelect in the alternate memory space.
- // TODO(berkin): Not allocating add-dependencies either since they need to
- // be treated specially. We should revisit this later.
- bool keep_in_default_mem = false;
- for (const HloPosition& position : interval.buffer->positions()) {
- if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
- position.instruction->opcode() == HloOpcode::kAddDependency) {
- keep_in_default_mem = true;
- VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
- << " in default mem because it has a tuple-select or "
- << "add-dependency position.";
- break;
- }
- }
- if (keep_in_default_mem) {
+ if (!IsIntervalAllowedInAlternateMemory(interval)) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index b1ff0b4..50b1a16 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -621,6 +621,10 @@
// it is a parameter in default memory or an ouput in default memory.
bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const;
+ // Returns true if this buffer is allowed to be placed in the alternate
+ // memory.
+ bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
+
// Finds an allocation for the given interval. Internally, it will attempt to
// find a suitable chunk candidate within the heap size and prefetch interval
// limits, and append the new allocation(s) to allocations. The new
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 8f1c1c3..fd1c804 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -1268,6 +1268,42 @@
AssignMemorySpace(module.get());
}
+TEST_P(MemorySpaceAssignmentTest,
+ RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
+ // Ensure that request identifier returned by Send/Recv HLOs are not allocated
+ // in the alternate memory.
+ absl::string_view hlo_string = R"(
+ HloModule SendRecv, is_scheduled=true
+
+ ENTRY %AddDependency (p: f32[3]) -> f32[3] {
+ %p = f32[3]{0} parameter(0)
+ %after-all = token[] after-all()
+ %recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
+ %recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
+ %token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
+ %data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
+ %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
+ %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
+ ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ AssignMemorySpace(module.get());
+
+ for (const HloInstruction* instruction :
+ module->entry_computation()->instructions()) {
+ if (instruction->opcode() == HloOpcode::kSend ||
+ instruction->opcode() == HloOpcode::kRecv) {
+ const Shape& request_identifier_shape =
+ ShapeUtil::GetSubshape(instruction->shape(), {1});
+ EXPECT_NE(request_identifier_shape.layout().memory_space(),
+ kAlternateMemorySpace);
+ }
+ }
+}
+
TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
// Test that checks the last use optimization. It uses two buffers that should
// be placed in alternate memory.