[XLA] Implement a more sophisticated prefetch interval ordering.

We previously used latest-to-earliest order in the prefetch interval picker.
This can result in prefetches to start unnecessarily late. With this CL, a
preferred prefetch interval is specified and the prefetch interval picker starts
at this preferred interval and returns alternating ascending and descending
indices. This can help cases where we spend a long time in copy-dones before a
while loop (3% improvement in b/161249728).

PiperOrigin-RevId: 321477583
Change-Id: Icd88c306c5f2bf7cd55e693fc7d040d93f05b70d
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 6610035..874200e 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -199,6 +199,12 @@
 }
 
 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
+    const HloInstruction& instruction) const {
+  return std::max(GetInstructionElapsedDueToCompute(instruction),
+                  GetInstructionElapsedDueToMemory(instruction));
+}
+
+float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
     const HloInstruction& instruction,
     absl::optional<int64> operand_in_alternate_mem,
     bool output_in_alternate_mem) const {
@@ -258,12 +264,15 @@
 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
     float min_async_copy_to_overlap_ratio,
-    float max_async_copy_to_overlap_ratio)
+    float max_async_copy_to_overlap_ratio,
+    float preferred_async_copy_to_overlap_ratio)
     : while_nest_level_(
           cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
       cost_analysis_(cost_analysis),
       min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
-      max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) {
+      max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio),
+      preferred_async_copy_to_overlap_ratio_(
+          preferred_async_copy_to_overlap_ratio) {
   instruction_schedule_ =
       &cost_analysis_.hlo_live_range().instruction_schedule();
 
@@ -281,7 +290,7 @@
         instruction->opcode() == HloOpcode::kConditional) {
       continue;
     }
-    float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
+    float elapsed_time = cost_analysis_.GetInstructionElapsed(
         *instruction_and_logical_time.first);
     int64 logical_time = instruction_and_logical_time.second;
     if (logical_time >= instructions_elapsed_time.size()) {
@@ -355,52 +364,107 @@
   async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
   // Estimate the time we would save by having this op in alternate memory.
   float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
-  float elapsed_time_in_alternate_mem = cost_analysis_.GetInstructionElapsed(
-      *use.instruction, use.operand_number);
+  float elapsed_time_in_alternate_mem =
+      cost_analysis_.GetInstructionElapsedInAlternateMemory(
+          *use.instruction, use.operand_number,
+          /*output_in_alternate_mem=*/false);
   inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
   end_logical_time_ = end_time;
-  earliest_start_logical_time_ = start_time;
-  int end_nest_level = while_nest_level_[end_time];
-  // Find the latest time we're allowed to start prefetching. If the start and
-  // end nest levels differe look for an earlier prefetch start.
-  for (current_logical_prefetch_time_ = end_time - 1;
-       current_logical_prefetch_time_ > start_time &&
-       (while_nest_level_[current_logical_prefetch_time_] != end_nest_level ||
-        min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
-            GetLogicalIntervalElapsed(current_logical_prefetch_time_,
-                                      end_logical_time_) +
-                inst_elapsed_reduction_);
-       --current_logical_prefetch_time_) {
+  int end_nest_level = while_nest_level_[end_logical_time_];
+
+  // Find the latest time we're allowed to start prefetching.
+  float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
+  for (latest_prefetch_time_ = end_logical_time_ - 1;
+       latest_prefetch_time_ >= start_time &&
+       (while_nest_level_[latest_prefetch_time_] != end_nest_level ||
+        min_interval > GetLogicalIntervalElapsed(latest_prefetch_time_,
+                                                 end_logical_time_) +
+                           inst_elapsed_reduction_);
+       --latest_prefetch_time_) {
   }
+
+  // Find the earliest time we're allowed to start prefetching.
+  float max_interval = max_async_copy_to_overlap_ratio_ *
+                       max_overlap_multiplier_ * async_copy_elapsed_;
+  for (earliest_prefetch_time_ = start_time;
+       earliest_prefetch_time_ <= end_logical_time_ &&
+       (while_nest_level_[earliest_prefetch_time_] != end_nest_level ||
+        max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
+                                                 end_logical_time_));
+       ++earliest_prefetch_time_) {
+  }
+  if (earliest_prefetch_time_ > latest_prefetch_time_) {
+    // There is no available prefetch interval for the given start and end
+    // times. Set the iterators accordingly to ensure Done() returns true.
+    increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
+    decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
+    CHECK(Done());
+    return;
+  }
+
+  // Between the earliest and latest prefetch interval, find the interval
+  // closest to the preferred interval and start iterating from there.
+  int64 starting_prefetch_time = earliest_prefetch_time_;
+  float preferred_interval =
+      preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
+  float best_interval =
+      GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_);
+  for (int64 prefetch_time = earliest_prefetch_time_ + 1;
+       prefetch_time <= latest_prefetch_time_; ++prefetch_time) {
+    float interval =
+        GetLogicalIntervalElapsed(prefetch_time, end_logical_time_);
+    if (while_nest_level_[prefetch_time] == end_nest_level &&
+        std::abs(preferred_interval - interval) <
+            std::abs(preferred_interval - best_interval)) {
+      best_interval = interval;
+      starting_prefetch_time = prefetch_time;
+    }
+  }
+  VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
+          << max_interval << " " << preferred_interval
+          << " prefetch time earliest/latest/starting = "
+          << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
+          << starting_prefetch_time;
+
+  increasing_prefetch_time_iterator_ = starting_prefetch_time;
+  decreasing_prefetch_time_iterator_ = starting_prefetch_time;
+  using_increasing_prefetch_time_iterator_ = true;
+  // Since both iterators start at the same position, call Next() once to
+  // advance one of the iterators.
+  Next();
 }
 
 int64 CostAnalysisPrefetchIntervalPicker::Next() {
   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
                     "Done() is false";
-  int64 prefetch_time = current_logical_prefetch_time_;
-  if (!Done()) {
-    --current_logical_prefetch_time_;
+  if (using_increasing_prefetch_time_iterator_) {
+    int64 prefetch_time = increasing_prefetch_time_iterator_++;
+    while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
+           while_nest_level_[increasing_prefetch_time_iterator_] !=
+               while_nest_level_[end_logical_time_]) {
+      ++increasing_prefetch_time_iterator_;
+    }
+    if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
+      using_increasing_prefetch_time_iterator_ = false;
+    }
+    return prefetch_time;
+  } else {
+    int64 prefetch_time = decreasing_prefetch_time_iterator_--;
+    while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
+           while_nest_level_[decreasing_prefetch_time_iterator_] !=
+               while_nest_level_[end_logical_time_]) {
+      --decreasing_prefetch_time_iterator_;
+    }
+    if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
+      using_increasing_prefetch_time_iterator_ = true;
+    }
+    return prefetch_time;
   }
-  // If the prefetch start and end times differ, look for an earlier prefetch
-  // start.
-  while (!Done() && while_nest_level_[current_logical_prefetch_time_] !=
-                        while_nest_level_[end_logical_time_]) {
-    --current_logical_prefetch_time_;
-  }
-  return prefetch_time;
 }
 
 bool CostAnalysisPrefetchIntervalPicker::Done() const {
-  if (current_logical_prefetch_time_ < earliest_start_logical_time_) {
-    return true;
-  }
-  float logical_interval_elapsed = GetLogicalIntervalElapsed(
-      current_logical_prefetch_time_, end_logical_time_);
-  return (max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
-              async_copy_elapsed_ <
-          logical_interval_elapsed) ||
-         (min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
-          logical_interval_elapsed + inst_elapsed_reduction_);
+  return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
+         decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
 }
 
 void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
@@ -440,13 +504,16 @@
 }
 
 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
+  int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
+                                          ? increasing_prefetch_time_iterator_
+                                          : decreasing_prefetch_time_iterator_;
   float logical_interval_elapsed = GetLogicalIntervalElapsed(
-      current_logical_prefetch_time_, end_logical_time_);
+      current_logical_prefetch_time, end_logical_time_);
   return absl::StrCat(
       "Async copy elapsed (s) = ", async_copy_elapsed_,
       ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
       ", logical interval elapsed (s) = ", logical_interval_elapsed,
-      ", interval = (", current_logical_prefetch_time_, ", ", end_logical_time_,
+      ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
       ")");
 }
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 8f20020..d1b508a 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -84,6 +84,8 @@
     absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
   };
 
+  virtual ~MemorySpaceAssignmentCostAnalysis() = default;
+
   static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
       const HloCostAnalysis& cost_analysis,
       float async_copy_bandwidth_bytes_per_second,
@@ -128,16 +130,21 @@
 
   // Returns the estimated elapsed duration of the instruction in seconds.  It
   // assumes all operands and outputs of the instruction are in the default
+  // memory.
+  virtual float GetInstructionElapsed(const HloInstruction& instruction) const;
+
+  // Returns the estimated elapsed duration of the instruction in seconds.  It
+  // assumes all operands and outputs of the instruction are in the default
   // memory, except for the operand number that is in the alternate memory, if
   // provided, or output if output_in_alternate_mem is true.
-  float GetInstructionElapsed(
+  virtual float GetInstructionElapsedInAlternateMemory(
       const HloInstruction& instruction,
-      absl::optional<int64> operand_in_alternate_mem = absl::nullopt,
-      bool output_in_alternate_mem = false) const;
+      absl::optional<int64> operand_in_alternate_mem,
+      bool output_in_alternate_mem) const;
 
   // Returns the elapsed time it would take to asynchronously copy the shape
   // from default to alternate memory space (or vice versa).
-  float GetAsyncCopyElapsed(const Shape& shape) const;
+  virtual float GetAsyncCopyElapsed(const Shape& shape) const;
 
   int64 GetScheduleEndTime() const;
 
@@ -147,7 +154,7 @@
 
   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
 
- private:
+ protected:
   MemorySpaceAssignmentCostAnalysis(
       const HloCostAnalysis& cost_analysis,
       float async_copy_bandwidth_bytes_per_second,
@@ -164,6 +171,7 @@
         hlo_live_range_(std::move(hlo_live_range)),
         call_graph_(std::move(call_graph)) {}
 
+ private:
   const HloCostAnalysis& cost_analysis_;
   float async_copy_bandwidth_bytes_per_second_;
   float alternate_mem_bandwidth_bytes_per_second_;
@@ -267,16 +275,16 @@
 // Prefetch interval picker that uses cost analysis to overlap asynchronous
 // copies with independent computation. It uses min/max (asynchronous copy
 // duration) / (independent computation duration) ratios to guide whether the
-// prefetch is within those bounds. It starts with the maximum allowed ratio
-// (earliest prefetch) in Begin() and works its way for later and later prefetch
-// with each Next() call until hitting the minimum ratio, in order not to hurt
-// the critical path.
+// prefetch is within those bounds. It starts with the preferred ratio in
+// Begin() and works its way for alternately earlier and later prefetches until
+// hitting min and max ratios.
 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
  public:
   CostAnalysisPrefetchIntervalPicker(
       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
       float min_async_copy_to_overlap_ratio,
-      float max_async_copy_to_overlap_ratio);
+      float max_async_copy_to_overlap_ratio,
+      float preferred_async_copy_to_overlap_ratio);
 
   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
                                           int64 end_time) const override;
@@ -319,13 +327,17 @@
   const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
   float min_async_copy_to_overlap_ratio_;
   float max_async_copy_to_overlap_ratio_;
+  float preferred_async_copy_to_overlap_ratio_;
   float max_overlap_multiplier_ = 1.0;
 
   float async_copy_elapsed_;
   float inst_elapsed_reduction_;
   int64 end_logical_time_;
-  int64 earliest_start_logical_time_;
-  int64 current_logical_prefetch_time_;
+  int64 earliest_prefetch_time_;
+  int64 latest_prefetch_time_;
+  bool using_increasing_prefetch_time_iterator_;
+  int64 increasing_prefetch_time_iterator_;
+  int64 decreasing_prefetch_time_iterator_;
 };
 
 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 10e11e5..a92b73c 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -60,7 +60,8 @@
     CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
         CostAnalysisPrefetchIntervalPicker(
             *cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
-            /*max_async_copy_to_overlap_ratio=*/10.0));
+            /*max_async_copy_to_overlap_ratio=*/10.0,
+            /*preferred_async_copy_to_overlap_ratio=*/1.5));
     return AssignMemorySpace(
         module, /*max_outstanding_async_copies=*/-1,
         MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
@@ -4045,5 +4046,218 @@
   EXPECT_EQ(cross_program_prefetches.size(), 0);
 }
 
+// For testing purposes, we define a cost analysis where we can control the
+// elapsed times of each HLO and asynchronous copy.
+class FakeMemorySpaceAssignmentCostAnalysis
+    : public MemorySpaceAssignmentCostAnalysis {
+ public:
+  static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
+  Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
+    TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
+    TF_ASSIGN_OR_RETURN(auto hlo_live_range,
+                        HloLiveRange::Run(module.schedule(), *alias_analysis,
+                                          module.entry_computation()));
+    auto call_graph = CallGraph::Build(&module);
+    return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
+        cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
+        /*alternate_mem_bandwidth_bytes_per_second=*/1,
+        std::move(alias_analysis), std::move(hlo_live_range),
+        std::move(call_graph)));
+  }
+
+  float GetInstructionElapsed(
+      const HloInstruction& instruction) const override {
+    return 1.0;
+  }
+
+  float GetInstructionElapsedInAlternateMemory(
+      const HloInstruction& instruction,
+      absl::optional<int64> operand_in_alternate_mem,
+      bool output_in_alternate_mem) const override {
+    if (operand_in_alternate_mem) {
+      return 0.5;
+    } else {
+      return 1.0;
+    }
+  }
+
+  float GetAsyncCopyElapsed(const Shape& shape) const override { return 3.0; }
+
+ protected:
+  FakeMemorySpaceAssignmentCostAnalysis(
+      const HloCostAnalysis& cost_analysis,
+      float async_copy_bandwidth_bytes_per_second,
+      float alternate_mem_bandwidth_bytes_per_second,
+      std::unique_ptr<HloAliasAnalysis> alias_analysis,
+      std::unique_ptr<HloLiveRange> hlo_live_range,
+      std::unique_ptr<CallGraph> call_graph)
+      : MemorySpaceAssignmentCostAnalysis(
+            cost_analysis, async_copy_bandwidth_bytes_per_second,
+            alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
+            std::move(hlo_live_range), std::move(call_graph)) {}
+};
+
+using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
+
+TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
+  absl::string_view hlo_string = R"(
+  HloModule bug, is_scheduled=true
+
+  ENTRY Entry {
+    param0 = f32[2,4] parameter(0)
+    a = f32[2,4] negate(param0)
+    b = f32[2,4] negate(a)
+    c = f32[2,4] negate(b)
+    d = f32[2,4] negate(c)
+    e = f32[2,4] negate(d)
+    f = f32[2,4] negate(e)
+    g = f32[2,4] negate(f)
+    h = f32[2,4] negate(g)
+    i = f32[2,4] negate(h)
+    j = f32[2,4] negate(i)
+    k = f32[2,4] negate(j)
+    l = f32[2,4] negate(k)
+    m = f32[2,4] negate(l)
+    n = f32[2,4] negate(m)
+    o = f32[2,4] negate(n)
+    p = f32[2,4] negate(o)
+    q = f32[2,4] negate(p)
+    r = f32[2,4] negate(q)
+    s = f32[2,4] negate(r)
+    t = f32[2,4] negate(s)
+    u = f32[2,4] negate(t)
+    ROOT v = f32[2,4] add(u, param0)
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  HloCostAnalysis hlo_cost_analysis(ShapeSize);
+  TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
+                          FakeMemorySpaceAssignmentCostAnalysis::Create(
+                              hlo_cost_analysis, *module));
+  CostAnalysisPrefetchIntervalPicker interval_picker(
+      *cost_analysis,
+      /*min_async_copy_to_overlap_ratio=*/1.0,
+      /*max_async_copy_to_overlap_ratio=*/4.0,
+      /*preferred_async_copy_to_overlap_ratio=*/2.0);
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
+  interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22);
+
+  // Expect that the first interval is (15, 22), which has elapsed time of 6.0,
+  // twice of the async copy elased (3.0). Then we expect that intervals will be
+  // visited in alternating increasing and decreasing orders until hitting the
+  // min and max async copy overlap ratios, which are the intervals (18, 22)
+  // and (9, 22) respectively.
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 15);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 16);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 14);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 17);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 13);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 18);  // Min async overlap ratio reached.
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 12);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 11);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 10);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 9);  // Max async overlap ratio reached.
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_TRUE(interval_picker.Done());
+
+  // Expect that if the time between start_time and end_time is too short, there
+  // won't be any available intervals.
+  interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_TRUE(interval_picker.Done());
+}
+
+TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) {
+  absl::string_view hlo_string = R"(
+  HloModule bug, is_scheduled=true
+
+  while_condition {
+    param1 = (f32[2,4]) parameter(0)    // 19
+    ROOT cond = pred[] constant(true)   // 20
+  }
+
+  while_body {
+    param2 = (f32[2,4]) parameter(0)    // 21
+    gte2 = f32[2,4] get-tuple-element(param2), index=0  // 22
+    add = f32[2,4] add(gte2, gte2)      // 23
+    ROOT tuple2 = (f32[2,4]) tuple(add) // 24
+  }
+
+  ENTRY Entry {
+    param0 = f32[2,4] parameter(0)  // 0
+    a = f32[2,4] negate(param0)     // 1
+    b = f32[2,4] negate(a)          // 2
+    c = f32[2,4] negate(b)          // 3
+    d = f32[2,4] negate(c)          // 4
+    e = f32[2,4] negate(d)          // 5
+    f = f32[2,4] negate(e)          // 6
+    g = f32[2,4] negate(f)          // 7
+    h = f32[2,4] negate(g)          // 8
+    i = f32[2,4] negate(h)          // 9
+    j = f32[2,4] negate(i)          // 10
+    k = f32[2,4] negate(j)          // 11
+    l = f32[2,4] negate(k)          // 12
+    m = f32[2,4] negate(l)          // 13
+    n = f32[2,4] negate(m)          // 14
+    o = f32[2,4] negate(n)          // 15
+    p = f32[2,4] negate(o)          // 16
+    q = f32[2,4] negate(p)          // 17
+    tuple = (f32[2,4]) tuple(q)     // 18
+    while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body  // 25
+    gte1 = f32[2,4] get-tuple-element(while), index=0  // 26
+    r = f32[2,4] negate(gte1)       // 27
+    s = f32[2,4] negate(r)          // 28
+    t = f32[2,4] negate(s)          // 29
+    u = f32[2,4] negate(t)          // 30
+    ROOT v = f32[2,4] add(u, param0)  // 31
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  HloCostAnalysis hlo_cost_analysis(ShapeSize);
+  TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
+                          FakeMemorySpaceAssignmentCostAnalysis::Create(
+                              hlo_cost_analysis, *module));
+  CostAnalysisPrefetchIntervalPicker interval_picker(
+      *cost_analysis,
+      /*min_async_copy_to_overlap_ratio=*/1.0,
+      /*max_async_copy_to_overlap_ratio=*/12.0,
+      /*preferred_async_copy_to_overlap_ratio=*/2.0);
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
+  interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31);
+
+  // Because there are while loop computations between [19, 24], we ensure that
+  // the interval picker avoids this interval.
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 25);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 26);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 18);
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 27);  // Min async overlap ratio reached.
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_EQ(interval_picker.Next(), 17);  // Max async overlap ratio reached.
+  LOG(INFO) << interval_picker.ToDebugString();
+  EXPECT_TRUE(interval_picker.Done());
+}
+
 }  // namespace
 }  // namespace xla