[XLA] Fix a scheduling bug with evictions to default mem.

When simplifying the graph for dead code, we were previously removing the
deleted instruction from the schedule. However, the scheduler, which is run
after SimplifyGraph, relies on the original logical time (index into the
instruction schedule). So, when some instructions have been deleted, we end up
scheduling certain operation later than intended. Most seriously, the evictions
could have been scheduled later than they were supposed to, corrupting the
memory since we might have reused the evicted memory. The solution is to mark
the deleted instructions with a nullptr in the schedule instead of actually
deleting them.

PiperOrigin-RevId: 283410438
Change-Id: Ia671ede14469dd00ac218cfb8d714e171152bb37
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 751d258..28c93fb 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -1130,7 +1130,15 @@
           // Ensure the exported preset assignments don't contain a refence to
           // the removed instruction.
           preset_assignments_->RemoveAssignmentForInstruction(instruction);
-          flattened_instruction_sequence_.remove_instruction(instruction);
+          // Instead of deleting the instruction from the schedule, replace it
+          // with a nullptr. This is needed because FixSchedule relies on the
+          // logical time that is the index into flattened_instructions_ for
+          // scheduling asynchronous copies.
+          auto instruction_it =
+              absl::c_find(flattened_instructions_, instruction);
+          if (instruction_it != flattened_instructions_.end()) {
+            *instruction_it = nullptr;
+          }
           TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
           computation_modified = true;
         } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
@@ -1228,12 +1236,12 @@
 
       // If the copy start doesn't happen to be scheduled at the correct
       // computation, delay it until the correct computation starts.
-      const auto& flattened_instructions =
-          flattened_instruction_sequence_.instructions();
       int64 copy_start_schedule_after =
           copy_allocation->copy_start_schedule_after();
+      // Accessing flattened_instructions_ here without checking if it is
+      // nullptr is safe because this method is called before SimplifyGraph.
       while (copy_allocation->instruction()->parent() !=
-             flattened_instructions[copy_start_schedule_after]->parent()) {
+             flattened_instructions_[copy_start_schedule_after]->parent()) {
         VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
                 << (copy_start_schedule_after + 1) << ") for "
                 << copy_allocation->copy_start()->ToString()
@@ -1264,8 +1272,7 @@
     VLOG(4) << "Scheduling: " << computation->ToString();
 
     for (int64 instruction_index = 0;
-         instruction_index <
-         flattened_instruction_sequence_.instructions().size();
+         instruction_index < flattened_instructions_.size();
          ++instruction_index) {
       auto insts_before_iter = schedule_before_.find(instruction_index);
       if (insts_before_iter != schedule_before_.end()) {
@@ -1276,10 +1283,11 @@
           }
         }
       }
-      HloInstruction* instruction =
-          flattened_instruction_sequence_.instructions()[instruction_index];
-      // Insert only if not previously inserted.
-      if (!inserted_instructions.contains(instruction) &&
+      HloInstruction* instruction = flattened_instructions_[instruction_index];
+      // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
+      // it was deleted) and not previously inserted.
+      if (instruction != nullptr &&
+          !inserted_instructions.contains(instruction) &&
           instruction->parent() == computation) {
         EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
                                              &inserted_instructions);
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index bfc9166..a8b3310 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -450,7 +450,8 @@
       absl::Span<HloInstruction* const> flattened_instructions)
       : module_(module),
         alternate_memory_space_(alternate_memory_space),
-        flattened_instruction_sequence_(flattened_instructions),
+        flattened_instructions_(flattened_instructions.begin(),
+                                flattened_instructions.end()),
         preset_assignments_(absl::make_unique<PresetAssignments>()) {}
 
   // Process calls Process methods of the allocations after the allocations have
@@ -479,7 +480,7 @@
 
   HloModule* module_;
   int64 alternate_memory_space_;
-  HloInstructionSequence flattened_instruction_sequence_;
+  std::vector<HloInstruction*> flattened_instructions_;
   AllocationMap allocation_map_;
   std::unique_ptr<PresetAssignments> preset_assignments_;
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 6041b96..7e3ce7d 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -2032,6 +2032,99 @@
   EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
 }
 
+TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) {
+  // This test reproduces an eviction scheduling bug where evictions to default
+  // memory can happen later than intended, causing memory corruption. This test
+  // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3]
+  // tensors instead, so at most two tensors should fit in the alternate memory
+  // space at a given time. We have a number of redundant operations
+  // (tanh_redundant ops) that do not have users. The bug was due to
+  // SimplifyGraph removing dead instructions, and removing them from the
+  // schedule. However, the CopyStart/CopyDone insertion relies on the schedule
+  // indexes, so they could be inserted too late.
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* tanh0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant4 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant5 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* tanh_redundant6 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
+  HloInstruction* negate0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0));
+  HloInstruction* tanh1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0));
+  HloInstruction* negate1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* tanh2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
+  HloInstruction* negate2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* tanh3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
+  HloInstruction* negate3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* tuple = builder.AddInstruction(
+      HloInstruction::CreateTuple({tanh3, negate3, tanh0}));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(
+      computation,
+      {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2,
+       tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6,
+       negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpaceUsingCostAnalysis(module.get());
+
+  TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
+                          HloAliasAnalysis::Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range,
+                          HloLiveRange::Run(module->schedule(), *alias_analysis,
+                                            module->entry_computation()));
+
+  std::vector<int> num_live_buffers_in_alternate_mem(
+      hlo_live_range->flattened_instruction_sequence().size() + 1, 0);
+
+  // Go through each value and for those that are allocated in the alternate
+  // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for
+  // every time step that they are live.
+  for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
+    const Shape& shape = value->shape();
+    if (!shape.has_layout() ||
+        shape.layout().memory_space() == kDefaultMemorySpace) {
+      continue;
+    }
+
+    HloLiveRange::TimeBound time_bound =
+        hlo_live_range->buffer_live_ranges().at(value);
+    for (int i = time_bound.start; i <= time_bound.end; ++i) {
+      ++num_live_buffers_in_alternate_mem[i];
+    }
+  }
+
+  // The test memory can at most hold two f32[4,3] buffers at a time. If there
+  // is more than that, it means we have memory corruption.
+  for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) {
+    EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2);
+  }
+}
+
 INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
                          MemorySpaceAssignmentTest,
                          ::testing::Values(false, true));