[XLA] Don't issue redundant evictions when there is an earlier one before a while op
Consider the scenario:
a = foo() // a in alternate mem
a' = evict(a) // a' in default mem
a'' = while(a) // a'' in alternate mem
// No need for eviction of a'' to default mem because a' is already in default mem.
PiperOrigin-RevId: 408965154
Change-Id: Ic837a918e83d7417aead3a246fefe6c8d826de73
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index c5136c6..fee95a6 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -1697,6 +1697,22 @@
<< body_allocation_value_it->allocation_sequence()
->back()
->ToString();
+
+ auto after_while_allocation_value_it = absl::c_find_if(
+ allocation_values, [&](const AllocationValue& value) {
+ return value.defining_instruction() == hlo_use.instruction;
+ });
+ CHECK_NE(after_while_allocation_value_it, allocation_values.end());
+ VLOG(3) << "After while allocation value: "
+ << after_while_allocation_value_it->ToShortString();
+ int64_t while_time = instruction_schedule.at(hlo_use.instruction);
+ after_while_allocation_value_it->allocation_sequence()->push_back(
+ absl::make_unique<MemorySpaceAssignment::MirroredAllocation>(
+ **prev_allocation_in_default_mem_it, while_time));
+ VLOG(3) << "Created: "
+ << after_while_allocation_value_it->allocation_sequence()
+ ->back()
+ ->ToString();
}
}
// Special case for while loops since the root offset must agree with
@@ -2414,12 +2430,11 @@
auto prev_allocation_it = allocation_sequence->rbegin();
// Find a previous allocation that is in the default memory space (not
// necessarily the very last allocation).
- auto prev_allocation_in_default_mem_it = std::find_if(
- allocation_sequence->rbegin(), allocation_sequence->rend(),
- [&](const auto& allocation) {
- return allocation->memory_space() == MemorySpace::kDefault &&
- allocation->defining_position() == defining_position;
- });
+ auto prev_allocation_in_default_mem_it =
+ std::find_if(allocation_sequence->rbegin(), allocation_sequence->rend(),
+ [&](const auto& allocation) {
+ return allocation->memory_space() == MemorySpace::kDefault;
+ });
if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
prev_allocation_it != allocation_sequence->rend() &&
@@ -3358,6 +3373,11 @@
prev_allocation_.ToString());
}
+std::string MemorySpaceAssignment::MirroredAllocation::ToString() const {
+ return absl::StrCat("Mirrored Allocation for ",
+ original_allocation_.ToString());
+}
+
std::string MemorySpaceAssignment::ParentAllocation::ToString() const {
return absl::StrCat("Parent Allocation mirrored at ",
defining_position_.ToString(), ", originally ",
@@ -3408,6 +3428,11 @@
return Status::OK();
}
+Status MemorySpaceAssignment::MirroredAllocation::Process() {
+ defining_position_ = original_allocation_.defining_position();
+ return Allocation::Process();
+}
+
Status MemorySpaceAssignment::ParentAllocation::Process() {
// Add an additional parameter to the while HLO with a reference to the buffer
// in the default memory space.
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 168cd64..adb0a3d 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -656,6 +656,27 @@
HloInstruction* copy_done_;
};
+ // An allocation in the default memory space that mirrors another Allocation
+ // object. This is useful to model an eviction that happens before a while op
+ // so that we don't need to redundantly evict the buffer after the while op as
+ // well.
+ class MirroredAllocation : public Allocation {
+ public:
+ MirroredAllocation(const Allocation& original_allocation, int64_t time)
+ : Allocation(original_allocation.defining_position(),
+ MemorySpace::kDefault, original_allocation.chunk(),
+ /*start_time=*/time,
+ /*end_time=*/time, /*is_scoped_allocation=*/false),
+ original_allocation_(original_allocation) {}
+
+ Status Process() override;
+
+ std::string ToString() const override;
+
+ private:
+ const Allocation& original_allocation_;
+ };
+
// An allocation in default memory space that is defined in the parent
// computation. If a value has a copy in the default memory space in the
// parent computation, we don't need to evict this buffer in a while loop.
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index ec1152d..df80988 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -4687,6 +4687,243 @@
}
}
+TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile) {
+ absl::string_view hlo_string = R"(
+ HloModule module, is_scheduled=true
+
+ while_cond {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ ROOT gte = pred[] get-tuple-element(p0), index=2
+ }
+
+ while_body {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ gte0 = f32[3]{0} get-tuple-element(p0), index=0
+ gte1 = f32[3]{0} get-tuple-element(p0), index=1
+ gte2 = pred[] get-tuple-element(p0), index=2
+ add = f32[3]{0} add(gte0, gte1)
+ ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
+ }
+
+ ENTRY entry {
+ p0 = f32[3]{0} parameter(0)
+ p1 = pred[] parameter(1)
+ copy = f32[3]{0} copy(p0)
+ negate0 = f32[3]{0} negate(p0)
+ negate1 = f32[3]{0} negate(negate0)
+ negate2 = f32[3]{0} negate(negate1)
+ negate3 = f32[3]{0} negate(negate2)
+ negate4 = f32[3]{0} negate(negate3)
+ negate5 = f32[3]{0} negate(negate4)
+ negate6 = f32[3]{0} negate(negate5)
+ negate7 = f32[3]{0} negate(negate6)
+ negate8 = f32[3]{0} negate(negate7)
+ negate9 = f32[3]{0} negate(negate8)
+ negate10 = f32[3]{0} negate(negate9)
+ negate11 = f32[3]{0} negate(negate10)
+ negate12 = f32[3]{0} negate(negate11)
+ negate13 = f32[3]{0} negate(negate12)
+ negate14 = f32[3]{0} negate(negate13)
+ tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
+ while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
+ gte0 = f32[3]{0} get-tuple-element(while), index=0
+ gte1 = f32[3]{0} get-tuple-element(while), index=1
+ negate20 = f32[3]{0} negate(gte1)
+ negate21 = f32[3]{0} negate(negate20)
+ negate22 = f32[3]{0} negate(negate21)
+ negate23 = f32[3]{0} negate(negate22)
+ negate24 = f32[3]{0} negate(negate23)
+ negate25 = f32[3]{0} negate(negate24)
+ negate26 = f32[3]{0} negate(negate25)
+ negate27 = f32[3]{0} negate(negate26)
+ negate28 = f32[3]{0} negate(negate27)
+ negate29 = f32[3]{0} negate(negate28)
+ negate30 = f32[3]{0} negate(negate29)
+ negate31 = f32[3]{0} negate(negate30)
+ negate32 = f32[3]{0} negate(negate31)
+ negate33 = f32[3]{0} negate(negate32)
+ negate34 = f32[3]{0} negate(negate33)
+ ROOT add = f32[3]{0} add(negate34, gte0)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ AssignMemorySpace(module.get());
+
+ if (GetParam()) {
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction()->operand(1),
+ op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, op::Copy()));
+ }
+}
+
+TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile2) {
+ absl::string_view hlo_string = R"(
+ HloModule module, is_scheduled=true
+
+ while_cond1 {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ ROOT gte = pred[] get-tuple-element(p0), index=2
+ }
+
+ while_body1 {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ gte0 = f32[3]{0} get-tuple-element(p0), index=0
+ gte1 = f32[3]{0} get-tuple-element(p0), index=1
+ gte2 = pred[] get-tuple-element(p0), index=2
+ add = f32[3]{0} add(gte0, gte1)
+ ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
+ }
+
+ while_cond2 {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ ROOT gte = pred[] get-tuple-element(p0), index=2
+ }
+
+ while_body2 {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ gte0 = f32[3]{0} get-tuple-element(p0), index=0
+ gte1 = f32[3]{0} get-tuple-element(p0), index=1
+ gte2 = pred[] get-tuple-element(p0), index=2
+ add = f32[3]{0} add(gte0, gte1)
+ ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
+ }
+
+ ENTRY entry {
+ p0 = f32[3]{0} parameter(0)
+ p1 = pred[] parameter(1)
+ copy = f32[3]{0} copy(p0)
+ tuple1 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
+ while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple1), condition=while_cond1, body=while_body1
+ gte0 = f32[3]{0} get-tuple-element(while1), index=0
+ gte1 = f32[3]{0} get-tuple-element(while1), index=1
+ negate0 = f32[3]{0} negate(gte1)
+ negate1 = f32[3]{0} negate(negate0)
+ negate2 = f32[3]{0} negate(negate1)
+ negate3 = f32[3]{0} negate(negate2)
+ negate4 = f32[3]{0} negate(negate3)
+ negate5 = f32[3]{0} negate(negate4)
+ negate6 = f32[3]{0} negate(negate5)
+ negate7 = f32[3]{0} negate(negate6)
+ negate8 = f32[3]{0} negate(negate7)
+ negate9 = f32[3]{0} negate(negate8)
+ negate10 = f32[3]{0} negate(negate9)
+ negate11 = f32[3]{0} negate(negate10)
+ negate12 = f32[3]{0} negate(negate11)
+ negate13 = f32[3]{0} negate(negate12)
+ negate14 = f32[3]{0} negate(negate13)
+ tuple2 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, negate14, p1)
+ while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple2), condition=while_cond2, body=while_body2
+ gte2 = f32[3]{0} get-tuple-element(while2), index=0
+ gte3 = f32[3]{0} get-tuple-element(while2), index=1
+ negate20 = f32[3]{0} negate(gte3)
+ negate21 = f32[3]{0} negate(negate20)
+ negate22 = f32[3]{0} negate(negate21)
+ negate23 = f32[3]{0} negate(negate22)
+ negate24 = f32[3]{0} negate(negate23)
+ negate25 = f32[3]{0} negate(negate24)
+ negate26 = f32[3]{0} negate(negate25)
+ negate27 = f32[3]{0} negate(negate26)
+ negate28 = f32[3]{0} negate(negate27)
+ negate29 = f32[3]{0} negate(negate28)
+ negate30 = f32[3]{0} negate(negate29)
+ negate31 = f32[3]{0} negate(negate30)
+ negate32 = f32[3]{0} negate(negate31)
+ negate33 = f32[3]{0} negate(negate32)
+ negate34 = f32[3]{0} negate(negate33)
+ ROOT add = f32[3]{0} add(negate34, gte2)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ AssignMemorySpace(module.get());
+
+ if (GetParam()) {
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction()->operand(1),
+ op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
+ op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
+ op::GetTupleElement(op::While()))));
+ }
+}
+
+TEST_P(MemorySpaceAssignmentTest,
+ AfterWhileRedundantEarlierEvictionModifiedBuffer) {
+ absl::string_view hlo_string = R"(
+ HloModule module, is_scheduled=true
+
+ while_cond {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ ROOT gte = pred[] get-tuple-element(p0), index=2
+ }
+
+ while_body {
+ p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
+ gte0 = f32[3]{0} get-tuple-element(p0), index=0
+ gte1 = f32[3]{0} get-tuple-element(p0), index=1
+ gte2 = pred[] get-tuple-element(p0), index=2
+ add = f32[3]{0} add(gte0, gte1)
+ negate = f32[3]{0} negate(gte0)
+ ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(negate, add, gte2)
+ }
+
+ ENTRY entry {
+ p0 = f32[3]{0} parameter(0)
+ p1 = pred[] parameter(1)
+ copy = f32[3]{0} copy(p0)
+ negate0 = f32[3]{0} negate(p0)
+ negate1 = f32[3]{0} negate(negate0)
+ negate2 = f32[3]{0} negate(negate1)
+ negate3 = f32[3]{0} negate(negate2)
+ negate4 = f32[3]{0} negate(negate3)
+ negate5 = f32[3]{0} negate(negate4)
+ negate6 = f32[3]{0} negate(negate5)
+ negate7 = f32[3]{0} negate(negate6)
+ negate8 = f32[3]{0} negate(negate7)
+ negate9 = f32[3]{0} negate(negate8)
+ negate10 = f32[3]{0} negate(negate9)
+ negate11 = f32[3]{0} negate(negate10)
+ negate12 = f32[3]{0} negate(negate11)
+ negate13 = f32[3]{0} negate(negate12)
+ negate14 = f32[3]{0} negate(negate13)
+ tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
+ while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
+ gte0 = f32[3]{0} get-tuple-element(while), index=0
+ gte1 = f32[3]{0} get-tuple-element(while), index=1
+ negate20 = f32[3]{0} negate(gte1)
+ negate21 = f32[3]{0} negate(negate20)
+ negate22 = f32[3]{0} negate(negate21)
+ negate23 = f32[3]{0} negate(negate22)
+ negate24 = f32[3]{0} negate(negate23)
+ negate25 = f32[3]{0} negate(negate24)
+ negate26 = f32[3]{0} negate(negate25)
+ negate27 = f32[3]{0} negate(negate26)
+ negate28 = f32[3]{0} negate(negate27)
+ negate29 = f32[3]{0} negate(negate28)
+ negate30 = f32[3]{0} negate(negate29)
+ negate31 = f32[3]{0} negate(negate30)
+ negate32 = f32[3]{0} negate(negate31)
+ negate33 = f32[3]{0} negate(negate32)
+ negate34 = f32[3]{0} negate(negate33)
+ ROOT add = f32[3]{0} add(negate34, gte0)
+ }
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ AssignMemorySpace(module.get());
+
+ if (GetParam()) {
+ EXPECT_THAT(
+ module->entry_computation()->root_instruction()->operand(1),
+ op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
+ op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
+ op::GetTupleElement(op::While()))));
+ }
+}
+
TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
// Tests against a bug where the root of entry computation is a bitcast
// instruction and it ends up getting an allocation in the alternate memory.