[XLA] Memory space assignment improvements

- It's no longer necessary to set schedules in the heap simulator. It's possible
  for memory space assignment to get flattened instruction sequence using
  the HloLiveRange object.
- MemorySpaceAssignment::FixSchedule now uses the flattened sequence. It was
  previously using per-computation sequence which was wrong.
- Instead of inserting CopyStart/CopyDone ops and coloring at the same time, now
  first insert all CopyStart/CopyDone's and then color the necessary HLOs.
- Separating CopyStart/CopyDone insertion and coloring enables us to re-run
  HloAliasAnalysis once the graph is finalized. We now use alias analysis to
  propagate the color to all HLOs in the same buffer (instead of special casing
  for bitcasts and tuples etc.)
- Support for embedded computations. Do not allow a CopyStart in one computation
  and its corresponding CopyDone in another.
- Rely on defining position to disambiguate previous allocations that might
  point to the same tensor as opposed to using the producing instruction (which
  can be wrong when e.g. there are two separate GetTupleElement instructions
  with the same index that point to the same defining position but separate
  operand instructions).

PiperOrigin-RevId: 269930471
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 660a098..fc7f36e 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -151,9 +151,6 @@
 
   HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
 
-  algorithm_->SetSchedules(&hlo_live_range->flattened_instruction_sequence(),
-                           &hlo_live_range->instruction_schedule());
-
   // Record the buffer define/free event for each time step. We free all
   // remaining buffers (entry parameter, etc) after the program has finished
   // running, so we set the size of to program_end_time + 1.
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 00a748f..d8f5996 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -257,22 +257,6 @@
   // Finish collects the buffer offset assignment results.  Free may only be
   // called once, after the Alloc and Free calls.
   virtual Result Finish() = 0;
-
-  // Heap algorithms can optionally make use of the instruction/computation
-  // schedule. These data structures are guaranteed to be valid while Finish()
-  // is being called.
-  virtual void SetSchedules(
-      const HloInstructionSequence* flattened_instruction_sequence,
-      const absl::flat_hash_map<const HloInstruction*, int64>*
-          instruction_schedule) {
-    flattened_instruction_sequence_ = flattened_instruction_sequence;
-    instruction_schedule_ = instruction_schedule;
-  }
-
- protected:
-  const HloInstructionSequence* flattened_instruction_sequence_;
-  const absl::flat_hash_map<const HloInstruction*, int64>*
-      instruction_schedule_;
 };
 
 // NoFragmentationStatsHeap computes the heap size assuming no fragmentation;
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 8c968ff..692b100 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -98,19 +98,19 @@
         break;
       }
       const HloValue* value = colocated_interval->buffer;
+      const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
       int64 definition_time =
-          instruction_schedule_->at(value->defining_instruction());
+          instruction_schedule.at(value->defining_instruction());
       // Sort the uses by the use time.
       std::vector<HloUse> uses = value->uses();
       absl::c_sort(uses, [&](HloUse use1, HloUse use2) {
-        return instruction_schedule_->at(use1.instruction) <
-               instruction_schedule_->at(use2.instruction);
+        return instruction_schedule.at(use1.instruction) <
+               instruction_schedule.at(use2.instruction);
       });
       // Iterate over the uses.
       for (HloUse use : uses) {
-        int64 use_time = instruction_schedule_->at(use.instruction);
-        int64 last_use_time =
-            instruction_schedule_->at(uses.back().instruction);
+        int64 use_time = instruction_schedule.at(use.instruction);
+        int64 last_use_time = instruction_schedule.at(uses.back().instruction);
 
         // Bitcasts don't define buffers and don't directly consume buffers.
         // Skip allocating buffers for bitcast uses. The uses that feed from
@@ -157,10 +157,6 @@
   return result_;
 }
 
-HloInstruction* AlternateMemoryBestFitHeap::GetInstructionAt(int64 time) const {
-  return flattened_instruction_sequence_->instructions()[time];
-}
-
 void AlternateMemoryBestFitHeap::CommitPendingChunks() {
   for (auto interval_and_chunk : pending_chunks_) {
     VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
@@ -209,11 +205,12 @@
 
   VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " ("
           << start_time << ", " << end_time << ") last use = " << last_use_time
-          << ". Size = " << size
+          << " use = " << use.ToString() << ". Size = " << size
           << ", def pos = " << defining_position.ToString()
-          << ", operand = " << operand->ToString()
+          << ", operand = " << operand->ToShortString()
           << (non_bitcast_operand != operand
-                  ? ", non_bitcast_operand = " + non_bitcast_operand->ToString()
+                  ? ", non_bitcast_operand = " +
+                        non_bitcast_operand->ToShortString()
                   : "");
   CHECK_LE(start_time, end_time);
 
@@ -224,6 +221,14 @@
     return true;
   }
 
+  if (defining_position.instruction->parent() != use.instruction->parent() ||
+      !use.instruction->called_computations().empty()) {
+    VLOG(3) << "Use is in a different computation or calls a computation.";
+    // Fail because we do not allow asynchronous copies while in the bodies of
+    // other computation.
+    return false;
+  }
+
   MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
   if (!allocations->empty()) {
     prev_allocation = allocations->back().get();
@@ -233,7 +238,7 @@
   // memory space.
   if (prev_allocation != nullptr &&
       prev_allocation->memory_space() == MemorySpace::kAlternate &&
-      prev_allocation->instruction() == non_bitcast_operand) {
+      prev_allocation->defining_position() == defining_position) {
     // If there was an allocation for this HloValue that was in the alternate
     // memory space, we also need to perform an eviction.
     // TODO(berkin): For now evictions happen relative to the most recent
@@ -273,15 +278,17 @@
         VLOG(3) << "Bailing: Could not evict " << use.ToString()
                 << " because we hit the limit of maximum asynchronous copies "
                 << "between "
-                << GetInstructionAt(prev_allocation->start_time())->ToString()
+                << hlo_live_range_.flattened_instruction_sequence()
+                       .instructions()[prev_allocation->start_time()]
                 << " and "
-                << GetInstructionAt(prev_allocation->end_time())->ToString();
+                << hlo_live_range_.flattened_instruction_sequence()
+                       .instructions()[prev_allocation->end_time()];
         return false;
       }
     }
   } else if (prev_allocation != nullptr &&
              prev_allocation->memory_space() == MemorySpace::kDefault &&
-             prev_allocation->instruction() == non_bitcast_operand) {
+             prev_allocation->defining_position() == defining_position) {
     // If the previous allocation was in the default memory space and was
     // defined by the same instruction, extend that.  Otherwise, create a new
     // allocation.
@@ -439,11 +446,11 @@
   //           --------------------------+-----------+------
   //
   // Because we allocate buffers greedily, Producer to Use1 segment first, and
-  // then Use1 to Use2 segment, it is possible to allocate the the first segment
-  // at an offset that is available for the first segment (e.g. offset 0) but
-  // not for the entire live range. This can result in unnecessary copies. By
-  // using the last use time, we try to find an allocation that is available for
-  // the entire Producer to Use2 range.
+  // then Use1 to Use2 segment, it is possible to allocate the first segment at
+  // an offset that is available for the first segment (e.g. offset 0) but not
+  // for the entire live range. This can result in unnecessary copies. By using
+  // the last use time, we try to find an allocation that is available for the
+  // entire Producer to Use2 range.
   alternate_mem_interval.end = last_use_time;
   ChunkCandidate chunk_candidate =
       FindChunkCandidate(alternate_mem_interval, preferred_offset);
@@ -462,7 +469,8 @@
     // If there was a previous allocation, the buffer location is the
     // same as the previous. Otherwise, it is the operand.
     if (prev_allocation != nullptr &&
-        prev_allocation->instruction() == non_bitcast_operand) {
+        (prev_allocation->is_copy_allocation() ||
+         prev_allocation->defining_position() == defining_position)) {
       prev_allocation->Extend(end_time);
     } else {
       allocations->push_back(
@@ -508,10 +516,15 @@
   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
 
   MemorySpaceAssignment memory_space_assignment(module, alternate_memory_space);
+  const HloComputation* entry_computation = module->entry_computation();
+  TF_ASSIGN_OR_RETURN(memory_space_assignment.hlo_live_range_,
+                      HloLiveRange::Run(module->schedule(), *alias_analysis,
+                                        entry_computation));
   // TODO(berkin): Explore heap algorithms other than kSpatial.
   auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
       &memory_space_assignment.allocation_map_, max_size_in_bytes,
       min_prefetch_interval, max_prefetch_interval, *alias_analysis,
+      *memory_space_assignment.hlo_live_range_,
       alternate_memory_space_alignment_in_bytes,
       GlobalDecreasingSizeBestFitHeap::Type::kSpatial,
       is_allowed_in_alternate_mem, max_outstanding_async_copies);
@@ -546,27 +559,14 @@
   }
 }
 
-Status MemorySpaceAssignment::Allocation::PropagateMemorySpaceToBitcasts(
-    const MemorySpaceAssignment& memory_space_assignment) {
-  for (HloInstruction* bitcast : bitcasts_) {
-    if (memory_space_ == MemorySpace::kAlternate) {
-      Layout* bitcast_layout = bitcast->mutable_shape()->mutable_layout();
-      bitcast_layout->set_memory_space(
-          memory_space_assignment.alternate_memory_space_);
-    }
-  }
-  return Status::OK();
-}
-
 Status MemorySpaceAssignment::Allocation::Process(
     MemorySpaceAssignment* memory_space_assignment) {
   // For non-copy allocations, all we need to do is to update the output memory
   // space if placed in the alternate memory.
   if (memory_space_ == MemorySpace::kAlternate) {
-    Layout* layout = instruction_->mutable_shape()->mutable_layout();
-    layout->set_memory_space(memory_space_assignment->alternate_memory_space_);
+    memory_space_assignment->AddPositionInAlternateMemorySpace(
+        defining_position_);
   }
-  TF_RETURN_IF_ERROR(PropagateMemorySpaceToBitcasts(*memory_space_assignment));
   return Status::OK();
 }
 
@@ -579,14 +579,6 @@
   Shape shape = producing_instruction->shape();
   HloComputation* computation = producing_instruction->parent();
 
-  // Set the layout to include the memory space.
-  Layout* layout = shape.mutable_layout();
-  if (memory_space_ == MemorySpace::kAlternate) {
-    layout->set_memory_space(memory_space_assignment->alternate_memory_space_);
-  } else {
-    layout->set_memory_space(0);
-  }
-
   copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
       ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
       HloOpcode::kCopyStart, producing_instruction));
@@ -626,11 +618,12 @@
   // ReplaceOperandWithDifferentShape.
   for (HloInstruction* bitcast : bitcasts_) {
     TF_RETURN_IF_ERROR(bitcast->ReplaceOperandWithDifferentShape(
-        /*operand_num=*/0, instruction_));
+        /*operand_num=*/0, copy_done_));
   }
 
-  // Propagate the memory space to all bitcasts.
-  TF_RETURN_IF_ERROR(PropagateMemorySpaceToBitcasts(*memory_space_assignment));
+  if (memory_space_ == MemorySpace::kAlternate) {
+    memory_space_assignment->AddPositionInAlternateMemorySpace({copy_done_});
+  }
 
   return Status::OK();
 }
@@ -670,6 +663,25 @@
       VLOG(3) << "  space: " << pair.first << ", size: " << pair.second;
     }
   }
+
+  // Color the pending positions and all of their aliased buffers.
+  TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
+  for (HloPosition defining_position : pending_positions_in_alternate_mem_) {
+    for (auto& buffer : alias_analysis->ComputeBuffersAt(
+             defining_position.instruction, defining_position.index)) {
+      for (auto& value : buffer->values()) {
+        for (auto& position : value->positions()) {
+          VLOG(3) << "Coloring " << position.ToString();
+          Shape* shape = ShapeUtil::GetMutableSubshape(
+              position.instruction->mutable_shape(), position.index);
+          CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
+                                  << position.ToString();
+          shape->mutable_layout()->set_memory_space(alternate_memory_space_);
+        }
+      }
+    }
+  }
+
   return Status::OK();
 }
 
@@ -683,11 +695,16 @@
     EnsureInstructionAndOperandsInserted(operand, new_sequence,
                                          inserted_instructions);
   }
-  VLOG(4) << "inserting: " << new_instruction->ToString();
+  VLOG(4) << "inserting: " << new_instruction->ToShortString();
   new_sequence->push_back(new_instruction);
   inserted_instructions->insert(new_instruction);
 }
 
+void MemorySpaceAssignment::AddPositionInAlternateMemorySpace(
+    HloPosition position) {
+  pending_positions_in_alternate_mem_.push_back(position);
+}
+
 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
   // For asynchronous copies of both directions (default to alternate and vice
   // versa), sort them by their completion time. Then, if in the sorted order we
@@ -746,6 +763,23 @@
         copy_allocation->set_copy_start_schedule_after(
             prev_copy_allocation->copy_start_schedule_after());
       }
+
+      // If the copy start doesn't happen to be scheduled at the correct
+      // computation, delay it until the correct computation starts.
+      const auto& flattened_instructions =
+          hlo_live_range_->flattened_instruction_sequence().instructions();
+      int64 copy_start_schedule_after =
+          copy_allocation->copy_start_schedule_after();
+      while (copy_allocation->instruction()->parent() !=
+             flattened_instructions[copy_start_schedule_after]->parent()) {
+        VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
+                << (copy_start_schedule_after + 1) << ") for "
+                << copy_allocation->copy_start()->ToString()
+                << " because it is not in the correct computation.";
+        copy_allocation->set_copy_start_schedule_after(
+            ++copy_start_schedule_after);
+      }
+
       schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
           copy_allocation->copy_start());
       schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
@@ -761,15 +795,30 @@
   for (const HloComputation* computation :
        module_->MakeNonfusionComputations()) {
     CHECK(schedule.is_computation_scheduled(computation));
-    const HloInstructionSequence& sequence = schedule.sequence(computation);
+    const HloInstructionSequence& sequence =
+        hlo_live_range_->flattened_instruction_sequence();
     HloInstructionSequence new_sequence;
 
     absl::flat_hash_set<HloInstruction*> inserted_instructions;
 
+    // Schedule the computations only if needed (if there are unscheduled
+    // instructions in the computation).
+    if (computation->instruction_count() ==
+        schedule.sequence(computation).size()) {
+      VLOG(4) << "Skip scheduling " << computation->name()
+              << " because it is already scheduled.";
+      continue;
+    }
+
+    VLOG(4) << "Scheduling: " << computation->ToString();
+
     for (int64 instruction_index = 0;
          instruction_index < sequence.instructions().size();
          ++instruction_index) {
       HloInstruction* instruction = sequence.instructions()[instruction_index];
+      if (instruction->parent() != computation) {
+        continue;
+      }
       auto insts_before_iter = schedule_before_.find(instruction_index);
       if (insts_before_iter != schedule_before_.end()) {
         for (HloInstruction* new_instruction : insts_before_iter->second) {
@@ -790,6 +839,10 @@
         }
       }
     }
+    CHECK_EQ(new_sequence.size(), computation->instruction_count())
+        << "New sequence for computation " << computation->name() << " has "
+        << new_sequence.size() << " instructions, expects "
+        << computation->instruction_count() << ".";
     schedule.set_sequence(computation, new_sequence);
   }
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 1f8ee0e..f4c0a43 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -139,11 +139,6 @@
     int64 end_time() const { return end_time_; }
 
    protected:
-    // Bitcasts are treated specially because they do not define buffers.  This
-    // method propagates the memory space for the bitcasts of this allocation.
-    Status PropagateMemorySpaceToBitcasts(
-        const MemorySpaceAssignment& memory_space_assignment);
-
     HloInstruction* instruction_;
     HloPosition defining_position_;
     std::vector<HloUse> uses_;
@@ -160,7 +155,7 @@
     CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
                    Chunk chunk, int64 start_time, int64 end_time)
         : Allocation(/*instruction=*/nullptr,
-                     /*defining_position=*/{nullptr, {}}, memory_space, chunk,
+                     prev_allocation.defining_position(), memory_space, chunk,
                      start_time, end_time),
           prev_allocation_(prev_allocation),
           copy_start_schedule_after_(start_time),
@@ -259,8 +254,13 @@
   // corresponding CopyDones follow the same order.
   void ScheduleAsynchronousCopies();
 
+  // Add the position to the pending positions that will be colored as alternate
+  // memory.
+  void AddPositionInAlternateMemorySpace(HloPosition position);
+
   HloModule* module_;
   int64 alternate_memory_space_;
+  std::unique_ptr<HloLiveRange> hlo_live_range_;
   AllocationMap allocation_map_;
   std::unique_ptr<PresetAssignments> preset_assignments_;
 
@@ -269,6 +269,7 @@
   // to modify and fix the schedule.
   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_after_;
   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_before_;
+  std::vector<HloPosition> pending_positions_in_alternate_mem_;
 };
 
 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
@@ -283,7 +284,8 @@
       MemorySpaceAssignment::AllocationMap* allocation_map,
       int64 max_size_in_bytes, int64 min_prefetch_interval,
       int64 max_prefetch_interval, const HloAliasAnalysis& alias_analysis,
-      int64 alignment, GlobalDecreasingSizeBestFitHeap::Type type,
+      const HloLiveRange& hlo_live_range, int64 alignment,
+      GlobalDecreasingSizeBestFitHeap::Type type,
       IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem,
       int64 max_outstanding_async_copies)
       : GlobalDecreasingSizeBestFitHeap(alignment, type),
@@ -292,6 +294,7 @@
         min_prefetch_interval_(min_prefetch_interval),
         max_prefetch_interval_(max_prefetch_interval),
         alias_analysis_(alias_analysis),
+        hlo_live_range_(hlo_live_range),
         is_allowed_in_alternate_mem_(is_allowed_in_alternate_mem),
         max_outstanding_async_copies_(max_outstanding_async_copies) {}
 
@@ -317,10 +320,6 @@
       HloInstruction* non_bitcast_operand,
       MemorySpaceAssignment::AllocationSequence* allocations);
 
-  // Returns the instruction at a particular time in the flattened instruction
-  // schedule.
-  HloInstruction* GetInstructionAt(int64 time) const;
-
   // Given a buffer interval, returns the colocated intervals. Unlike the
   // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
   // returns the colocated intervals sorted by scheduled time.
@@ -366,6 +365,7 @@
   int64 min_prefetch_interval_;
   int64 max_prefetch_interval_;
   const HloAliasAnalysis& alias_analysis_;
+  const HloLiveRange& hlo_live_range_;
   IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_;
   // We use a interval tree to keep track of the number of outstanding
   // asynchronous copies.
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 2c30e29..bd1a4fb 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -589,7 +589,7 @@
   //              mul1===>add2
   //
   // Without the last use optimization, the mul1 buffer will be assigned first
-  // (becase it is larger) to offset 0. Then, add1 will be scheduled for the
+  // (because it is larger) to offset 0. Then, add1 will be scheduled for the
   // add1 to sub1 segment. Because offset 0 is available, it will get that
   // offset. But because offset 0 is not available in the sub1 to mul2 offset,
   // it will end up in unnecessary copies. With the last use optimization, these
@@ -751,5 +751,422 @@
   }
 }
 
+TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
+  // Test to ensure CopyStart/CopyDone is placed only in the entry computation.
+  auto module = CreateNewVerifiedModule();
+  Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
+  Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+  Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
+
+  auto cond_builder = HloComputation::Builder("WhileCond");
+  // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+  HloInstruction* cond_param = cond_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
+  HloInstruction* cond_iter = cond_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
+  HloInstruction* cond_limit = cond_builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
+  // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
+  HloInstruction* cond_lt = cond_builder.AddInstruction(
+      HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
+                                    cond_limit, ComparisonDirection::kLt));
+  HloComputation* cond_computation =
+      module->AddEmbeddedComputation(cond_builder.Build());
+
+  auto body_builder = HloComputation::Builder("WhileBody");
+  // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+  HloInstruction* body_param = body_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
+  HloInstruction* body_iter = body_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
+  HloInstruction* body_data = body_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, body_param, 0));
+  HloInstruction* body_iter_increment = body_builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
+  HloInstruction* body_iter_next =
+      body_builder.AddInstruction(HloInstruction::CreateBinary(
+          scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
+  HloInstruction* body_data_increment =
+      body_builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
+  HloInstruction* body_data_mul =
+      body_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kMultiply, body_data, body_data));
+  HloInstruction* body_data_add =
+      body_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kAdd, body_data, body_data_increment));
+  HloInstruction* body_data_next =
+      body_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kAdd, body_data_add, body_data_mul));
+  HloInstruction* body_out = body_builder.AddInstruction(
+      HloInstruction::CreateTuple({body_data_next, body_iter_next}));
+  HloComputation* body_computation =
+      module->AddEmbeddedComputation(body_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* data = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "param_iter"));
+  HloInstruction* iter = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
+  HloInstruction* p2 =
+      builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2"));
+  HloInstruction* tuple =
+      builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
+  HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
+      tuple_shape, cond_computation, body_computation, tuple));
+  HloInstruction* while_data = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, while_op, 0));
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, while_data, p2));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(cond_computation,
+                        {cond_param, cond_iter, cond_limit, cond_lt});
+  schedule.set_sequence(body_computation,
+                        {body_param, body_iter, body_data, body_iter_increment,
+                         body_iter_next, body_data_increment, body_data_mul,
+                         body_data_add, body_data_next, body_out});
+  schedule.set_sequence(entry_computation,
+                        {iter, data, p2, tuple, while_op, while_data, add});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get(), -1, 50);
+}
+
+TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
+  auto module = CreateNewVerifiedModule();
+  Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
+  Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
+
+  auto call_builder = HloComputation::Builder("Call");
+  HloInstruction* call_param = call_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "call_param"));
+  HloInstruction* call_param2 = call_builder.AddInstruction(
+      HloInstruction::CreateParameter(1, shape2, "call_param2"));
+  HloInstruction* slice = call_builder.AddInstruction(
+      HloInstruction::CreateSlice(shape, call_param2, {0, 0}, {2, 3}, {1, 1}));
+  HloInstruction* mul =
+      call_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kMultiply, call_param, slice));
+  HloInstruction* negate0 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
+  HloInstruction* negate1 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* negate2 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* negate3 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* negate4 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+  HloInstruction* negate5 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
+  HloInstruction* negate6 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
+  HloInstruction* negate7 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
+  HloInstruction* add0 =
+      call_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kAdd, call_param, negate7));
+  HloComputation* call_computation =
+      module->AddEmbeddedComputation(call_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* p1 =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
+  HloInstruction* add1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
+  HloInstruction* add2 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
+  HloInstruction* negate8 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape2, HloOpcode::kNegate, p1));
+  HloInstruction* call = builder.AddInstruction(
+      HloInstruction::CreateCall(shape, {add1, negate8}, call_computation));
+  HloInstruction* add3 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, add1));
+  HloInstruction* add4 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add3));
+  HloInstruction* add5 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add2, add4));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(
+      call_computation,
+      {call_param, call_param2, slice, mul, negate0, negate1, negate2, negate3,
+       negate4, negate5, negate6, negate7, add0});
+  schedule.set_sequence(entry_computation,
+                        {p0, p1, add1, add2, negate8, call, add3, add4, add5});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get(), -1, 5);
+}
+
+TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
+  auto module = CreateNewVerifiedModule();
+  Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
+  Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
+
+  auto call_builder = HloComputation::Builder("Call");
+  HloInstruction* call_param = call_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "call_param"));
+  // Use shape2 here which is larger (scheduled earlier) to occupy alternate
+  // memory at the beginning. This should cause a situation where the prefetch
+  // of add1 later in the function body gets the wrong offset which cannot be
+  // communicated to the outside the function.
+  HloInstruction* iota =
+      call_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
+  HloInstruction* slice = call_builder.AddInstruction(
+      HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
+  HloInstruction* mul =
+      call_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kMultiply, call_param, slice));
+  HloInstruction* negate0 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
+  HloInstruction* negate1 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* negate2 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* negate3 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* negate4 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+  HloInstruction* negate5 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
+  HloInstruction* negate6 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
+  HloInstruction* negate7 = call_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
+  HloInstruction* add0 =
+      call_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kAdd, call_param, negate7));
+  HloComputation* call_computation =
+      module->AddEmbeddedComputation(call_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* add1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
+  HloInstruction* add2 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
+  HloInstruction* call = builder.AddInstruction(
+      HloInstruction::CreateCall(shape, {add1}, call_computation));
+  HloInstruction* add3 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add1));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(
+      call_computation,
+      {call_param, iota, slice, mul, negate0, negate1, negate2, negate3,
+       negate4, negate5, negate6, negate7, add0});
+  schedule.set_sequence(entry_computation, {p0, add1, add2, call, add3});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get(), -1, 5);
+}
+
+TEST_F(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) {
+  auto module = CreateNewVerifiedModule();
+  Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
+  Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
+
+  auto true_builder = HloComputation::Builder("True");
+  HloInstruction* true_param = true_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "true_param"));
+  HloInstruction* iota =
+      true_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
+  HloInstruction* slice = true_builder.AddInstruction(
+      HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
+  HloInstruction* mul =
+      true_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kMultiply, true_param, slice));
+  HloInstruction* negate0 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
+  HloInstruction* negate1 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* negate2 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* negate3 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* negate4 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+  HloInstruction* negate5 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
+  HloInstruction* negate6 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
+  HloInstruction* negate7 = true_builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
+  HloInstruction* add0 =
+      true_builder.AddInstruction(HloInstruction::CreateBinary(
+          shape, HloOpcode::kAdd, true_param, negate7));
+  HloComputation* true_computation =
+      module->AddEmbeddedComputation(true_builder.Build());
+
+  auto false_builder = HloComputation::Builder("False");
+  HloInstruction* false_param = false_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "false_param"));
+  HloComputation* false_computation =
+      module->AddEmbeddedComputation(false_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* add1 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
+  HloInstruction* add2 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
+  HloInstruction* pred = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+  HloInstruction* conditional =
+      builder.AddInstruction(HloInstruction::CreateConditional(
+          shape, pred, add1, true_computation, add2, false_computation));
+  HloInstruction* add3 = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, conditional, add1));
+  HloComputation* entry_computation =
+      module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(
+      true_computation,
+      {true_param, iota, slice, mul, negate0, negate1, negate2, negate3,
+       negate4, negate5, negate6, negate7, add0});
+  schedule.set_sequence(false_computation, {false_param});
+  schedule.set_sequence(entry_computation,
+                        {p0, add1, add2, pred, conditional, add3});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get(), -1, 5);
+}
+
+TEST_F(MemorySpaceAssignmentTest, DanglingCopy) {
+  // This situation was encountered in vss, where there is a mismatch in the
+  // memory space in preset assignments and the output graph.
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+  Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
+
+  HloInstruction* p = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "p"));
+  HloInstruction* p0 = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, p, 0));
+  HloInstruction* p1a = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, p, 1));
+  HloInstruction* copy = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kCopy, p1a));
+  HloInstruction* negate0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
+  HloInstruction* negate1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+  HloInstruction* negate2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+  HloInstruction* negate3 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+  HloInstruction* negate4 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+  HloInstruction* negate5 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
+  HloInstruction* negate6 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
+  HloInstruction* p1b = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, p, 1));
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1b));
+
+  auto module = CreateNewVerifiedModule();
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(
+      computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
+                    negate6, p1a, copy, p1b, add});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get());
+}
+
+TEST_F(MemorySpaceAssignmentTest, MultiOutputFusion) {
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+  Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
+  auto module = CreateNewVerifiedModule();
+
+  HloComputation::Builder fusion_builder("fusion");
+  HloInstruction* fusion_param0 = fusion_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* fusion_param1 = fusion_builder.AddInstruction(
+      HloInstruction::CreateParameter(1, shape, "p1"));
+  fusion_builder.AddInstruction(
+      HloInstruction::CreateTuple({fusion_param0, fusion_param1}));
+  HloComputation* fusion_computation =
+      module->AddEmbeddedComputation(fusion_builder.Build());
+
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
+      tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
+      fusion_computation));
+  HloInstruction* element0 = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, fusion, 0));
+  HloInstruction* element1 = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, fusion, 1));
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
+
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(computation, {p0, fusion, element0, element1, add});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get());
+}
+
+TEST_F(MemorySpaceAssignmentTest, TupleInput) {
+  HloComputation::Builder builder(TestName());
+  Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+  Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
+  auto module = CreateNewVerifiedModule();
+
+  HloComputation::Builder fusion_builder("fusion");
+  HloInstruction* fusion_param = fusion_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, tuple_shape, "p"));
+  HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
+  HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
+  fusion_builder.AddInstruction(HloInstruction::CreateBinary(
+      shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
+  HloComputation* fusion_computation =
+      module->AddEmbeddedComputation(fusion_builder.Build());
+
+  HloInstruction* p0 =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+  HloInstruction* p1 =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
+  HloInstruction* negate0 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
+  HloInstruction* negate1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
+  HloInstruction* tuple =
+      builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
+  HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
+      shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
+
+  HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(computation, {p0, p1, negate0, negate1, tuple, fusion});
+  TF_CHECK_OK(module->set_schedule(schedule));
+
+  AssignMemorySpace(module.get());
+}
+
 }  // namespace
 }  // namespace xla