[XLA] Don't issue redundant evictions when there is an earlier one before a while op

Consider the scenario:

a = foo()       // a in alternate mem
a' = evict(a)   // a' in default mem
a'' = while(a)  // a'' in alternate mem
// No need for eviction of a'' to default mem because a' is already in default mem.
PiperOrigin-RevId: 408965154
Change-Id: Ic837a918e83d7417aead3a246fefe6c8d826de73
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index c5136c6..fee95a6 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -1697,6 +1697,22 @@
                     << body_allocation_value_it->allocation_sequence()
                            ->back()
                            ->ToString();
+
+            auto after_while_allocation_value_it = absl::c_find_if(
+                allocation_values, [&](const AllocationValue& value) {
+                  return value.defining_instruction() == hlo_use.instruction;
+                });
+            CHECK_NE(after_while_allocation_value_it, allocation_values.end());
+            VLOG(3) << "After while allocation value: "
+                    << after_while_allocation_value_it->ToShortString();
+            int64_t while_time = instruction_schedule.at(hlo_use.instruction);
+            after_while_allocation_value_it->allocation_sequence()->push_back(
+                absl::make_unique<MemorySpaceAssignment::MirroredAllocation>(
+                    **prev_allocation_in_default_mem_it, while_time));
+            VLOG(3) << "Created: "
+                    << after_while_allocation_value_it->allocation_sequence()
+                           ->back()
+                           ->ToString();
           }
         }
         // Special case for while loops since the root offset must agree with
@@ -2414,12 +2430,11 @@
   auto prev_allocation_it = allocation_sequence->rbegin();
   // Find a previous allocation that is in the default memory space (not
   // necessarily the very last allocation).
-  auto prev_allocation_in_default_mem_it = std::find_if(
-      allocation_sequence->rbegin(), allocation_sequence->rend(),
-      [&](const auto& allocation) {
-        return allocation->memory_space() == MemorySpace::kDefault &&
-               allocation->defining_position() == defining_position;
-      });
+  auto prev_allocation_in_default_mem_it =
+      std::find_if(allocation_sequence->rbegin(), allocation_sequence->rend(),
+                   [&](const auto& allocation) {
+                     return allocation->memory_space() == MemorySpace::kDefault;
+                   });
 
   if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
       prev_allocation_it != allocation_sequence->rend() &&
@@ -3358,6 +3373,11 @@
                       prev_allocation_.ToString());
 }
 
+std::string MemorySpaceAssignment::MirroredAllocation::ToString() const {
+  return absl::StrCat("Mirrored Allocation for ",
+                      original_allocation_.ToString());
+}
+
 std::string MemorySpaceAssignment::ParentAllocation::ToString() const {
   return absl::StrCat("Parent Allocation mirrored at ",
                       defining_position_.ToString(), ", originally ",
@@ -3408,6 +3428,11 @@
   return Status::OK();
 }
 
+Status MemorySpaceAssignment::MirroredAllocation::Process() {
+  defining_position_ = original_allocation_.defining_position();
+  return Allocation::Process();
+}
+
 Status MemorySpaceAssignment::ParentAllocation::Process() {
   // Add an additional parameter to the while HLO with a reference to the buffer
   // in the default memory space.
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 168cd64..adb0a3d 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -656,6 +656,27 @@
     HloInstruction* copy_done_;
   };
 
+  // An allocation in the default memory space that mirrors another Allocation
+  // object. This is useful to model an eviction that happens before a while op
+  // so that we don't need to redundantly evict the buffer after the while op as
+  // well.
+  class MirroredAllocation : public Allocation {
+   public:
+    MirroredAllocation(const Allocation& original_allocation, int64_t time)
+        : Allocation(original_allocation.defining_position(),
+                     MemorySpace::kDefault, original_allocation.chunk(),
+                     /*start_time=*/time,
+                     /*end_time=*/time, /*is_scoped_allocation=*/false),
+          original_allocation_(original_allocation) {}
+
+    Status Process() override;
+
+    std::string ToString() const override;
+
+   private:
+    const Allocation& original_allocation_;
+  };
+
   // An allocation in default memory space that is defined in the parent
   // computation. If a value has a copy in the default memory space in the
   // parent computation, we don't need to evict this buffer in a while loop.
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index ec1152d..df80988 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -4687,6 +4687,243 @@
   }
 }
 
+TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile) {
+  absl::string_view hlo_string = R"(
+  HloModule module, is_scheduled=true
+
+  while_cond {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    ROOT gte = pred[] get-tuple-element(p0), index=2
+  }
+
+  while_body {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    gte0 = f32[3]{0} get-tuple-element(p0), index=0
+    gte1 = f32[3]{0} get-tuple-element(p0), index=1
+    gte2 = pred[] get-tuple-element(p0), index=2
+    add = f32[3]{0} add(gte0, gte1)
+    ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
+  }
+
+  ENTRY entry {
+    p0 = f32[3]{0} parameter(0)
+    p1 = pred[] parameter(1)
+    copy = f32[3]{0} copy(p0)
+    negate0 = f32[3]{0} negate(p0)
+    negate1 = f32[3]{0} negate(negate0)
+    negate2 = f32[3]{0} negate(negate1)
+    negate3 = f32[3]{0} negate(negate2)
+    negate4 = f32[3]{0} negate(negate3)
+    negate5 = f32[3]{0} negate(negate4)
+    negate6 = f32[3]{0} negate(negate5)
+    negate7 = f32[3]{0} negate(negate6)
+    negate8 = f32[3]{0} negate(negate7)
+    negate9 = f32[3]{0} negate(negate8)
+    negate10 = f32[3]{0} negate(negate9)
+    negate11 = f32[3]{0} negate(negate10)
+    negate12 = f32[3]{0} negate(negate11)
+    negate13 = f32[3]{0} negate(negate12)
+    negate14 = f32[3]{0} negate(negate13)
+    tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
+    while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
+    gte0 = f32[3]{0} get-tuple-element(while), index=0
+    gte1 = f32[3]{0} get-tuple-element(while), index=1
+    negate20 = f32[3]{0} negate(gte1)
+    negate21 = f32[3]{0} negate(negate20)
+    negate22 = f32[3]{0} negate(negate21)
+    negate23 = f32[3]{0} negate(negate22)
+    negate24 = f32[3]{0} negate(negate23)
+    negate25 = f32[3]{0} negate(negate24)
+    negate26 = f32[3]{0} negate(negate25)
+    negate27 = f32[3]{0} negate(negate26)
+    negate28 = f32[3]{0} negate(negate27)
+    negate29 = f32[3]{0} negate(negate28)
+    negate30 = f32[3]{0} negate(negate29)
+    negate31 = f32[3]{0} negate(negate30)
+    negate32 = f32[3]{0} negate(negate31)
+    negate33 = f32[3]{0} negate(negate32)
+    negate34 = f32[3]{0} negate(negate33)
+    ROOT add = f32[3]{0} add(negate34, gte0)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  AssignMemorySpace(module.get());
+
+  if (GetParam()) {
+    EXPECT_THAT(
+        module->entry_computation()->root_instruction()->operand(1),
+        op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, op::Copy()));
+  }
+}
+
+TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile2) {
+  absl::string_view hlo_string = R"(
+  HloModule module, is_scheduled=true
+
+  while_cond1 {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    ROOT gte = pred[] get-tuple-element(p0), index=2
+  }
+
+  while_body1 {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    gte0 = f32[3]{0} get-tuple-element(p0), index=0
+    gte1 = f32[3]{0} get-tuple-element(p0), index=1
+    gte2 = pred[] get-tuple-element(p0), index=2
+    add = f32[3]{0} add(gte0, gte1)
+    ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
+  }
+
+  while_cond2 {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    ROOT gte = pred[] get-tuple-element(p0), index=2
+  }
+
+  while_body2 {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    gte0 = f32[3]{0} get-tuple-element(p0), index=0
+    gte1 = f32[3]{0} get-tuple-element(p0), index=1
+    gte2 = pred[] get-tuple-element(p0), index=2
+    add = f32[3]{0} add(gte0, gte1)
+    ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
+  }
+
+  ENTRY entry {
+    p0 = f32[3]{0} parameter(0)
+    p1 = pred[] parameter(1)
+    copy = f32[3]{0} copy(p0)
+    tuple1 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
+    while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple1), condition=while_cond1, body=while_body1
+    gte0 = f32[3]{0} get-tuple-element(while1), index=0
+    gte1 = f32[3]{0} get-tuple-element(while1), index=1
+    negate0 = f32[3]{0} negate(gte1)
+    negate1 = f32[3]{0} negate(negate0)
+    negate2 = f32[3]{0} negate(negate1)
+    negate3 = f32[3]{0} negate(negate2)
+    negate4 = f32[3]{0} negate(negate3)
+    negate5 = f32[3]{0} negate(negate4)
+    negate6 = f32[3]{0} negate(negate5)
+    negate7 = f32[3]{0} negate(negate6)
+    negate8 = f32[3]{0} negate(negate7)
+    negate9 = f32[3]{0} negate(negate8)
+    negate10 = f32[3]{0} negate(negate9)
+    negate11 = f32[3]{0} negate(negate10)
+    negate12 = f32[3]{0} negate(negate11)
+    negate13 = f32[3]{0} negate(negate12)
+    negate14 = f32[3]{0} negate(negate13)
+    tuple2 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, negate14, p1)
+    while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple2), condition=while_cond2, body=while_body2
+    gte2 = f32[3]{0} get-tuple-element(while2), index=0
+    gte3 = f32[3]{0} get-tuple-element(while2), index=1
+    negate20 = f32[3]{0} negate(gte3)
+    negate21 = f32[3]{0} negate(negate20)
+    negate22 = f32[3]{0} negate(negate21)
+    negate23 = f32[3]{0} negate(negate22)
+    negate24 = f32[3]{0} negate(negate23)
+    negate25 = f32[3]{0} negate(negate24)
+    negate26 = f32[3]{0} negate(negate25)
+    negate27 = f32[3]{0} negate(negate26)
+    negate28 = f32[3]{0} negate(negate27)
+    negate29 = f32[3]{0} negate(negate28)
+    negate30 = f32[3]{0} negate(negate29)
+    negate31 = f32[3]{0} negate(negate30)
+    negate32 = f32[3]{0} negate(negate31)
+    negate33 = f32[3]{0} negate(negate32)
+    negate34 = f32[3]{0} negate(negate33)
+    ROOT add = f32[3]{0} add(negate34, gte2)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  AssignMemorySpace(module.get());
+
+  if (GetParam()) {
+    EXPECT_THAT(
+        module->entry_computation()->root_instruction()->operand(1),
+        op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
+                      op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
+                                    op::GetTupleElement(op::While()))));
+  }
+}
+
+TEST_P(MemorySpaceAssignmentTest,
+       AfterWhileRedundantEarlierEvictionModifiedBuffer) {
+  absl::string_view hlo_string = R"(
+  HloModule module, is_scheduled=true
+
+  while_cond {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    ROOT gte = pred[] get-tuple-element(p0), index=2
+  }
+
+  while_body {
+    p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+    gte0 = f32[3]{0} get-tuple-element(p0), index=0
+    gte1 = f32[3]{0} get-tuple-element(p0), index=1
+    gte2 = pred[] get-tuple-element(p0), index=2
+    add = f32[3]{0} add(gte0, gte1)
+    negate = f32[3]{0} negate(gte0)
+    ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(negate, add, gte2)
+  }
+
+  ENTRY entry {
+    p0 = f32[3]{0} parameter(0)
+    p1 = pred[] parameter(1)
+    copy = f32[3]{0} copy(p0)
+    negate0 = f32[3]{0} negate(p0)
+    negate1 = f32[3]{0} negate(negate0)
+    negate2 = f32[3]{0} negate(negate1)
+    negate3 = f32[3]{0} negate(negate2)
+    negate4 = f32[3]{0} negate(negate3)
+    negate5 = f32[3]{0} negate(negate4)
+    negate6 = f32[3]{0} negate(negate5)
+    negate7 = f32[3]{0} negate(negate6)
+    negate8 = f32[3]{0} negate(negate7)
+    negate9 = f32[3]{0} negate(negate8)
+    negate10 = f32[3]{0} negate(negate9)
+    negate11 = f32[3]{0} negate(negate10)
+    negate12 = f32[3]{0} negate(negate11)
+    negate13 = f32[3]{0} negate(negate12)
+    negate14 = f32[3]{0} negate(negate13)
+    tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
+    while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
+    gte0 = f32[3]{0} get-tuple-element(while), index=0
+    gte1 = f32[3]{0} get-tuple-element(while), index=1
+    negate20 = f32[3]{0} negate(gte1)
+    negate21 = f32[3]{0} negate(negate20)
+    negate22 = f32[3]{0} negate(negate21)
+    negate23 = f32[3]{0} negate(negate22)
+    negate24 = f32[3]{0} negate(negate23)
+    negate25 = f32[3]{0} negate(negate24)
+    negate26 = f32[3]{0} negate(negate25)
+    negate27 = f32[3]{0} negate(negate26)
+    negate28 = f32[3]{0} negate(negate27)
+    negate29 = f32[3]{0} negate(negate28)
+    negate30 = f32[3]{0} negate(negate29)
+    negate31 = f32[3]{0} negate(negate30)
+    negate32 = f32[3]{0} negate(negate31)
+    negate33 = f32[3]{0} negate(negate32)
+    negate34 = f32[3]{0} negate(negate33)
+    ROOT add = f32[3]{0} add(negate34, gte0)
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  AssignMemorySpace(module.get());
+
+  if (GetParam()) {
+    EXPECT_THAT(
+        module->entry_computation()->root_instruction()->operand(1),
+        op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
+                      op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
+                                    op::GetTupleElement(op::While()))));
+  }
+}
+
 TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
   // Tests against a bug where the root of entry computation is a bitcast
   // instruction and it ends up getting an allocation in the alternate memory.