[XLA] Implement a more sophisticated prefetch interval ordering.
We previously used latest-to-earliest order in the prefetch interval picker.
This can result in prefetches to start unnecessarily late. With this CL, a
preferred prefetch interval is specified and the prefetch interval picker starts
at this preferred interval and returns alternating ascending and descending
indices. This can help cases where we spend a long time in copy-dones before a
while loop (3% improvement in b/161249728).
PiperOrigin-RevId: 321477583
Change-Id: Icd88c306c5f2bf7cd55e693fc7d040d93f05b70d
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 6610035..874200e 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -199,6 +199,12 @@
}
float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
+ const HloInstruction& instruction) const {
+ return std::max(GetInstructionElapsedDueToCompute(instruction),
+ GetInstructionElapsedDueToMemory(instruction));
+}
+
+float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
const HloInstruction& instruction,
absl::optional<int64> operand_in_alternate_mem,
bool output_in_alternate_mem) const {
@@ -258,12 +264,15 @@
CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
float min_async_copy_to_overlap_ratio,
- float max_async_copy_to_overlap_ratio)
+ float max_async_copy_to_overlap_ratio,
+ float preferred_async_copy_to_overlap_ratio)
: while_nest_level_(
cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
cost_analysis_(cost_analysis),
min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
- max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) {
+ max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio),
+ preferred_async_copy_to_overlap_ratio_(
+ preferred_async_copy_to_overlap_ratio) {
instruction_schedule_ =
&cost_analysis_.hlo_live_range().instruction_schedule();
@@ -281,7 +290,7 @@
instruction->opcode() == HloOpcode::kConditional) {
continue;
}
- float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
+ float elapsed_time = cost_analysis_.GetInstructionElapsed(
*instruction_and_logical_time.first);
int64 logical_time = instruction_and_logical_time.second;
if (logical_time >= instructions_elapsed_time.size()) {
@@ -355,52 +364,107 @@
async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
// Estimate the time we would save by having this op in alternate memory.
float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
- float elapsed_time_in_alternate_mem = cost_analysis_.GetInstructionElapsed(
- *use.instruction, use.operand_number);
+ float elapsed_time_in_alternate_mem =
+ cost_analysis_.GetInstructionElapsedInAlternateMemory(
+ *use.instruction, use.operand_number,
+ /*output_in_alternate_mem=*/false);
inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
end_logical_time_ = end_time;
- earliest_start_logical_time_ = start_time;
- int end_nest_level = while_nest_level_[end_time];
- // Find the latest time we're allowed to start prefetching. If the start and
- // end nest levels differe look for an earlier prefetch start.
- for (current_logical_prefetch_time_ = end_time - 1;
- current_logical_prefetch_time_ > start_time &&
- (while_nest_level_[current_logical_prefetch_time_] != end_nest_level ||
- min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
- GetLogicalIntervalElapsed(current_logical_prefetch_time_,
- end_logical_time_) +
- inst_elapsed_reduction_);
- --current_logical_prefetch_time_) {
+ int end_nest_level = while_nest_level_[end_logical_time_];
+
+ // Find the latest time we're allowed to start prefetching.
+ float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
+ for (latest_prefetch_time_ = end_logical_time_ - 1;
+ latest_prefetch_time_ >= start_time &&
+ (while_nest_level_[latest_prefetch_time_] != end_nest_level ||
+ min_interval > GetLogicalIntervalElapsed(latest_prefetch_time_,
+ end_logical_time_) +
+ inst_elapsed_reduction_);
+ --latest_prefetch_time_) {
}
+
+ // Find the earliest time we're allowed to start prefetching.
+ float max_interval = max_async_copy_to_overlap_ratio_ *
+ max_overlap_multiplier_ * async_copy_elapsed_;
+ for (earliest_prefetch_time_ = start_time;
+ earliest_prefetch_time_ <= end_logical_time_ &&
+ (while_nest_level_[earliest_prefetch_time_] != end_nest_level ||
+ max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
+ end_logical_time_));
+ ++earliest_prefetch_time_) {
+ }
+ if (earliest_prefetch_time_ > latest_prefetch_time_) {
+ // There is no available prefetch interval for the given start and end
+ // times. Set the iterators accordingly to ensure Done() returns true.
+ increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
+ decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
+ CHECK(Done());
+ return;
+ }
+
+ // Between the earliest and latest prefetch interval, find the interval
+ // closest to the preferred interval and start iterating from there.
+ int64 starting_prefetch_time = earliest_prefetch_time_;
+ float preferred_interval =
+ preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
+ float best_interval =
+ GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_);
+ for (int64 prefetch_time = earliest_prefetch_time_ + 1;
+ prefetch_time <= latest_prefetch_time_; ++prefetch_time) {
+ float interval =
+ GetLogicalIntervalElapsed(prefetch_time, end_logical_time_);
+ if (while_nest_level_[prefetch_time] == end_nest_level &&
+ std::abs(preferred_interval - interval) <
+ std::abs(preferred_interval - best_interval)) {
+ best_interval = interval;
+ starting_prefetch_time = prefetch_time;
+ }
+ }
+ VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
+ << max_interval << " " << preferred_interval
+ << " prefetch time earliest/latest/starting = "
+ << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
+ << starting_prefetch_time;
+
+ increasing_prefetch_time_iterator_ = starting_prefetch_time;
+ decreasing_prefetch_time_iterator_ = starting_prefetch_time;
+ using_increasing_prefetch_time_iterator_ = true;
+ // Since both iterators start at the same position, call Next() once to
+ // advance one of the iterators.
+ Next();
}
int64 CostAnalysisPrefetchIntervalPicker::Next() {
CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
"Done() is false";
- int64 prefetch_time = current_logical_prefetch_time_;
- if (!Done()) {
- --current_logical_prefetch_time_;
+ if (using_increasing_prefetch_time_iterator_) {
+ int64 prefetch_time = increasing_prefetch_time_iterator_++;
+ while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
+ while_nest_level_[increasing_prefetch_time_iterator_] !=
+ while_nest_level_[end_logical_time_]) {
+ ++increasing_prefetch_time_iterator_;
+ }
+ if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
+ using_increasing_prefetch_time_iterator_ = false;
+ }
+ return prefetch_time;
+ } else {
+ int64 prefetch_time = decreasing_prefetch_time_iterator_--;
+ while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
+ while_nest_level_[decreasing_prefetch_time_iterator_] !=
+ while_nest_level_[end_logical_time_]) {
+ --decreasing_prefetch_time_iterator_;
+ }
+ if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
+ using_increasing_prefetch_time_iterator_ = true;
+ }
+ return prefetch_time;
}
- // If the prefetch start and end times differ, look for an earlier prefetch
- // start.
- while (!Done() && while_nest_level_[current_logical_prefetch_time_] !=
- while_nest_level_[end_logical_time_]) {
- --current_logical_prefetch_time_;
- }
- return prefetch_time;
}
bool CostAnalysisPrefetchIntervalPicker::Done() const {
- if (current_logical_prefetch_time_ < earliest_start_logical_time_) {
- return true;
- }
- float logical_interval_elapsed = GetLogicalIntervalElapsed(
- current_logical_prefetch_time_, end_logical_time_);
- return (max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
- async_copy_elapsed_ <
- logical_interval_elapsed) ||
- (min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
- logical_interval_elapsed + inst_elapsed_reduction_);
+ return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
+ decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
}
void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
@@ -440,13 +504,16 @@
}
std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
+ int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
+ ? increasing_prefetch_time_iterator_
+ : decreasing_prefetch_time_iterator_;
float logical_interval_elapsed = GetLogicalIntervalElapsed(
- current_logical_prefetch_time_, end_logical_time_);
+ current_logical_prefetch_time, end_logical_time_);
return absl::StrCat(
"Async copy elapsed (s) = ", async_copy_elapsed_,
", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
", logical interval elapsed (s) = ", logical_interval_elapsed,
- ", interval = (", current_logical_prefetch_time_, ", ", end_logical_time_,
+ ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
")");
}
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 8f20020..d1b508a 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -84,6 +84,8 @@
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
};
+ virtual ~MemorySpaceAssignmentCostAnalysis() = default;
+
static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
@@ -128,16 +130,21 @@
// Returns the estimated elapsed duration of the instruction in seconds. It
// assumes all operands and outputs of the instruction are in the default
+ // memory.
+ virtual float GetInstructionElapsed(const HloInstruction& instruction) const;
+
+ // Returns the estimated elapsed duration of the instruction in seconds. It
+ // assumes all operands and outputs of the instruction are in the default
// memory, except for the operand number that is in the alternate memory, if
// provided, or output if output_in_alternate_mem is true.
- float GetInstructionElapsed(
+ virtual float GetInstructionElapsedInAlternateMemory(
const HloInstruction& instruction,
- absl::optional<int64> operand_in_alternate_mem = absl::nullopt,
- bool output_in_alternate_mem = false) const;
+ absl::optional<int64> operand_in_alternate_mem,
+ bool output_in_alternate_mem) const;
// Returns the elapsed time it would take to asynchronously copy the shape
// from default to alternate memory space (or vice versa).
- float GetAsyncCopyElapsed(const Shape& shape) const;
+ virtual float GetAsyncCopyElapsed(const Shape& shape) const;
int64 GetScheduleEndTime() const;
@@ -147,7 +154,7 @@
const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
- private:
+ protected:
MemorySpaceAssignmentCostAnalysis(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
@@ -164,6 +171,7 @@
hlo_live_range_(std::move(hlo_live_range)),
call_graph_(std::move(call_graph)) {}
+ private:
const HloCostAnalysis& cost_analysis_;
float async_copy_bandwidth_bytes_per_second_;
float alternate_mem_bandwidth_bytes_per_second_;
@@ -267,16 +275,16 @@
// Prefetch interval picker that uses cost analysis to overlap asynchronous
// copies with independent computation. It uses min/max (asynchronous copy
// duration) / (independent computation duration) ratios to guide whether the
-// prefetch is within those bounds. It starts with the maximum allowed ratio
-// (earliest prefetch) in Begin() and works its way for later and later prefetch
-// with each Next() call until hitting the minimum ratio, in order not to hurt
-// the critical path.
+// prefetch is within those bounds. It starts with the preferred ratio in
+// Begin() and works its way for alternately earlier and later prefetches until
+// hitting min and max ratios.
class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
public:
CostAnalysisPrefetchIntervalPicker(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
float min_async_copy_to_overlap_ratio,
- float max_async_copy_to_overlap_ratio);
+ float max_async_copy_to_overlap_ratio,
+ float preferred_async_copy_to_overlap_ratio);
bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
int64 end_time) const override;
@@ -319,13 +327,17 @@
const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
float min_async_copy_to_overlap_ratio_;
float max_async_copy_to_overlap_ratio_;
+ float preferred_async_copy_to_overlap_ratio_;
float max_overlap_multiplier_ = 1.0;
float async_copy_elapsed_;
float inst_elapsed_reduction_;
int64 end_logical_time_;
- int64 earliest_start_logical_time_;
- int64 current_logical_prefetch_time_;
+ int64 earliest_prefetch_time_;
+ int64 latest_prefetch_time_;
+ bool using_increasing_prefetch_time_iterator_;
+ int64 increasing_prefetch_time_iterator_;
+ int64 decreasing_prefetch_time_iterator_;
};
// MemorySpaceAssignment assigns memory spaces (default or alternate) to each
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 10e11e5..a92b73c 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -60,7 +60,8 @@
CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
CostAnalysisPrefetchIntervalPicker(
*cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
- /*max_async_copy_to_overlap_ratio=*/10.0));
+ /*max_async_copy_to_overlap_ratio=*/10.0,
+ /*preferred_async_copy_to_overlap_ratio=*/1.5));
return AssignMemorySpace(
module, /*max_outstanding_async_copies=*/-1,
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
@@ -4045,5 +4046,218 @@
EXPECT_EQ(cross_program_prefetches.size(), 0);
}
+// For testing purposes, we define a cost analysis where we can control the
+// elapsed times of each HLO and asynchronous copy.
+class FakeMemorySpaceAssignmentCostAnalysis
+ : public MemorySpaceAssignmentCostAnalysis {
+ public:
+ static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
+ Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
+ TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
+ TF_ASSIGN_OR_RETURN(auto hlo_live_range,
+ HloLiveRange::Run(module.schedule(), *alias_analysis,
+ module.entry_computation()));
+ auto call_graph = CallGraph::Build(&module);
+ return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
+ cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
+ /*alternate_mem_bandwidth_bytes_per_second=*/1,
+ std::move(alias_analysis), std::move(hlo_live_range),
+ std::move(call_graph)));
+ }
+
+ float GetInstructionElapsed(
+ const HloInstruction& instruction) const override {
+ return 1.0;
+ }
+
+ float GetInstructionElapsedInAlternateMemory(
+ const HloInstruction& instruction,
+ absl::optional<int64> operand_in_alternate_mem,
+ bool output_in_alternate_mem) const override {
+ if (operand_in_alternate_mem) {
+ return 0.5;
+ } else {
+ return 1.0;
+ }
+ }
+
+ float GetAsyncCopyElapsed(const Shape& shape) const override { return 3.0; }
+
+ protected:
+ FakeMemorySpaceAssignmentCostAnalysis(
+ const HloCostAnalysis& cost_analysis,
+ float async_copy_bandwidth_bytes_per_second,
+ float alternate_mem_bandwidth_bytes_per_second,
+ std::unique_ptr<HloAliasAnalysis> alias_analysis,
+ std::unique_ptr<HloLiveRange> hlo_live_range,
+ std::unique_ptr<CallGraph> call_graph)
+ : MemorySpaceAssignmentCostAnalysis(
+ cost_analysis, async_copy_bandwidth_bytes_per_second,
+ alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
+ std::move(hlo_live_range), std::move(call_graph)) {}
+};
+
+using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
+
+TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
+ absl::string_view hlo_string = R"(
+ HloModule bug, is_scheduled=true
+
+ ENTRY Entry {
+ param0 = f32[2,4] parameter(0)
+ a = f32[2,4] negate(param0)
+ b = f32[2,4] negate(a)
+ c = f32[2,4] negate(b)
+ d = f32[2,4] negate(c)
+ e = f32[2,4] negate(d)
+ f = f32[2,4] negate(e)
+ g = f32[2,4] negate(f)
+ h = f32[2,4] negate(g)
+ i = f32[2,4] negate(h)
+ j = f32[2,4] negate(i)
+ k = f32[2,4] negate(j)
+ l = f32[2,4] negate(k)
+ m = f32[2,4] negate(l)
+ n = f32[2,4] negate(m)
+ o = f32[2,4] negate(n)
+ p = f32[2,4] negate(o)
+ q = f32[2,4] negate(p)
+ r = f32[2,4] negate(q)
+ s = f32[2,4] negate(r)
+ t = f32[2,4] negate(s)
+ u = f32[2,4] negate(t)
+ ROOT v = f32[2,4] add(u, param0)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ HloCostAnalysis hlo_cost_analysis(ShapeSize);
+ TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
+ FakeMemorySpaceAssignmentCostAnalysis::Create(
+ hlo_cost_analysis, *module));
+ CostAnalysisPrefetchIntervalPicker interval_picker(
+ *cost_analysis,
+ /*min_async_copy_to_overlap_ratio=*/1.0,
+ /*max_async_copy_to_overlap_ratio=*/4.0,
+ /*preferred_async_copy_to_overlap_ratio=*/2.0);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
+ interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22);
+
+ // Expect that the first interval is (15, 22), which has elapsed time of 6.0,
+ // twice of the async copy elased (3.0). Then we expect that intervals will be
+ // visited in alternating increasing and decreasing orders until hitting the
+ // min and max async copy overlap ratios, which are the intervals (18, 22)
+ // and (9, 22) respectively.
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 15);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 16);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 14);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 17);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 13);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached.
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 12);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 11);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 10);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached.
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_TRUE(interval_picker.Done());
+
+ // Expect that if the time between start_time and end_time is too short, there
+ // won't be any available intervals.
+ interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_TRUE(interval_picker.Done());
+}
+
+TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) {
+ absl::string_view hlo_string = R"(
+ HloModule bug, is_scheduled=true
+
+ while_condition {
+ param1 = (f32[2,4]) parameter(0) // 19
+ ROOT cond = pred[] constant(true) // 20
+ }
+
+ while_body {
+ param2 = (f32[2,4]) parameter(0) // 21
+ gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22
+ add = f32[2,4] add(gte2, gte2) // 23
+ ROOT tuple2 = (f32[2,4]) tuple(add) // 24
+ }
+
+ ENTRY Entry {
+ param0 = f32[2,4] parameter(0) // 0
+ a = f32[2,4] negate(param0) // 1
+ b = f32[2,4] negate(a) // 2
+ c = f32[2,4] negate(b) // 3
+ d = f32[2,4] negate(c) // 4
+ e = f32[2,4] negate(d) // 5
+ f = f32[2,4] negate(e) // 6
+ g = f32[2,4] negate(f) // 7
+ h = f32[2,4] negate(g) // 8
+ i = f32[2,4] negate(h) // 9
+ j = f32[2,4] negate(i) // 10
+ k = f32[2,4] negate(j) // 11
+ l = f32[2,4] negate(k) // 12
+ m = f32[2,4] negate(l) // 13
+ n = f32[2,4] negate(m) // 14
+ o = f32[2,4] negate(n) // 15
+ p = f32[2,4] negate(o) // 16
+ q = f32[2,4] negate(p) // 17
+ tuple = (f32[2,4]) tuple(q) // 18
+ while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25
+ gte1 = f32[2,4] get-tuple-element(while), index=0 // 26
+ r = f32[2,4] negate(gte1) // 27
+ s = f32[2,4] negate(r) // 28
+ t = f32[2,4] negate(s) // 29
+ u = f32[2,4] negate(t) // 30
+ ROOT v = f32[2,4] add(u, param0) // 31
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+
+ HloCostAnalysis hlo_cost_analysis(ShapeSize);
+ TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
+ FakeMemorySpaceAssignmentCostAnalysis::Create(
+ hlo_cost_analysis, *module));
+ CostAnalysisPrefetchIntervalPicker interval_picker(
+ *cost_analysis,
+ /*min_async_copy_to_overlap_ratio=*/1.0,
+ /*max_async_copy_to_overlap_ratio=*/12.0,
+ /*preferred_async_copy_to_overlap_ratio=*/2.0);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
+ interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31);
+
+ // Because there are while loop computations between [19, 24], we ensure that
+ // the interval picker avoids this interval.
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 25);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 26);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 18);
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached.
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached.
+ LOG(INFO) << interval_picker.ToDebugString();
+ EXPECT_TRUE(interval_picker.Done());
+}
+
} // namespace
} // namespace xla