[XLA] Don't allocate request identifiers to alternate mem.

PiperOrigin-RevId: 289733771
Change-Id: Ib1a0324648952a4ea88be91890de568d34456018
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 7fe4913..1000ef0 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1375,8 +1375,8 @@
     const HeapSimulator::Chunk& chunk = position_and_chunk.second;
     auto preset_allocations_iter = preset_allocations.find(value.color());
     CHECK(preset_allocations_iter != preset_allocations.end())
-        << "No preset value allocation for color " << value.color()
-        << " found.";
+        << "No preset value allocation for color " << value.color() << " for "
+        << value.ToShortString() << " found.";
     preset_allocations_iter->second->AddAssignment(value, chunk.offset,
                                                    chunk.size);
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index c721ebc..4b73365 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -258,6 +258,51 @@
   return colocated_intervals;
 }
 
+bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
+    const BufferInterval& interval) const {
+  // If the buffer is a tuple, don't use this algorithm for now. The buffers
+  // that are pointed to by the tuple will still use this algorithm.  Because
+  // tuples are cheap to place in the alternate memory (they are just pointers)
+  // we don't need to use prefetch/evict logic.
+  if (interval.buffer->shape().IsTuple()) {
+    VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
+            << " in default mem because it is a tuple.";
+    return false;
+  }
+
+  // The semantics of TupleSelect are weird: TupleSelect doesn't define a
+  // buffer, but just forwards the buffers in the either left or right side.
+  // This means the the two different inputs to TupleSelect must not alias, yet
+  // they should be allocated in the same memory space, and both buffers must be
+  // kept alive for the entire live range of TupleSelect. Instead, just don't
+  // allocate TupleSelect in the alternate memory space.
+  // TODO(berkin): Not allocating add-dependencies either since they need to be
+  // treated specially. We should revisit this later.
+  for (const HloPosition& position : interval.buffer->positions()) {
+    if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
+        position.instruction->opcode() == HloOpcode::kAddDependency) {
+      VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
+              << " in default mem because it has a tuple-select or "
+              << "add-dependency position.";
+      return false;
+    }
+  }
+
+  // Send and Recv HLOs return a request identifier. These should not be
+  // allocated in the alternate memory.
+  const HloPosition& defining_position = interval.buffer->defining_position();
+  if ((defining_position.instruction->opcode() == HloOpcode::kSend ||
+       defining_position.instruction->opcode() == HloOpcode::kRecv) &&
+      defining_position.index == ShapeIndex({1})) {
+    VLOG(4)
+        << "Keeping value " << interval.buffer->ToShortString()
+        << " in default mem because it is a request identifier for send/recv.";
+    return false;
+  }
+
+  return true;
+}
+
 HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
   std::vector<BufferInterval> sorted_buffer_intervals =
       GetSortedBufferIntervals();
@@ -279,36 +324,7 @@
       continue;
     }
 
-    // If the buffer is a tuple, don't use this algorithm for now. The buffers
-    // that are pointed to by the tuple will still use this algorithm.  Because
-    // tuples are cheap to place in the alternate memory (they are just
-    // pointers) we don't need to use prefetch/evict logic.
-    if (interval.buffer->shape().IsTuple()) {
-      VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
-              << " in default mem because it is a tuple.";
-      continue;
-    }
-
-    // The semantics of TupleSelect are weird: TupleSelect doesn't define a
-    // buffer, but just forwards the buffers in the either left or right side.
-    // This means the the two different inputs to TupleSelect must not alias,
-    // yet they should be allocated in the same memory space, and both buffers
-    // must be kept alive for the entire live range of TupleSelect. Instead,
-    // just don't allocate TupleSelect in the alternate memory space.
-    // TODO(berkin): Not allocating add-dependencies either since they need to
-    // be treated specially. We should revisit this later.
-    bool keep_in_default_mem = false;
-    for (const HloPosition& position : interval.buffer->positions()) {
-      if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
-          position.instruction->opcode() == HloOpcode::kAddDependency) {
-        keep_in_default_mem = true;
-        VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
-                << " in default mem because it has a tuple-select or "
-                << "add-dependency position.";
-        break;
-      }
-    }
-    if (keep_in_default_mem) {
+    if (!IsIntervalAllowedInAlternateMemory(interval)) {
       continue;
     }
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index b1ff0b4..50b1a16 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -621,6 +621,10 @@
   // it is a parameter in default memory or an ouput in default memory.
   bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const;
 
+  // Returns true if this buffer is allowed to be placed in the alternate
+  // memory.
+  bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
+
   // Finds an allocation for the given interval. Internally, it will attempt to
   // find a suitable chunk candidate within the heap size and prefetch interval
   // limits, and append the new allocation(s) to allocations. The new
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 8f1c1c3..fd1c804 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -1268,6 +1268,42 @@
   AssignMemorySpace(module.get());
 }
 
+TEST_P(MemorySpaceAssignmentTest,
+       RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
+  // Ensure that request identifier returned by Send/Recv HLOs are not allocated
+  // in the alternate memory.
+  absl::string_view hlo_string = R"(
+  HloModule SendRecv, is_scheduled=true
+
+  ENTRY %AddDependency (p: f32[3]) -> f32[3] {
+    %p = f32[3]{0} parameter(0)
+    %after-all = token[] after-all()
+    %recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
+    %recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
+    %token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
+    %data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
+    %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
+    %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
+    ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  AssignMemorySpace(module.get());
+
+  for (const HloInstruction* instruction :
+       module->entry_computation()->instructions()) {
+    if (instruction->opcode() == HloOpcode::kSend ||
+        instruction->opcode() == HloOpcode::kRecv) {
+      const Shape& request_identifier_shape =
+          ShapeUtil::GetSubshape(instruction->shape(), {1});
+      EXPECT_NE(request_identifier_shape.layout().memory_space(),
+                kAlternateMemorySpace);
+    }
+  }
+}
+
 TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
   // Test that checks the last use optimization. It uses two buffers that should
   // be placed in alternate memory.