| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/memory_space_assignment.h" |
| |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_matchers.h" |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h" |
| |
| namespace xla { |
| namespace { |
| |
| namespace op = xla::testing::opcode_matchers; |
| |
| constexpr int64 kPointerSize = 8; |
| constexpr float kAsyncCopyBandwidth = 100; |
| constexpr float kAlternateMemBandwidth = 1000; |
| constexpr float kBytesPerSecond = 100; |
| constexpr float kFlopsPerSecond = 1000; |
| constexpr float kTranscendentalsPerSecond = 10; |
| |
| int64 ShapeSize(const Shape& shape) { |
| return ShapeUtil::ByteSizeOf(shape, kPointerSize); |
| } |
| |
| class MemorySpaceAssignmentTest : public HloTestBase, |
| public ::testing::WithParamInterface<bool> { |
| protected: |
| // We use the following two memory space values to describe the default (slow |
| // and large) and alternate (fast and small) memory spaces. |
| const int64 kDefaultMemorySpace = 0; |
| const int64 kAlternateMemorySpace = 1; |
| |
| std::unique_ptr<PresetAssignments> AssignMemorySpaceUsingCostAnalysis( |
| HloModule* module) { |
| HloCostAnalysis hlo_cost_analysis(ShapeSize); |
| hlo_cost_analysis.set_flops_per_second(kFlopsPerSecond); |
| hlo_cost_analysis.set_bytes_per_second(kBytesPerSecond); |
| hlo_cost_analysis.set_transcendentals_per_second(kTranscendentalsPerSecond); |
| for (HloComputation* computation : module->MakeNonfusionComputations()) { |
| TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); |
| } |
| auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); |
| auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create( |
| hlo_cost_analysis, kAsyncCopyBandwidth, |
| kAlternateMemBandwidth, *module) |
| .ValueOrDie(); |
| CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( |
| CostAnalysisPrefetchIntervalPicker( |
| *cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, |
| /*max_async_copy_to_overlap_ratio=*/10.0, |
| /*preferred_async_copy_to_overlap_ratio=*/1.5)); |
| return AssignMemorySpace( |
| module, /*max_outstanding_async_copies=*/-1, |
| MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( |
| *cost_analysis, &cache_), |
| &prefetch_interval_picker); |
| } |
| |
| std::unique_ptr<PresetAssignments> AssignMemorySpace( |
| HloModule* module, int64 max_outstanding_async_copies = -1, |
| int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2) { |
| InstructionCountPrefetchIntervalPicker prefetch_interval_picker( |
| min_prefetch_interval, max_prefetch_interval); |
| return AssignMemorySpace(module, max_outstanding_async_copies, |
| /*buffer_interval_compare=*/{}, |
| &prefetch_interval_picker); |
| } |
| |
| std::unique_ptr<PresetAssignments> AssignMemorySpace( |
| HloModule* module, int64 max_outstanding_async_copies, |
| absl::optional<MemorySpaceAssignment::BufferIntervalCompare> |
| buffer_interval_compare, |
| PrefetchIntervalPicker* prefetch_interval_picker) { |
| auto size_fn = [](const BufferValue& buffer) { |
| return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); |
| }; |
| |
| auto is_allowed_in_alternate_mem = [](const HloValue& value) { |
| // Check if the value belongs to the entry computation. |
| HloInstruction* instruction = value.instruction(); |
| HloComputation* computation = instruction->parent(); |
| bool in_entry_computation = |
| (computation == computation->parent()->entry_computation()); |
| if (in_entry_computation && |
| instruction->opcode() == HloOpcode::kParameter) { |
| return false; |
| } |
| return true; |
| }; |
| |
| // Only check parameters in default memory if the original module didn't |
| // have the parameters in alternate memory. |
| bool check_parameters_in_default_memory = true; |
| for (const HloInstruction* parameter : |
| module->entry_computation()->parameter_instructions()) { |
| ShapeUtil::ForEachSubshape( |
| parameter->shape(), |
| [&](const Shape& subshape, const ShapeIndex& /*index*/) { |
| if (subshape.has_layout() && |
| subshape.layout().memory_space() == kAlternateMemorySpace) { |
| check_parameters_in_default_memory = false; |
| } |
| }); |
| } |
| |
| MemorySpaceAssignment::Options options; |
| options.alternate_memory_space = kAlternateMemorySpace; |
| options.max_size_in_bytes = 128; |
| options.alignment_in_bytes = 8; |
| options.buffer_interval_compare = buffer_interval_compare; |
| options.prefetch_interval_picker = prefetch_interval_picker; |
| options.size_fn = size_fn; |
| options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; |
| options.max_outstanding_prefetches = max_outstanding_async_copies; |
| options.max_outstanding_evictions = max_outstanding_async_copies; |
| options.allocate_across_sequential_calls = GetParam(); |
| options.verify = true; |
| |
| auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); |
| std::unique_ptr<HloLiveRange> hlo_live_range = |
| HloLiveRange::Run(module->schedule(), *alias_analysis, |
| module->entry_computation()) |
| .ValueOrDie(); |
| |
| std::unique_ptr<PresetAssignments> preset_assignments = |
| MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis, |
| options) |
| .ValueOrDie(); |
| if (check_parameters_in_default_memory) { |
| CheckParametersInDefaultMemory(module); |
| } |
| CheckPresetAssignments(preset_assignments.get()); |
| return preset_assignments; |
| } |
| |
| void CheckPresetAssignments(const PresetAssignments* preset_assignments) { |
| // Ensure that the exported preset assignments point to layouts in the |
| // alternate memory. Also ensure that the positions are unique. Note that |
| // we're using a std::set instead of absl::flat_hash_set because we can make |
| // use of HloPosition's comparator logic instead of providing a hasher. |
| std::set<HloPosition> positions_in_preset_assignments; |
| for (auto& position_and_chunk : preset_assignments->chunks()) { |
| HloPosition position = position_and_chunk.first; |
| EXPECT_EQ(positions_in_preset_assignments.find(position), |
| positions_in_preset_assignments.end()); |
| positions_in_preset_assignments.insert(position); |
| const Shape& subshape = |
| ShapeUtil::GetSubshape(position.instruction->shape(), position.index); |
| EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace) |
| << "Exported position is not in alternate mem: " |
| << position.ToString(); |
| } |
| } |
| |
| void CheckParametersInDefaultMemory(const HloModule* module) { |
| // Check that all the entry parameter subshapes are placed in default |
| // memory. |
| const HloComputation* entry_computation = module->entry_computation(); |
| for (const HloInstruction* parameter : |
| entry_computation->parameter_instructions()) { |
| ShapeUtil::ForEachSubshape( |
| parameter->shape(), |
| [&](const Shape& subshape, const ShapeIndex& /*index*/) { |
| if (subshape.has_layout()) { |
| EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace) |
| << "Parameter not in default memory: " |
| << parameter->ToString(); |
| } |
| }); |
| } |
| } |
| |
| struct OutstandingAsyncCopies { |
| int64 max_copies; |
| int64 max_prefetches; |
| int64 max_evictions; |
| }; |
| |
| /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies( |
| const HloModule& module) { |
| OutstandingAsyncCopies copies{0, 0, 0}; |
| int64 current_copies = 0; |
| int64 current_prefetches = 0; |
| int64 current_evictions = 0; |
| for (HloInstruction* instruction : module.schedule() |
| .sequence(module.entry_computation()) |
| .instructions()) { |
| if (instruction->opcode() == HloOpcode::kCopyStart) { |
| current_copies++; |
| if (ShapeUtil::GetSubshape(instruction->shape(), {0}) |
| .layout() |
| .memory_space() == kAlternateMemorySpace) { |
| current_prefetches++; |
| } else { |
| current_evictions++; |
| } |
| } else if (instruction->opcode() == HloOpcode::kCopyDone) { |
| current_copies--; |
| if (instruction->shape().layout().memory_space() == |
| kAlternateMemorySpace) { |
| current_prefetches--; |
| } else { |
| current_evictions--; |
| } |
| } |
| copies.max_copies = std::max(copies.max_copies, current_copies); |
| copies.max_prefetches = |
| std::max(copies.max_prefetches, current_prefetches); |
| copies.max_prefetches = std::max(copies.max_evictions, current_evictions); |
| } |
| return copies; |
| } |
| |
| std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* tanh = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| // tanh should be placed in the alternate memory since there isn't much |
| // contention in the beginning. However, tanh has another consumer at the |
| // end. So it should be kicked out to default memory and prefetched back in. |
| // The graph below is meant to increase the contention to force |
| // eviction/prefetch behavior. |
| HloInstruction* a = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh)); |
| HloInstruction* b = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); |
| HloInstruction* c = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); |
| HloInstruction* d = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); |
| HloInstruction* e = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b)); |
| HloInstruction* f = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c)); |
| HloInstruction* g = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d)); |
| HloInstruction* h = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c)); |
| HloInstruction* i = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d)); |
| HloInstruction* j = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d)); |
| HloInstruction* k = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f)); |
| HloInstruction* l = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h)); |
| HloInstruction* m = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j)); |
| HloInstruction* n = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l)); |
| HloInstruction* o = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m)); |
| // tanh is being used at the root instruction, and this should be |
| // prefetched. |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i, |
| j, k, l, m, n, o, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| return module; |
| } |
| |
| MemorySpaceAssignmentCostAnalysis::Cache cache_; |
| }; |
| |
| TEST_P(MemorySpaceAssignmentTest, ParameterOnly) { |
| // A module consisting of a single parameter. Inputs/outputs are currently |
| // excluded from memory space assignment. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_THAT(p0, op::ShapeWithLayout(shape)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, Simple) { |
| // A simple module with a few simple instructions. Expect this to be |
| // transformed with CopyStart and CopyDone instructions inserted after inputs |
| // and before outputs. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1)); |
| HloInstruction* sub = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); |
| HloInstruction* mul = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, add, sub, mul}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| auto preset_assignments = AssignMemorySpace(module.get()); |
| |
| // Inputs and outputs are currently placed in the default memory. Everything |
| // else should be in the alternate memory. |
| Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {2, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| EXPECT_THAT(p0, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(p1, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(mul, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem)); |
| |
| // Make sure the preset assignments is sane. |
| EXPECT_EQ(preset_assignments->chunks().size(), 3); |
| EXPECT_EQ(preset_assignments->assignment_informations().size(), 1); |
| // Ensure the offset assigned to add and sub are different. |
| EXPECT_NE(preset_assignments->chunks()[0].second.offset, |
| preset_assignments->chunks()[1].second.offset); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, NegateChain) { |
| // The negate chain is long enough for asynchronous copy to be inserted |
| // between p1 and add. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2, |
| negate3, negate4, negate5, negate6, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, |
| kDefaultMemorySpace, |
| op::Parameter(1)))); |
| // Parameters are in the default memory space. |
| EXPECT_THAT(p0, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(p1, op::ShapeWithLayout(shape)); |
| // Negate instructions are in the alternate memory space (1). |
| Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {2, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem)); |
| // Ensure the CopyStart/CopyDone schedules. |
| const HloInstructionSequence& sequence = |
| module->schedule().sequence(computation); |
| EXPECT_THAT(sequence.instructions()[0], op::Parameter(0)); |
| EXPECT_THAT(sequence.instructions()[1], op::Parameter(1)); |
| EXPECT_THAT(sequence.instructions()[2], op::CopyStart()); |
| EXPECT_THAT(sequence.instructions()[10], op::CopyDone()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) { |
| std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule(); |
| |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_THAT( |
| module->entry_computation()->root_instruction(), |
| op::Add(op::Add(), |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::AsyncCopy(kDefaultMemorySpace, |
| kAlternateMemorySpace, op::Tanh())))); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { |
| std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule(); |
| |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0); |
| |
| EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 0); |
| EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 0); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { |
| std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule(); |
| |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1); |
| |
| EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 1); |
| EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 1); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { |
| std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule(); |
| |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2); |
| |
| EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 2); |
| EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 2); |
| } |
| |
| // TODO(berkin): This test is broken with some prefetch timing improvements. |
| TEST_P(MemorySpaceAssignmentTest, |
| DISABLED_DontEvictWhenThereIsDefaultMemAllocation) { |
| // This test is the same as EvictAndPrefetchLimitAsyncCopies1, except we check |
| // that there is no eviction if not necessary (due to an existing allocation |
| // in default memory). |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* tanh = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| // tanh should be placed in the alternate memory since there isn't much |
| // contention in the beginning. However, tanh has another consumer at the end. |
| // So it should be kicked out to default memory and prefetched back in. The |
| // graph below is meant to increase the contention to force eviction/prefetch |
| // behavior. |
| HloInstruction* a = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh)); |
| HloInstruction* b = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); |
| HloInstruction* c = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); |
| HloInstruction* d = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); |
| HloInstruction* e = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b)); |
| HloInstruction* f = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c)); |
| HloInstruction* g = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d)); |
| HloInstruction* h = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c)); |
| HloInstruction* i = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d)); |
| HloInstruction* j = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d)); |
| HloInstruction* k = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f)); |
| HloInstruction* l = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h)); |
| HloInstruction* m = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j)); |
| HloInstruction* n = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l)); |
| HloInstruction* o = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m)); |
| // tanh is being used at the root instruction, and this should be |
| // prefetched. |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i, |
| j, k, l, m, n, o, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1); |
| |
| // We expect the second argument to multiply is prefetched c. |
| EXPECT_THAT(f, op::Multiply(op::Add(), op::CopyDone())); |
| // We make sure that the second argument to this multiply is not evicted |
| // CopyDone but is the original c. |
| EXPECT_THAT(h, op::Multiply(op::Subtract(), op::Multiply())); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchAndPrefetch) { |
| // Test for a memory corruption bug involving evict/prefetch/prefetch pattern, |
| // where the last prefetch copied from the original buffer in alternate buffer |
| // instead of evicted buffer. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* tanh = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* a = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh)); |
| HloInstruction* b = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); |
| HloInstruction* c = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); |
| HloInstruction* d = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); |
| HloInstruction* e = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b)); |
| HloInstruction* f = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c)); |
| HloInstruction* g = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d)); |
| HloInstruction* h = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c)); |
| HloInstruction* i = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d)); |
| HloInstruction* j = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d)); |
| HloInstruction* k = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f)); |
| HloInstruction* l = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h)); |
| HloInstruction* m = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j)); |
| HloInstruction* n = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l)); |
| HloInstruction* o = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m)); |
| HloInstruction* add0 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* negate7 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); |
| HloInstruction* negate8 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7)); |
| HloInstruction* negate9 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8)); |
| HloInstruction* add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate9, tanh)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, |
| {p0, p1, tanh, a, b, c, d, e, |
| f, g, h, i, j, k, l, m, |
| n, o, add0, negate0, negate1, negate2, negate3, negate4, |
| negate5, negate6, negate7, negate8, negate9, add1}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| // Check that both prefetches (add0 and add1) prefetch from the eviction |
| // instead of tanh, which will be placed in the alternate memory directly. |
| EXPECT_THAT( |
| add0, |
| op::Add(op::Add(), |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::AsyncCopy(kDefaultMemorySpace, |
| kAlternateMemorySpace, op::Tanh())))); |
| EXPECT_THAT( |
| add1, |
| op::Add(op::Negate(), |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::AsyncCopy(kDefaultMemorySpace, |
| kAlternateMemorySpace, op::Tanh())))); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, While) { |
| auto module = CreateNewVerifiedModule(); |
| Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); |
| Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape}); |
| |
| auto cond_builder = HloComputation::Builder("WhileCond"); |
| // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) |
| HloInstruction* cond_param = cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); |
| HloInstruction* cond_iter = cond_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); |
| HloInstruction* cond_limit = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f))); |
| // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) |
| HloInstruction* cond_lt = cond_builder.AddInstruction( |
| HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, |
| cond_limit, ComparisonDirection::kLt)); |
| HloComputation* cond_computation = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto body_builder = HloComputation::Builder("WhileBody"); |
| // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) |
| HloInstruction* body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "body_param")); |
| HloInstruction* body_iter = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1)); |
| HloInstruction* body_data = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, body_param, 0)); |
| HloInstruction* body_iter_increment = body_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f))); |
| HloInstruction* body_iter_next = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment)); |
| HloInstruction* body_data_increment = |
| body_builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}}))); |
| HloInstruction* body_data_mul = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kMultiply, body_data, body_data)); |
| HloInstruction* body_data_add = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, body_data, body_data_increment)); |
| HloInstruction* body_data_next = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, body_data_add, body_data_mul)); |
| HloInstruction* body_out = body_builder.AddInstruction( |
| HloInstruction::CreateTuple({body_data_next, body_iter_next})); |
| HloComputation* body_computation = |
| module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* data = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "param_iter")); |
| HloInstruction* iter = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, scalar_shape, "param_data")); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({data, iter})); |
| HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( |
| tuple_shape, cond_computation, body_computation, tuple)); |
| HloComputation* entry_computation = |
| module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(cond_computation, |
| {cond_param, cond_iter, cond_limit, cond_lt}); |
| schedule.set_sequence(body_computation, |
| {body_param, body_iter, body_data, body_iter_increment, |
| body_iter_next, body_data_increment, body_data_mul, |
| body_data_add, body_data_next, body_out}); |
| schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| // Ensure the tuple value and buffers used in the while instruction are |
| // exempted from using the alternate memory when allocating across sequential |
| // calls is disabled. However, body_data_mul is independent and can be safely |
| // be placed in the alternate memory. |
| const bool allocate_across_sequential_calls = GetParam(); |
| if (!allocate_across_sequential_calls) { |
| EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape)); |
| EXPECT_THAT(data, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape)); |
| EXPECT_THAT(body_data, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape)); |
| EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape)); |
| } |
| Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {2, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, Tuple) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape}); |
| Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape}); |
| HloInstruction* p = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* p0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p, 0)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p, 1)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1)); |
| HloInstruction* p2 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2)); |
| HloInstruction* p2_0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p2, 0)); |
| HloInstruction* mul = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5, |
| negate6, p1, add, p2, p2_0, mul}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_THAT( |
| mul, |
| op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, |
| kDefaultMemorySpace, |
| op::GetTupleElement())), |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::GetTupleElement(op::GetTupleElement())))); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, Bitcast) { |
| // Bitcasts can cause the position in the alternate memory to appear multiple |
| // times in the preset assignments. This test ensure the preset assignments |
| // refer to unique positions. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape param_shape = ShapeUtil::MakeShape(F32, {6}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, param_shape, "p1")); |
| HloInstruction* negate = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* bitcast = builder.AddInstruction( |
| HloInstruction::CreateBitcast(param_shape, negate)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(param_shape, HloOpcode::kAdd, bitcast, p1)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, negate, bitcast, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| bitcast = add->mutable_operand(0); |
| EXPECT_EQ(bitcast->opcode(), HloOpcode::kBitcast); |
| EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, Bitcast2) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape param_shape = ShapeUtil::MakeShape(F32, {6}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, param_shape, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* bitcast = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2, |
| negate3, negate4, bitcast, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_EQ(add->operand(0)->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, Bitcast3) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); |
| Shape shape3 = ShapeUtil::MakeShape(F32, {1, 6}); |
| Shape param_shape = ShapeUtil::MakeShape(F32, {6}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0")); |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, param_shape, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3)); |
| HloInstruction* bitcast1 = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape1, p1)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, bitcast1, negate4)); |
| HloInstruction* bitcast2 = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape3, p1)); |
| HloInstruction* bitcast3 = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape2, bitcast2)); |
| HloInstruction* bitcast4 = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape2, add)); |
| HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( |
| shape2, HloOpcode::kMultiply, bitcast3, bitcast4)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, |
| {p0, p1, negate0, negate1, negate2, negate3, negate4, |
| bitcast1, add, bitcast2, bitcast3, bitcast4, mul}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| // We expect one bitcast on the LHS of multiply since bitcast(bitcast(foo)) is |
| // converted to bitcast(foo). |
| EXPECT_THAT( |
| mul, |
| op::Multiply( |
| op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::Parameter(1))), |
| op::Bitcast(op::Add( |
| op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, |
| kDefaultMemorySpace, op::Parameter(1))), |
| op::Negate())))); |
| EXPECT_EQ(add->operand(0)->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| EXPECT_EQ(add->shape().layout().memory_space(), kAlternateMemorySpace); |
| // bitcast2 will no longer have a consumer and should get DCE'd, so we don't |
| // care about its memory space. |
| EXPECT_EQ(mul->operand(0)->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| EXPECT_EQ(mul->operand(1)->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, BitcastTuple) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape param_shape = ShapeUtil::MakeShape(F32, {6}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation::Builder fusion_builder("fusion"); |
| HloInstruction* fusion_param = fusion_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* fusion_element0 = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion_param, 0)); |
| HloInstruction* fusion_element1 = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion_param, 1)); |
| fusion_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, fusion_element0, fusion_element1)); |
| HloComputation* fusion_computation = |
| module->AddEmbeddedComputation(fusion_builder.Build()); |
| |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, param_shape, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* bitcast = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1)); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({bitcast, p0})); |
| HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion( |
| shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, |
| {p0, p1, negate0, negate1, negate2, negate3, negate4, |
| bitcast, tuple, fusion}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, BitcastGetTupleElementTuple) { |
| // This test pattern was encountered in |
| // //third_party/tensorflow/compiler/xla/tests:slice_test and was causing a |
| // breakage when there is a GetTupleElement(Tuple(Bitcast())) pattern. Also |
| // added a GetTupleElement(GetTupleElement(Tuple(Tuple(Bitcast())))) pattern. |
| absl::string_view hlo_string = R"( |
| HloModule DoIt_S64_10_0_5_1.3, is_scheduled=true |
| |
| ENTRY %DoIt_S64_10_0_5_1.3 (p0.1: (u32[10], u32[10])) -> (u32[5], u32[5]) { |
| %p0.1 = (u32[10]{0:T(128)}, u32[10]{0:T(128)}) parameter(0) |
| %get-tuple-element.1 = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=1 |
| %bitcast.1 = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element.1) |
| %get-tuple-element = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=0 |
| %bitcast = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element) |
| %tuple.1 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %bitcast, u32[5]{0:T(128)} %bitcast.1) |
| %tuple.3 = ((u32[5]{0:T(128)}, u32[5]{0:T(128)}), (u32[5]{0:T(128)}, u32[5]{0:T(128)})) tuple(%tuple.1, %tuple.1) |
| %get-tuple-element.4 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %tuple.1), index=0 |
| %get-tuple-element.5 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) get-tuple-element(%tuple.3), index=0 |
| %get-tuple-element.6 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %get-tuple-element.5), index=1 |
| %copy.2 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.4) |
| %copy.3 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.6) |
| ROOT %tuple.2 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %copy.2, u32[5]{0:T(128)} %copy.3) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, GetSimplifiedOperandBug) { |
| // Test case for a bug finding Bitcasts in GTE(Tuple(...)) pattern. |
| absl::string_view hlo_string = R"( |
| HloModule sort.16, is_scheduled=true |
| |
| ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) { |
| %param.3.4 = s32[1]{0:T(128)} parameter(3) |
| %param.2.3 = u32[1]{0:T(128)} parameter(2) |
| %param.1.2 = f32[1]{0:T(128)} parameter(1) |
| %param.0.1 = s32[1]{0:T(128)} parameter(0) |
| %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4) |
| %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0 |
| %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1 |
| %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2 |
| %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3 |
| %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4) |
| %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5) |
| %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6) |
| %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7) |
| ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, BitcastMultiUse) { |
| // When there is a pattern where a bitcast has multiple uses (negate0 and add) |
| // and one is in the default memory and the other is in alternate memory, they |
| // both need their own bitcast. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape param_shape = ShapeUtil::MakeShape(F32, {6}); |
| HloInstruction* p0 = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, param_shape, "p1")); |
| HloInstruction* bitcast = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2, |
| negate3, negate4, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {2, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape)); |
| EXPECT_THAT(add->operand(0), op::ShapeWithLayout(shape_in_alternate_mem)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, BitcastMultiUseTuple) { |
| // Same as BitcastMultUse but the second use is a tuple. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape param_shape = ShapeUtil::MakeShape(F32, {6}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation::Builder fusion_builder("fusion"); |
| HloInstruction* fusion_param = fusion_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* fusion_element0 = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion_param, 0)); |
| HloInstruction* fusion_element1 = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion_param, 1)); |
| fusion_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, fusion_element0, fusion_element1)); |
| HloComputation* fusion_computation = |
| module->AddEmbeddedComputation(fusion_builder.Build()); |
| |
| HloInstruction* p0 = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, param_shape, "p1")); |
| HloInstruction* bitcast = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({bitcast, negate4})); |
| HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion( |
| shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2, |
| negate3, negate4, tuple, fusion}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {2, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape)); |
| EXPECT_THAT(fusion->operand(0)->operand(0), |
| op::ShapeWithLayout(shape_in_alternate_mem)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, BitcastScheduleBug) { |
| // Bitcasts can force asynchronous copies to be scheduled too early, possibly |
| // leading to memory corruption. |
| // Bug: |
| // p0------------------>neg-->neg-->neg ... -->neg-->neg-->neg->add |
| // / |
| // p1->cs->cd->bitcast-----------------------------------------+ |
| // |
| // Expected: |
| // p0-->neg-->neg-->neg ... -->neg-->neg-->neg------------->add |
| // / |
| // p1--------------------->cs----------------->cd->bitcast-+ |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape param_shape = ShapeUtil::MakeShape(F32, {6}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, param_shape, "p1")); |
| HloInstruction* bitcast = |
| builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* negate7 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); |
| HloInstruction* negate8 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7)); |
| HloInstruction* negate9 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate9)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, {p0, p1, bitcast, negate0, negate1, negate2, negate3, |
| negate4, negate5, negate6, negate7, negate8, negate9, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, |
| /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/4); |
| |
| EXPECT_EQ(add->operand(0)->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| const auto& instructions = |
| module->schedule().sequence(module->entry_computation()).instructions(); |
| for (int i = 0; i < instructions.size(); ++i) { |
| // Expect that there is a negate before and after the CopyStart and there is |
| // a negate before CopyDone. |
| if (instructions.at(i)->opcode() == HloOpcode::kCopyStart) { |
| EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate); |
| EXPECT_EQ(instructions.at(i + 1)->opcode(), HloOpcode::kNegate); |
| } else if (instructions.at(i)->opcode() == HloOpcode::kCopyDone) { |
| EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate); |
| } |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, TupleSelect) { |
| // Make sure tuple-select is not optimized away. |
| absl::string_view hlo_string = R"( |
| HloModule tuple, is_scheduled=true |
| |
| ENTRY %main (a: f32[2], b: f32[2], c: f32[2], d: f32[2], cond: pred[]) -> f32[2] { |
| %cond = pred[]{:T(128)E(32)} parameter(4) |
| %token0 = token[] after-all() |
| %d = f32[2]{0:T(128)} parameter(3) |
| %c = f32[2]{0:T(128)} parameter(2) |
| %b = f32[2]{0:T(128)} parameter(1) |
| %a = f32[2]{0:T(128)} parameter(0) |
| %tup0 = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple(f32[2]{0:T(128)} %a, f32[2]{0:T(128)} %b) |
| %tup1 = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple(f32[2]{0:T(128)} %c, f32[2]{0:T(128)} %d) |
| %s = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple-select(pred[]{:T(128)E(32)} %cond, (f32[2]{0:T(128)}, f32[2]{0:T(128)}) %tup0, (f32[2]{0:T(128)}, f32[2]{0:T(128)}) %tup1) |
| %gte = f32[2]{0:T(128)} get-tuple-element((f32[2]{0:T(128)}, f32[2]{0:T(128)}) %s), index=0 |
| ROOT %negate = f32[2]{0:T(128)} negate(f32[2]{0:T(128)} %gte) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Negate(op::GetTupleElement(op::TupleSelect()))); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, AddDependency) { |
| // Make sure add-dependency is not optimized away. |
| absl::string_view hlo_string = R"( |
| HloModule AddDependency, is_scheduled=true |
| |
| ENTRY %AddDependency (p: f32[3]) -> f32[3] { |
| %p = f32[3]{0} parameter(0) |
| %neg0 = f32[3]{0} negate(f32[3]{0} %p) |
| %neg1 = f32[3]{0} negate(f32[3]{0} %neg0) |
| %neg2 = f32[3]{0} negate(f32[3]{0} %neg1) |
| %neg3 = f32[3]{0} negate(f32[3]{0} %neg2) |
| %neg4 = f32[3]{0} negate(f32[3]{0} %neg3) |
| %neg5 = f32[3]{0} negate(f32[3]{0} %neg4) |
| %neg6 = f32[3]{0} negate(f32[3]{0} %neg5) |
| %token0 = token[] after-all() |
| %add_dep = f32[3]{0} add-dependency(f32[3]{0} %p, token[] %token0) |
| ROOT %add = f32[3]{0} add(f32[3]{0} %add_dep, f32[3]{0} %neg6) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Add(op::AddDependency(), op::Negate())); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) { |
| // This test is carefully crafted to include two multiply ops sized [4,3] in a |
| // while body. For testing purposes, we have provided a BufferIntervalCompare |
| // such that first multiply, then tanh, then other HloValues will be |
| // allocated. The memory is sized just enough to fit two [4,3] buffers. |
| // Because the multiplies in the while body are going to be allocated in the |
| // alternate memory first, the tanh that is fed inside the while loop should |
| // not be placed in the alternate memory. Otherwise, we will corrupt memory. |
| absl::string_view hlo_string = R"( |
| HloModule WhileAllocationBug, is_scheduled=true |
| |
| %WhileBody (body_param: (f32[4,3], f32[])) -> (f32[4,3], f32[]) { |
| %body_param = (f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=1 |
| %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=0 |
| %constant.1 = f32[] constant(1) |
| %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) |
| %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) |
| %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.2) |
| %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) |
| %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2) |
| %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) |
| ROOT %tuple = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[] %add) |
| } |
| |
| %WhileCond (cond_param: (f32[4,3], f32[])) -> pred[] { |
| %cond_param = (f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %cond_param), index=1 |
| %constant = f32[] constant(50) |
| ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT |
| } |
| |
| ENTRY %Entry (param_iter: f32[4,3], param_data: f32[], p2: f32[4,3]) -> f32[4,3] { |
| %param_data = f32[] parameter(1) |
| %param_iter = f32[4,3]{1,0} parameter(0) |
| %p2 = f32[4,3]{1,0} parameter(2) |
| %tanh = f32[4,3]{1,0} tanh(f32[4,3]{1,0} %param_iter) |
| %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2) |
| %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0) |
| %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1) |
| %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2) |
| %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3) |
| %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4) |
| %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5) |
| %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %tanh) |
| %tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %tanh, f32[] %param_data) |
| %while = (f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody |
| %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %while), index=0 |
| ROOT %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.3, f32[4,3]{1,0} %add.4) |
| } |
| )"; |
| |
| MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = |
| [](const MemorySpaceAssignment::BufferInterval& a, |
| const MemorySpaceAssignment::BufferInterval& b) { |
| bool a_is_mul = |
| a.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply; |
| bool b_is_mul = |
| b.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply; |
| if (a_is_mul && !b_is_mul) { |
| return true; |
| } |
| if (!a_is_mul && b_is_mul) { |
| return false; |
| } |
| bool a_is_tanh = |
| a.buffer->defining_instruction()->opcode() == HloOpcode::kTanh; |
| bool b_is_tanh = |
| b.buffer->defining_instruction()->opcode() == HloOpcode::kTanh; |
| if (a_is_tanh && !b_is_tanh) { |
| return true; |
| } |
| if (!a_is_tanh && b_is_tanh) { |
| return false; |
| } |
| return a.buffer->id() < b.buffer->id(); |
| }; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| |
| InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, |
| buffer_interval_compare, &prefetch_interval_picker); |
| |
| for (const HloInstruction* instruction : |
| module->entry_computation()->instructions()) { |
| if (instruction->opcode() == HloOpcode::kWhile) { |
| const Shape& while_subshape = |
| ShapeUtil::GetSubshape(instruction->shape(), {0}); |
| // We expect shape {0} to either be in default memory for the entire while |
| // loop or there has to be an eviction within the while loop. |
| if (while_subshape.layout().memory_space() == kAlternateMemorySpace) { |
| const HloInstruction* body_param = |
| instruction->while_body()->parameter_instruction(0); |
| const HloInstruction* gte = nullptr; |
| for (const HloInstruction* user : body_param->users()) { |
| if (user->opcode() == HloOpcode::kGetTupleElement && |
| user->tuple_index() == 0) { |
| gte = user; |
| break; |
| } |
| } |
| EXPECT_NE(gte, nullptr); |
| const HloInstruction* copy_start = nullptr; |
| for (const HloInstruction* user : gte->users()) { |
| if (user->opcode() == HloOpcode::kCopyStart) { |
| copy_start = user; |
| break; |
| } |
| } |
| EXPECT_NE(copy_start, nullptr); |
| const Shape& copy_start_subshape = |
| ShapeUtil::GetSubshape(copy_start->shape(), {0}); |
| |
| EXPECT_NE(copy_start_subshape.layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| } |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) { |
| absl::string_view hlo_string = R"( |
| HloModule WhileAllocationBug, is_scheduled=true |
| |
| %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { |
| %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 |
| %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 |
| %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 |
| %constant.1 = f32[] constant(1) |
| %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) |
| %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) |
| %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3) |
| %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) |
| %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2) |
| %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) |
| ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) |
| } |
| |
| %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { |
| %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 |
| %constant = f32[] constant(50) |
| ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT |
| } |
| |
| %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { |
| %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 |
| %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 |
| %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 |
| %constant.1 = f32[] constant(1) |
| %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) |
| %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) |
| %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3) |
| %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) |
| %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2) |
| %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) |
| ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) |
| } |
| |
| %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { |
| %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 |
| %constant = f32[] constant(50) |
| ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT |
| } |
| |
| ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] { |
| %param_iter = f32[] parameter(1) |
| %param_data = f32[4,3]{1,0} parameter(0) |
| %p2 = f32[4,3]{1,0} parameter(2) |
| %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2) |
| %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0) |
| %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1) |
| %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2) |
| %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3) |
| %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4) |
| %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5) |
| %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2) |
| %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter) |
| %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody |
| %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0 |
| %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4) |
| %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1 |
| %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter) |
| %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2 |
| %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0 |
| ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) { |
| // Tests against while live ranges being incorrect and the verifier |
| // complaining about a conflict. |
| absl::string_view hlo_string = R"( |
| HloModule WhileAllocationBug, is_scheduled=true |
| |
| %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { |
| %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 |
| %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 |
| %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 |
| %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2) |
| %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10) |
| %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11) |
| %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12) |
| %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13) |
| %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14) |
| %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15) |
| %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16) |
| %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17) |
| %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18) |
| %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19) |
| %constant.1 = f32[] constant(1) |
| %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) |
| %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) |
| %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20) |
| %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) |
| %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2) |
| %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) |
| ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) |
| } |
| |
| %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { |
| %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 |
| %constant = f32[] constant(50) |
| ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT |
| } |
| |
| ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] { |
| %param_iter = f32[] parameter(1) |
| %param_data = f32[4,3]{1,0} parameter(0) |
| %p2 = f32[4,3]{1,0} parameter(2) |
| %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2) |
| %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0) |
| %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1) |
| %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2) |
| %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3) |
| %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4) |
| %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5) |
| %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2) |
| %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter) |
| %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody |
| %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0 |
| %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1 |
| %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4) |
| ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) { |
| // Tests against a bug when there are consecutive while loops with one buffer |
| // (the value doesn't change in the buffer), the parameter can be colored in |
| // the alternate memory space. |
| absl::string_view hlo_string = R"( |
| HloModule WhileAllocationBug, is_scheduled=true |
| |
| %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { |
| %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 |
| %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 |
| %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 |
| %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2) |
| %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10) |
| %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11) |
| %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12) |
| %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13) |
| %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14) |
| %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15) |
| %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16) |
| %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17) |
| %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18) |
| %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19) |
| %constant.1 = f32[] constant(1) |
| %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) |
| %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) |
| %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20) |
| %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) |
| %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2) |
| %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) |
| ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) |
| } |
| |
| %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { |
| %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 |
| %constant = f32[] constant(50) |
| ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT |
| } |
| |
| %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) { |
| %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2 |
| %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0 |
| %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1 |
| %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2) |
| %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10) |
| %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11) |
| %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12) |
| %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13) |
| %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14) |
| %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15) |
| %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16) |
| %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17) |
| %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18) |
| %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19) |
| %constant.1 = f32[] constant(1) |
| %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1) |
| %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } }) |
| %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20) |
| %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply) |
| %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2) |
| %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2) |
| ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add) |
| } |
| |
| %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] { |
| %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2 |
| %constant = f32[] constant(50) |
| ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT |
| } |
| |
| ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] { |
| %param_iter = f32[] parameter(1) |
| %param_data = f32[4,3]{1,0} parameter(0) |
| %p2 = f32[4,3]{1,0} parameter(2) |
| %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2) |
| %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0) |
| %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1) |
| %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2) |
| %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3) |
| %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4) |
| %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5) |
| %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2) |
| %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter) |
| %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody |
| %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0 |
| %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4) |
| %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter) |
| %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2 |
| %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0 |
| %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1 |
| ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) { |
| // While loop is the root of the entry computation. We should ensure the |
| // output of the entry computation remains to be in default memory space. |
| // Test from //third_party/tensorflow/compiler/xla/tests:while_test |
| // WhileTest.WhileWithPrngScalarResult. |
| absl::string_view hlo_string = R"( |
| HloModule WhileWithPrngScalarResult.18, is_scheduled=true |
| |
| %fused_computation (param_0.1: s32[6], param_1.3: s32[1], param_2.3: s32[5]) -> s32[6] { |
| %param_1.3 = s32[1]{0:T(128)} parameter(1) |
| %constant.2 = s32[]{:T(128)} constant(-2147483648) |
| %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 |
| %param_2.3 = s32[5]{0:T(128)} parameter(2) |
| %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 |
| %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) |
| %param_0.1 = s32[6]{0:T(128)} parameter(0) |
| ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) |
| } |
| |
| %body.3 (prev.4: s32[6]) -> s32[6] { |
| %constant.7 = s32[]{:T(128)} constant(100) |
| %constant.6 = s32[]{:T(128)} constant(0) |
| %constant.5 = s32[1]{0:T(128)} constant({1}) |
| %prev.4 = s32[6]{0:T(128)} parameter(0) |
| %rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform |
| %neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5) |
| ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation |
| } |
| |
| %WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] { |
| %constant.15 = s32[]{:T(128)} constant(1) |
| %prev.12 = s32[6]{0:T(128)} parameter(0) |
| %bitcast.1 = s32[1]{0:T(128)} bitcast(s32[6]{0:T(128)} %prev.12) |
| %bitcast = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %bitcast.1) |
| ROOT %compare.16 = pred[]{:T(128)E(32)} compare(s32[]{:T(128)} %constant.15, s32[]{:T(128)} %bitcast), direction=GT |
| } |
| |
| ENTRY %WhileWithPrngScalarResult.18 () -> s32[6] { |
| %constant.1 = s32[]{:T(128)} constant(0) |
| %broadcast.2 = s32[6]{0:T(128)} broadcast(s32[]{:T(128)} %constant.1), dimensions={} |
| ROOT %while.17 = s32[6]{0:T(128)} while(s32[6]{0:T(128)} %broadcast.2), condition=%WhileWithPrngScalarResult.11, body=%body.3 |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| // Expect the output to have default memory space. |
| EXPECT_EQ(module->entry_computation() |
| ->root_instruction() |
| ->shape() |
| .layout() |
| .memory_space(), |
| kDefaultMemorySpace); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) { |
| // Ensure that a dynamic update slice within a while loop is able to get an |
| // alternate memory allocation. |
| absl::string_view hlo_string = R"( |
| HloModule Module, is_scheduled=true |
| |
| fused_computation { |
| param0 = f32[2,3] parameter(0) |
| constant.1 = f32[] constant(0) |
| broadcast = f32[2,1] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3) |
| } |
| |
| %WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) { |
| %body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2 |
| %get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0 |
| %get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1 |
| %fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation |
| %multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion) |
| ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1) |
| } |
| |
| %WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] { |
| %cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2 |
| %constant = f32[] constant(50) |
| ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT |
| } |
| |
| ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] { |
| %param_iter = f32[] parameter(1) |
| %param_data = f32[2,3]{1,0} parameter(0) |
| %p2 = f32[2,3]{1,0} parameter(2) |
| %copy1 = f32[2,3]{1,0} copy(param_data) |
| %copy2 = f32[2,3]{1,0} copy(p2) |
| %tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter) |
| %while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody |
| %get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0 |
| ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| const HloInstruction* while_op = |
| module->entry_computation()->GetInstructionWithName("while"); |
| if (GetParam()) { |
| EXPECT_EQ( |
| ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) { |
| // Having control_predecessors on an HLO was preventing us from DCEing an op |
| // that doesn't have any users (tuple.1). The scheduler assumes the graph is |
| // fully DCEed, which causes some instructions not to be scheduled. |
| absl::string_view hlo_string = R"( |
| HloModule sort.16, is_scheduled=true |
| |
| ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) { |
| %param.3.4 = s32[1]{0:T(128)} parameter(3) |
| %param.2.3 = u32[1]{0:T(128)} parameter(2) |
| %param.1.2 = f32[1]{0:T(128)} parameter(1) |
| %param.0.1 = s32[1]{0:T(128)} parameter(0) |
| %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4), control-predecessors={%param.0.1} |
| %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0 |
| %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1 |
| %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2 |
| %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3 |
| %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4) |
| %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5) |
| %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6) |
| %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7) |
| ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) { |
| // Checks if simple conditionals get alternate memory allocations. |
| absl::string_view hlo_string = R"( |
| HloModule CondAllocation, is_scheduled=true |
| |
| true_computation { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg1 = f32[3]{0} negate(gte) |
| } |
| |
| false_computation { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg2 = f32[3]{0} negate(gte) |
| } |
| |
| ENTRY entry { |
| p0 = f32[3]{0} parameter(0) |
| p1 = pred[] parameter(1) |
| copy = f32[3]{0} copy(p0) |
| tuple = (f32[3]{0}) tuple(copy) |
| ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| if (GetParam()) { |
| // Check that copy and gtes got alternate memory allocations. |
| auto copy = |
| module->GetComputationWithName("entry")->GetInstructionWithName("copy"); |
| EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); |
| auto neg1 = module->GetComputationWithName("true_computation") |
| ->GetInstructionWithName("neg1"); |
| auto neg1_operand = neg1->operand(0); |
| EXPECT_EQ(neg1_operand->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| auto neg2 = module->GetComputationWithName("false_computation") |
| ->GetInstructionWithName("neg2"); |
| auto neg2_operand = neg2->operand(0); |
| EXPECT_EQ(neg2_operand->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) { |
| // Checks if we avoid unnecessary allocation in alternate memory if the input |
| // won't be used in the computation for a long time. |
| absl::string_view hlo_string = R"( |
| HloModule CondAllocation, is_scheduled=true |
| |
| true_computation { |
| p0 = (f32[3]{0}, f32[3]{0}) parameter(0) |
| gte0 = f32[3]{0} get-tuple-element(p0), index=0 |
| neg0 = f32[3]{0} negate(gte0) |
| neg1 = f32[3]{0} negate(neg0) |
| neg2 = f32[3]{0} negate(neg1) |
| neg3 = f32[3]{0} negate(neg2) |
| neg4 = f32[3]{0} negate(neg3) |
| neg5 = f32[3]{0} negate(neg4) |
| neg6 = f32[3]{0} negate(neg5) |
| neg7 = f32[3]{0} negate(neg6) |
| neg8 = f32[3]{0} negate(neg7) |
| neg9 = f32[3]{0} negate(neg8) |
| gte1 = f32[3]{0} get-tuple-element(p0), index=1 |
| ROOT add = f32[3]{0} add(neg9, gte1) |
| } |
| |
| false_computation { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg = f32[3]{0} negate(gte) |
| } |
| |
| ENTRY entry { |
| p0 = f32[3]{0} parameter(0) |
| p1 = pred[] parameter(1) |
| copy0 = f32[3]{0} copy(p0) |
| copy1 = f32[3]{0} copy(p0) |
| tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) |
| tuple1 = (f32[3]{0}) tuple(copy0) |
| ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| if (GetParam()) { |
| // Check that copy1 doesn't get unnecessarily allocated in alternate mem |
| // (due to long negate chain in true_computation) but is prefetched before |
| // add. |
| auto copy0 = |
| module->GetComputationWithName("entry")->GetInstructionWithName( |
| "copy0"); |
| EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); |
| auto copy1 = |
| module->GetComputationWithName("entry")->GetInstructionWithName( |
| "copy1"); |
| EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace); |
| auto add = module->GetComputationWithName("true_computation") |
| ->GetInstructionWithName("add"); |
| auto add_operand = add->operand(1); |
| EXPECT_EQ(add_operand->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) { |
| // Make sure there is an evict when there is a conditional use followed by |
| // another use. |
| absl::string_view hlo_string = R"( |
| HloModule CondAllocation, is_scheduled=true |
| |
| true_computation { |
| p0 = (f32[3]{0}, f32[3]{0}) parameter(0) |
| gte0 = f32[3]{0} get-tuple-element(p0), index=0 |
| gte1 = f32[3]{0} get-tuple-element(p0), index=1 |
| add0 = f32[3]{0} add(gte0, gte1) |
| neg0 = f32[3]{0} negate(add0) |
| neg1 = f32[3]{0} negate(neg0) |
| neg2 = f32[3]{0} negate(neg1) |
| neg3 = f32[3]{0} negate(neg2) |
| neg4 = f32[3]{0} negate(neg3) |
| neg5 = f32[3]{0} negate(neg4) |
| neg6 = f32[3]{0} negate(neg5) |
| neg7 = f32[3]{0} negate(neg6) |
| neg8 = f32[3]{0} negate(neg7) |
| ROOT neg9 = f32[3]{0} negate(neg8) |
| } |
| |
| false_computation { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg = f32[3]{0} negate(gte) |
| } |
| |
| ENTRY entry { |
| p0 = f32[3]{0} parameter(0) |
| p1 = pred[] parameter(1) |
| copy0 = f32[3]{0} copy(p0) |
| copy1 = f32[3]{0} copy(p0) |
| tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) |
| tuple1 = (f32[3]{0}) tuple(copy0) |
| conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation |
| ROOT add1 = f32[3]{0} add(copy1, conditional) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| if (GetParam()) { |
| // Make sure the copy1->add edge is in alternate memory. Before conditional, |
| // this should be evicted to default memory and neg uses the input from |
| // default memory. |
| auto copy1 = |
| module->GetComputationWithName("entry")->GetInstructionWithName( |
| "copy1"); |
| EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace); |
| auto add0 = module->GetComputationWithName("true_computation") |
| ->GetInstructionWithName("add0"); |
| auto add0_operand = add0->operand(1); |
| EXPECT_EQ(add0_operand->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| auto add1 = |
| module->GetComputationWithName("entry")->GetInstructionWithName("add1"); |
| auto add1_operand = add1->operand(0); |
| EXPECT_EQ(add1_operand->shape().layout().memory_space(), |
| kDefaultMemorySpace); |
| EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) { |
| absl::string_view hlo_string = R"( |
| HloModule CondAllocation, is_scheduled=true |
| |
| true_computation { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg1 = f32[3]{0} negate(gte) |
| } |
| |
| false_computation { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg2 = f32[3]{0} negate(gte) |
| } |
| |
| while_cond { |
| p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) |
| ROOT gte = pred[] get-tuple-element(p0), index=2 |
| } |
| |
| while_body { |
| p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) |
| gte0 = f32[3]{0} get-tuple-element(p0), index=0 |
| gte1 = f32[3]{0} get-tuple-element(p0), index=1 |
| gte2 = pred[] get-tuple-element(p0), index=2 |
| cond_tuple = (f32[3]{0}) tuple(gte0) |
| conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation |
| add = f32[3]{0} add(conditional, gte1) |
| neg0 = f32[3]{0} negate(add) |
| neg1 = f32[3]{0} negate(neg0) |
| ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2) |
| } |
| |
| ENTRY entry { |
| p0 = f32[3]{0} parameter(0) |
| p1 = pred[] parameter(1) |
| copy0 = f32[3]{0} copy(p0) |
| copy1 = f32[3]{0} copy(p0) |
| tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1) |
| while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body |
| ROOT gte = f32[3]{0} get-tuple-element(while), index=1 |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| if (GetParam()) { |
| // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation. |
| // This will force an eviction and a prefetch for while body root. |
| auto copy0 = |
| module->GetComputationWithName("entry")->GetInstructionWithName( |
| "copy0"); |
| EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); |
| auto conditional = module->GetComputationWithName("while_body") |
| ->GetInstructionWithName("conditional"); |
| auto conditional_operand = conditional->operand(1); |
| EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0}) |
| .layout() |
| .memory_space(), |
| kAlternateMemorySpace); |
| auto while_root = |
| module->GetComputationWithName("while_body")->root_instruction(); |
| auto while_root_operand = while_root->operand(0); |
| EXPECT_THAT( |
| while_root_operand, |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace, |
| op::GetTupleElement(op::Parameter(0))))); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, NestedConditional) { |
| absl::string_view hlo_string = R"( |
| HloModule CondAllocation, is_scheduled=true |
| |
| true_computation2 { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg1 = f32[3]{0} negate(gte) |
| } |
| |
| false_computation2 { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg2 = f32[3]{0} negate(gte) |
| } |
| |
| true_computation1 { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| slice = f32[1]{0} slice(gte), slice={[0:1]} |
| bitcast = f32[] bitcast(slice) |
| constant = f32[] constant(0.0) |
| compare = pred[] compare(bitcast, constant), direction=GT |
| ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2 |
| } |
| |
| false_computation1 { |
| p0 = (f32[3]{0}) parameter(0) |
| gte = f32[3]{0} get-tuple-element(p0), index=0 |
| ROOT neg3 = f32[3]{0} negate(gte) |
| } |
| |
| |
| ENTRY entry { |
| p0 = f32[3]{0} parameter(0) |
| p1 = pred[] parameter(1) |
| copy = f32[3]{0} copy(p0) |
| tuple = (f32[3]{0}) tuple(copy) |
| ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1 |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| if (GetParam()) { |
| // Make sure alternate memory allocation gets propagated into both levels of |
| // conditional. |
| auto copy = |
| module->GetComputationWithName("entry")->GetInstructionWithName("copy"); |
| EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); |
| auto neg1_operand = module->GetComputationWithName("true_computation2") |
| ->GetInstructionWithName("neg1") |
| ->operand(0); |
| auto neg2_operand = module->GetComputationWithName("false_computation2") |
| ->GetInstructionWithName("neg2") |
| ->operand(0); |
| auto neg3_operand = module->GetComputationWithName("false_computation1") |
| ->GetInstructionWithName("neg3") |
| ->operand(0); |
| EXPECT_EQ(neg1_operand->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| EXPECT_EQ(neg2_operand->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| EXPECT_EQ(neg3_operand->shape().layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, |
| RequestIdentifierShouldNotBeAllocatedInAlternateMem) { |
| // Ensure that request identifier returned by Send/Recv HLOs are not allocated |
| // in the alternate memory. |
| absl::string_view hlo_string = R"( |
| HloModule SendRecv, is_scheduled=true |
| |
| ENTRY %AddDependency (p: f32[3]) -> f32[3] { |
| %p = f32[3]{0} parameter(0) |
| %after-all = token[] after-all() |
| %recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7 |
| %recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7 |
| %token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1 |
| %data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0 |
| %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2 |
| %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2 |
| ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| |
| for (const HloInstruction* instruction : |
| module->entry_computation()->instructions()) { |
| if (instruction->opcode() == HloOpcode::kSend || |
| instruction->opcode() == HloOpcode::kRecv) { |
| const Shape& request_identifier_shape = |
| ShapeUtil::GetSubshape(instruction->shape(), {1}); |
| EXPECT_NE(request_identifier_shape.layout().memory_space(), |
| kAlternateMemorySpace); |
| } |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, SendDoneShouldHaveSendOperand) { |
| // Ensure that SendDone has only a Send operand. |
| absl::string_view hlo_string = R"( |
| HloModule SendRecv, is_scheduled=true |
| |
| ENTRY %AddDependency (p: f32[3]) -> f32[3] { |
| %p0 = f32[3]{0} parameter(0) |
| %p1 = f32[3]{0} parameter(1) |
| %neg0 = f32[3]{0} negate(f32[3]{0} %p1) |
| %neg1 = f32[3]{0} negate(f32[3]{0} %neg0) |
| %neg2 = f32[3]{0} negate(f32[3]{0} %neg1) |
| %neg3 = f32[3]{0} negate(f32[3]{0} %neg2) |
| %neg4 = f32[3]{0} negate(f32[3]{0} %neg3) |
| %neg5 = f32[3]{0} negate(f32[3]{0} %neg4) |
| %neg6 = f32[3]{0} negate(f32[3]{0} %neg5) |
| %after-all = token[] after-all() |
| %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2 |
| %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2 |
| ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, SendAndSendDoneShouldGetSameAllocation) { |
| // Ensure that Send and SendDone have the same allocation. |
| absl::string_view hlo_string = R"( |
| HloModule SendRecv, is_scheduled=true |
| |
| ENTRY %AddDependency (p: f32[3]) -> f32[3] { |
| %p0 = f32[3]{0} parameter(0) |
| %p1 = f32[3]{0} parameter(1) |
| %after-all = token[] after-all() |
| %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2 |
| %neg0 = f32[3]{0} negate(f32[3]{0} %p1) |
| %neg1 = f32[3]{0} negate(f32[3]{0} %neg0) |
| %neg2 = f32[3]{0} negate(f32[3]{0} %neg1) |
| %neg3 = f32[3]{0} negate(f32[3]{0} %neg2) |
| %neg4 = f32[3]{0} negate(f32[3]{0} %neg3) |
| %neg5 = f32[3]{0} negate(f32[3]{0} %neg4) |
| %neg6 = f32[3]{0} negate(f32[3]{0} %neg5) |
| %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2 |
| ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, |
| /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/4); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, LastUseOpt) { |
| // Test that checks the last use optimization. It uses two buffers that should |
| // be placed in alternate memory. |
| // |
| // +-------+ |
| // / \ |
| // add1--->sub1 +-------->mul2 |
| // mul1===>add2 |
| // |
| // Without the last use optimization, the mul1 buffer will be assigned first |
| // (because it is larger) to offset 0. Then, add1 will be scheduled for the |
| // add1 to sub1 segment. Because offset 0 is available, it will get that |
| // offset. But because offset 0 is not available in the sub1 to mul2 offset, |
| // it will end up in unnecessary copies. With the last use optimization, these |
| // copies can be optimized away. |
| HloComputation::Builder builder(TestName()); |
| Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape shape2 = ShapeUtil::MakeShape(F32, {2, 4}); |
| PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1")); |
| HloInstruction* add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, p0, p0)); |
| HloInstruction* sub1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape1, HloOpcode::kSubtract, p0, add1)); |
| HloInstruction* mul1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape2, HloOpcode::kMultiply, p1, p1)); |
| HloInstruction* add2 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, mul1, p1)); |
| HloInstruction* mul2 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape1, HloOpcode::kMultiply, add1, sub1)); |
| HloInstruction* padding_value = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::Zero(F32))); |
| HloInstruction* padded_mul2 = builder.AddInstruction( |
| HloInstruction::CreatePad(shape2, mul2, padding_value, padding_config)); |
| HloInstruction* add3 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, add2, padded_mul2)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, add1, sub1, mul1, add2, mul2, |
| padding_value, padded_mul2, add3}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| EXPECT_THAT( |
| mul2, |
| op::Multiply( |
| op::Add(op::Parameter(0), op::Parameter(0)), |
| op::Subtract(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::Parameter(0)), |
| op::Add(op::Parameter(0), op::Parameter(0))))); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CopyOrdering) { |
| // Test to make sure the CopyStarts follow the same CopyDone order. The shapes |
| // are picked in increasing order to exploit the fact that heap simulator |
| // processes larger tensors first. This checks the ability of the compiler to |
| // reschedule: |
| // |
| // CS1 CD1 |
| // +--------------+ |
| // +-----------+ |
| // CS2 CD2 |
| // |
| // into: |
| // |
| // CS1 CD1 |
| // +------------+ |
| // +-----------+ |
| // CS2 CD2 |
| HloComputation::Builder builder(TestName()); |
| Shape shape1 = ShapeUtil::MakeShape(F32, {2, 1}); |
| Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2}); |
| Shape shape3 = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape shape4 = ShapeUtil::MakeShape(F32, {2, 4}); |
| PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape3, shape4}); |
| HloInstruction* p0 = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* p4 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape4, p0, 1)); |
| HloInstruction* p3 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape3, p0, 0)); |
| HloInstruction* p2 = |
| builder.AddInstruction(HloInstruction::CreateParameter(2, shape2, "p2")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape1, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p1)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate5)); |
| HloInstruction* padding_value = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::Zero(F32))); |
| HloInstruction* add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, negate6, p1)); |
| HloInstruction* padded_add1 = builder.AddInstruction( |
| HloInstruction::CreatePad(shape2, add1, padding_value, padding_config)); |
| HloInstruction* add2 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, padded_add1, p2)); |
| HloInstruction* padded_add2 = builder.AddInstruction( |
| HloInstruction::CreatePad(shape3, add2, padding_value, padding_config)); |
| HloInstruction* negate7 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape4, HloOpcode::kNegate, p4)); |
| HloInstruction* add3 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, padded_add2, p3)); |
| HloInstruction* padded_add3 = builder.AddInstruction( |
| HloInstruction::CreatePad(shape4, add3, padding_value, padding_config)); |
| HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary( |
| shape4, HloOpcode::kAdd, padded_add3, negate7)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, |
| p4, |
| p3, |
| p2, |
| p1, |
| negate0, |
| negate1, |
| negate2, |
| negate3, |
| negate4, |
| negate5, |
| negate6, |
| padding_value, |
| add1, |
| padded_add1, |
| add2, |
| padded_add2, |
| negate7, |
| add3, |
| padded_add3, |
| add4}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| // Use a large max prefetch interval to force CopyStart/CopyDone right after |
| // the parameters. |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, |
| /*max_prefetch_interval=*/50); |
| |
| // Iterate over the schedule to make sure CopyStart order and the |
| // corresponding CopyDone order match. |
| std::list<const HloInstruction*> copy_starts; |
| for (HloInstruction* instruction : module->schedule() |
| .sequence(module->entry_computation()) |
| .instructions()) { |
| if (instruction->opcode() == HloOpcode::kCopyStart) { |
| copy_starts.push_back(instruction); |
| } |
| if (instruction->opcode() == HloOpcode::kCopyDone) { |
| EXPECT_EQ(copy_starts.front(), instruction->operand(0)); |
| copy_starts.pop_front(); |
| } |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) { |
| // Test to ensure CopyStart/CopyDone is placed only in the entry computation. |
| auto module = CreateNewVerifiedModule(); |
| Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); |
| Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape}); |
| |
| auto cond_builder = HloComputation::Builder("WhileCond"); |
| // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) |
| HloInstruction* cond_param = cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); |
| HloInstruction* cond_iter = cond_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); |
| HloInstruction* cond_limit = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f))); |
| // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) |
| HloInstruction* cond_lt = cond_builder.AddInstruction( |
| HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, |
| cond_limit, ComparisonDirection::kLt)); |
| HloComputation* cond_computation = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto body_builder = HloComputation::Builder("WhileBody"); |
| // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) |
| HloInstruction* body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "body_param")); |
| HloInstruction* body_iter = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1)); |
| HloInstruction* body_data = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, body_param, 0)); |
| HloInstruction* body_iter_increment = body_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f))); |
| HloInstruction* body_iter_next = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment)); |
| HloInstruction* body_data_increment = |
| body_builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}}))); |
| HloInstruction* body_data_mul = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kMultiply, body_data, body_data)); |
| HloInstruction* body_data_add = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, body_data, body_data_increment)); |
| HloInstruction* body_data_next = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, body_data_add, body_data_mul)); |
| HloInstruction* body_out = body_builder.AddInstruction( |
| HloInstruction::CreateTuple({body_data_next, body_iter_next})); |
| HloComputation* body_computation = |
| module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* data = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "param_iter")); |
| HloInstruction* iter = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, scalar_shape, "param_data")); |
| HloInstruction* p2 = |
| builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2")); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({data, iter})); |
| HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( |
| tuple_shape, cond_computation, body_computation, tuple)); |
| HloInstruction* while_data = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, while_op, 0)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, while_data, p2)); |
| HloComputation* entry_computation = |
| module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(cond_computation, |
| {cond_param, cond_iter, cond_limit, cond_lt}); |
| schedule.set_sequence(body_computation, |
| {body_param, body_iter, body_data, body_iter_increment, |
| body_iter_next, body_data_increment, body_data_mul, |
| body_data_add, body_data_next, body_out}); |
| schedule.set_sequence(entry_computation, |
| {iter, data, p2, tuple, while_op, while_data, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), -1, 50); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) { |
| auto module = CreateNewVerifiedModule(); |
| Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); |
| Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); |
| |
| auto call_builder = HloComputation::Builder("Call"); |
| HloInstruction* call_param = call_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "call_param")); |
| HloInstruction* call_param2 = call_builder.AddInstruction( |
| HloInstruction::CreateParameter(1, shape2, "call_param2")); |
| HloInstruction* slice = call_builder.AddInstruction( |
| HloInstruction::CreateSlice(shape, call_param2, {0, 0}, {2, 3}, {1, 1})); |
| HloInstruction* mul = |
| call_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kMultiply, call_param, slice)); |
| HloInstruction* negate0 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul)); |
| HloInstruction* negate1 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* negate7 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); |
| HloInstruction* add0 = |
| call_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, call_param, negate7)); |
| HloComputation* call_computation = |
| module->AddEmbeddedComputation(call_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1")); |
| HloInstruction* add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0)); |
| HloInstruction* add2 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0)); |
| HloInstruction* negate8 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape2, HloOpcode::kNegate, p1)); |
| HloInstruction* call = builder.AddInstruction( |
| HloInstruction::CreateCall(shape, {add1, negate8}, call_computation)); |
| HloInstruction* add3 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, add1)); |
| HloInstruction* add4 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add3)); |
| HloInstruction* add5 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add2, add4)); |
| HloComputation* entry_computation = |
| module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| call_computation, |
| {call_param, call_param2, slice, mul, negate0, negate1, negate2, negate3, |
| negate4, negate5, negate6, negate7, add0}); |
| schedule.set_sequence(entry_computation, |
| {p0, p1, add1, add2, negate8, call, add3, add4, add5}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), -1, 5); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) { |
| auto module = CreateNewVerifiedModule(); |
| Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); |
| Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); |
| |
| auto call_builder = HloComputation::Builder("Call"); |
| HloInstruction* call_param = call_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "call_param")); |
| // Use shape2 here which is larger (scheduled earlier) to occupy alternate |
| // memory at the beginning. This should cause a situation where the prefetch |
| // of add1 later in the function body gets the wrong offset which cannot be |
| // communicated to the outside the function. |
| HloInstruction* iota = |
| call_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0)); |
| HloInstruction* slice = call_builder.AddInstruction( |
| HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1})); |
| HloInstruction* mul = |
| call_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kMultiply, call_param, slice)); |
| HloInstruction* negate0 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul)); |
| HloInstruction* negate1 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* negate7 = call_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); |
| HloInstruction* add0 = |
| call_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, call_param, negate7)); |
| HloComputation* call_computation = |
| module->AddEmbeddedComputation(call_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0)); |
| HloInstruction* add2 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0)); |
| HloInstruction* call = builder.AddInstruction( |
| HloInstruction::CreateCall(shape, {add1}, call_computation)); |
| HloInstruction* add3 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add1)); |
| HloComputation* entry_computation = |
| module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| call_computation, |
| {call_param, iota, slice, mul, negate0, negate1, negate2, negate3, |
| negate4, negate5, negate6, negate7, add0}); |
| schedule.set_sequence(entry_computation, {p0, add1, add2, call, add3}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), -1, 5); |
| } |
| |
| // TODO(berkin): This might be an incorrect input graph, investigate. |
| TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) { |
| auto module = CreateNewVerifiedModule(); |
| Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); |
| Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); |
| |
| auto true_builder = HloComputation::Builder("True"); |
| HloInstruction* true_param = true_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "true_param")); |
| HloInstruction* iota = |
| true_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0)); |
| HloInstruction* slice = true_builder.AddInstruction( |
| HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1})); |
| HloInstruction* mul = |
| true_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kMultiply, true_param, slice)); |
| HloInstruction* negate0 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul)); |
| HloInstruction* negate1 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* negate7 = true_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); |
| HloInstruction* add0 = |
| true_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, true_param, negate7)); |
| HloComputation* true_computation = |
| module->AddEmbeddedComputation(true_builder.Build()); |
| |
| auto false_builder = HloComputation::Builder("False"); |
| HloInstruction* false_param = false_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "false_param")); |
| HloComputation* false_computation = |
| module->AddEmbeddedComputation(false_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0)); |
| HloInstruction* add2 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0)); |
| HloInstruction* pred = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))); |
| HloInstruction* conditional = |
| builder.AddInstruction(HloInstruction::CreateConditional( |
| shape, pred, add1, true_computation, add2, false_computation)); |
| HloInstruction* add3 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, conditional, add1)); |
| HloComputation* entry_computation = |
| module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| true_computation, |
| {true_param, iota, slice, mul, negate0, negate1, negate2, negate3, |
| negate4, negate5, negate6, negate7, add0}); |
| schedule.set_sequence(false_computation, {false_param}); |
| schedule.set_sequence(entry_computation, |
| {p0, add1, add2, pred, conditional, add3}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), -1, 5); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) { |
| // This test reproduces the failure in b/143288178. Given a graph like the |
| // following: |
| // |
| // ... = foo(a) |
| // tuple = tuple((..., a) |
| // ... = while(tuple) { |
| // p = param(0) |
| // a1 = get-tuple-element(p), index=n-1 |
| // ... |
| // ROOT tuple((..., a1)) |
| // } |
| // |
| // If a copy to alternate memory is inserted before foo, and if the size of |
| // the while body is less than max prefetch interval so that the copy-done is |
| // kept in the alternate memory, then we end up referring to the copy-done in |
| // the root instruction of the while loop body. I.e., |
| // |
| // cs = copy-start(a) |
| // ... |
| // cd = copy-done(cs) |
| // ... = foo(cd) |
| // tuple = tuple((..., cd) |
| // ... = while(tuple) { |
| // p = param(0) |
| // a1 = get-tuple-element(p), index=n-1 |
| // ... |
| // ROOT tuple((..., cd)) <-- Error: cd belongs to outside computation. |
| // } |
| // |
| auto module = CreateNewVerifiedModule(); |
| Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); |
| Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); |
| Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({shape, scalar_shape, scalar_shape}); |
| |
| auto cond_builder = HloComputation::Builder("WhileCond"); |
| HloInstruction* cond_param = cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); |
| HloInstruction* cond_iter = cond_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); |
| HloInstruction* cond_limit = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f))); |
| HloInstruction* cond_lt = cond_builder.AddInstruction( |
| HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, |
| cond_limit, ComparisonDirection::kLt)); |
| HloComputation* cond_computation = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto body_builder = HloComputation::Builder("WhileBody"); |
| HloInstruction* body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "body_param")); |
| HloInstruction* body_iter = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1)); |
| HloInstruction* body_data = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, body_param, 0)); |
| HloInstruction* body_iter_increment = body_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f))); |
| HloInstruction* body_iter_next = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment)); |
| HloInstruction* body_data2 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 2)); |
| HloInstruction* body_out = body_builder.AddInstruction( |
| HloInstruction::CreateTuple({body_data, body_iter_next, body_data2})); |
| HloComputation* body_computation = |
| module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* data = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "param_data")); |
| HloInstruction* iter = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, scalar_shape, "param_iter")); |
| HloInstruction* data2 = builder.AddInstruction( |
| HloInstruction::CreateParameter(2, scalar_shape, "param_data2")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* negate7 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); |
| HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( |
| scalar_shape, HloOpcode::kSubtract, iter, data2)); |
| HloInstruction* tuple = builder.AddInstruction( |
| HloInstruction::CreateTuple({negate7, iter, data2})); |
| HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( |
| tuple_shape, cond_computation, body_computation, tuple)); |
| HloInstruction* while_data = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1)); |
| HloInstruction* root = |
| builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub})); |
| HloComputation* entry_computation = |
| module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(cond_computation, |
| {cond_param, cond_iter, cond_limit, cond_lt}); |
| schedule.set_sequence(body_computation, |
| {body_param, body_iter, body_data, body_iter_increment, |
| body_iter_next, body_data2, body_out}); |
| schedule.set_sequence( |
| entry_computation, |
| {iter, data, data2, negate0, negate1, negate2, negate3, negate4, negate5, |
| negate6, negate7, sub, tuple, while_op, while_data, root}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| // Set a large max prefetch interval so that the buffer can be kept in |
| // alternate memory. |
| AssignMemorySpace(module.get(), -1, 20); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) { |
| auto module = CreateNewVerifiedModule(); |
| Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); |
| Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape}); |
| |
| auto cond_builder = HloComputation::Builder("WhileCond"); |
| HloInstruction* cond_param = cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); |
| HloInstruction* cond_iter = cond_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); |
| HloInstruction* cond_limit = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f))); |
| HloInstruction* cond_lt = cond_builder.AddInstruction( |
| HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, |
| cond_limit, ComparisonDirection::kLt)); |
| HloComputation* cond_computation = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto body_builder = HloComputation::Builder("WhileBody"); |
| HloInstruction* body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "body_param")); |
| HloInstruction* body_iter = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1)); |
| HloInstruction* body_data = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, body_param, 0)); |
| HloInstruction* body_negate0 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data)); |
| HloInstruction* body_negate1 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0)); |
| HloInstruction* body_negate2 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1)); |
| HloInstruction* body_negate3 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2)); |
| HloInstruction* body_negate4 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3)); |
| HloInstruction* body_negate5 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4)); |
| HloInstruction* body_negate6 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5)); |
| HloInstruction* body_negate7 = body_builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6)); |
| HloInstruction* body_iter_increment = body_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f))); |
| HloInstruction* body_iter_next = |
| body_builder.AddInstruction(HloInstruction::CreateBinary( |
| scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment)); |
| HloInstruction* body_out = body_builder.AddInstruction( |
| HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7})); |
| HloComputation* body_computation = |
| module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* data = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "param_data")); |
| HloInstruction* iter = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, scalar_shape, "param_iter")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* negate7 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6)); |
| HloInstruction* tuple = builder.AddInstruction( |
| HloInstruction::CreateTuple({data, iter, negate7})); |
| HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( |
| tuple_shape, cond_computation, body_computation, tuple)); |
| HloInstruction* while_data = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, while_op, 0)); |
| HloInstruction* while_data2 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, while_op, 2)); |
| HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, while_data, while_data2)); |
| HloComputation* entry_computation = |
| module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(cond_computation, |
| {cond_param, cond_iter, cond_limit, cond_lt}); |
| schedule.set_sequence( |
| body_computation, |
| {body_param, body_iter, body_data, body_negate0, body_negate1, |
| body_negate2, body_negate3, body_negate4, body_negate5, body_negate6, |
| body_negate7, body_iter_increment, body_iter_next, body_out}); |
| schedule.set_sequence( |
| entry_computation, |
| {iter, data, negate0, negate1, negate2, negate3, negate4, negate5, |
| negate6, negate7, tuple, while_op, while_data, while_data2, root}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| // Pick a large max prefetch interval to ensure all the while inputs are |
| // allocated in the alternate memory. |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, |
| /*max_prefetch_interval=*/25); |
| |
| // Index {0} of the while loop argument is not written inside the while loop, |
| // so it can be trivially placed in the alternate memory space. |
| *ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() = |
| LayoutUtil::MakeLayout( |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| // Index {1} is a scalar, so it is always placed in the default memory. |
| *ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() = |
| LayoutUtil::MakeLayout( |
| /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kDefaultMemorySpace); |
| // Index {2} of the while loop is placed in the default memory. |
| *ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() = |
| LayoutUtil::MakeLayout( |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kDefaultMemorySpace); |
| |
| // Expect the layout for the while loop and its aliased buffers. |
| EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape)); |
| EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape)); |
| EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape)); |
| EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape)); |
| EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, DanglingCopy) { |
| // This situation was encountered in vss, where there is a mismatch in the |
| // memory space in preset assignments and the output graph. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| |
| HloInstruction* p = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* p0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p, 0)); |
| HloInstruction* p1a = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p, 1)); |
| HloInstruction* copy = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kCopy, p1a)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* p1b = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p, 1)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1b)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5, |
| negate6, p1a, copy, p1b, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| auto module = CreateNewVerifiedModule(); |
| |
| HloComputation::Builder fusion_builder("fusion"); |
| HloInstruction* fusion_param0 = fusion_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion_param1 = fusion_builder.AddInstruction( |
| HloInstruction::CreateParameter(1, shape, "p1")); |
| fusion_builder.AddInstruction( |
| HloInstruction::CreateTuple({fusion_param0, fusion_param1})); |
| HloComputation* fusion_computation = |
| module->AddEmbeddedComputation(fusion_builder.Build()); |
| |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion( |
| tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0}, |
| fusion_computation)); |
| HloInstruction* element0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion, 0)); |
| HloInstruction* element1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion, 1)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, fusion, element0, element1, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, TupleInput) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| auto module = CreateNewVerifiedModule(); |
| |
| HloComputation::Builder fusion_builder("fusion"); |
| HloInstruction* fusion_param = fusion_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* fusion_element0 = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion_param, 0)); |
| HloInstruction* fusion_element1 = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion_param, 1)); |
| fusion_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, fusion_element0, fusion_element1)); |
| HloComputation* fusion_computation = |
| module->AddEmbeddedComputation(fusion_builder.Build()); |
| |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1)); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); |
| HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion( |
| shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, negate0, negate1, tuple, fusion}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| auto module = CreateNewVerifiedModule(); |
| |
| HloComputation::Builder fusion0_builder("fusion0"); |
| HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction( |
| HloInstruction::CreateParameter(1, shape, "p1")); |
| fusion0_builder.AddInstruction( |
| HloInstruction::CreateTuple({fusion0_param0, fusion0_param1})); |
| HloComputation* fusion0_computation = |
| module->AddEmbeddedComputation(fusion0_builder.Build()); |
| |
| HloComputation::Builder fusion1_builder("fusion1"); |
| HloInstruction* fusion1_param = fusion1_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0)); |
| HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1)); |
| fusion1_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1)); |
| HloComputation* fusion1_computation = |
| module->AddEmbeddedComputation(fusion1_builder.Build()); |
| |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion( |
| tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0}, |
| fusion0_computation)); |
| HloInstruction* element0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion0, 0)); |
| HloInstruction* element1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion0, 1)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* add0 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1)); |
| HloInstruction* add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, negate6)); |
| HloInstruction* fusion1 = builder.AddInstruction( |
| HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom, |
| {fusion0}, fusion1_computation)); |
| HloInstruction* mul = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add1, fusion1)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, |
| {p0, fusion0, element0, element1, negate0, negate1, negate2, negate3, |
| negate4, negate5, negate6, add0, add1, fusion1, mul}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), -1, 5); |
| EXPECT_THAT(fusion1, |
| op::Fusion(op::Tuple( |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::GetTupleElement(op::Fusion(), 0)), |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::GetTupleElement(op::Fusion(), 1))))); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({shape, tuple_shape}); |
| auto module = CreateNewVerifiedModule(); |
| |
| HloComputation::Builder fusion0_builder("fusion0"); |
| HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction( |
| HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* fusion0_tuple = fusion0_builder.AddInstruction( |
| HloInstruction::CreateTuple({fusion0_param0, fusion0_param1})); |
| fusion0_builder.AddInstruction( |
| HloInstruction::CreateTuple({fusion0_param0, fusion0_tuple})); |
| HloComputation* fusion0_computation = |
| module->AddEmbeddedComputation(fusion0_builder.Build()); |
| |
| HloComputation::Builder fusion1_builder("fusion1"); |
| HloInstruction* fusion1_param = fusion1_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, nested_tuple_shape, "p")); |
| HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0)); |
| HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(tuple_shape, fusion1_param, 1)); |
| HloInstruction* fusion1_element2 = fusion1_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion1_element1, 1)); |
| fusion1_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, fusion1_element0, fusion1_element2)); |
| HloComputation* fusion1_computation = |
| module->AddEmbeddedComputation(fusion1_builder.Build()); |
| |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion( |
| nested_tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0}, |
| fusion0_computation)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* fusion1 = builder.AddInstruction( |
| HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom, |
| {fusion0}, fusion1_computation)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, {p0, fusion0, negate0, negate1, negate2, negate3, negate4, |
| negate5, negate6, fusion1}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), -1, 5); |
| |
| EXPECT_THAT( |
| fusion1, |
| op::Fusion(op::Tuple( |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::GetTupleElement(op::Fusion(), 0)), |
| op::Tuple( |
| op::AsyncCopy( |
| kAlternateMemorySpace, kDefaultMemorySpace, |
| op::GetTupleElement(op::GetTupleElement(op::Fusion(), 1), 0)), |
| op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, |
| op::GetTupleElement( |
| op::GetTupleElement(op::Fusion(), 1), 1)))))); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| auto module = CreateNewVerifiedModule(); |
| |
| HloComputation::Builder fusion0_builder("fusion0"); |
| HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction( |
| HloInstruction::CreateParameter(1, shape, "p1")); |
| fusion0_builder.AddInstruction( |
| HloInstruction::CreateTuple({fusion0_param0, fusion0_param1})); |
| HloComputation* fusion0_computation = |
| module->AddEmbeddedComputation(fusion0_builder.Build()); |
| |
| HloComputation::Builder fusion1_builder("fusion1"); |
| HloInstruction* fusion1_param = fusion1_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0)); |
| HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1)); |
| fusion1_builder.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1)); |
| HloComputation* fusion1_computation = |
| module->AddEmbeddedComputation(fusion1_builder.Build()); |
| |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion( |
| tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0}, |
| fusion0_computation)); |
| HloInstruction* fusion1 = builder.AddInstruction( |
| HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom, |
| {fusion0}, fusion1_computation)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, fusion0, fusion1}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| EXPECT_THAT(fusion1, op::Fusion(op::Fusion())); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) { |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); |
| HloInstruction* p = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p")); |
| HloInstruction* p0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p, 0)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, p, 1)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1)); |
| HloInstruction* negate7 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add)); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({p0, add})); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5, |
| negate6, p1, add, negate7, tuple}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| // Make input {0} alias with output {0} and input {1} alias with output {1}. |
| TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0})); |
| TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1})); |
| |
| AssignMemorySpace(module.get()); |
| |
| // Make sure the input is in the default memory space. |
| EXPECT_EQ(p->shape().tuple_shapes(0).layout().memory_space(), |
| kDefaultMemorySpace); |
| EXPECT_EQ(p->shape().tuple_shapes(1).layout().memory_space(), |
| kDefaultMemorySpace); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CostAnalysis) { |
| // This is mostly a smoke test since it's difficult and brittle to work out |
| // the cost of the HLO instructions. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* add = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1)); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2, |
| negate3, negate4, negate5, negate6, add}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpaceUsingCostAnalysis(module.get()); |
| // Parameters are in the default memory space. |
| EXPECT_THAT(p0, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(p1, op::ShapeWithLayout(shape)); |
| // Negate instructions are in the alternate memory space (1). |
| Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {2, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) { |
| // This test is carefully crafted to force only negates to be allocated to the |
| // alternate memory. The graph consists of interleaving negate and tanh |
| // operations: |
| // |
| // +------+ +-------+ +----- |
| // / \ / \ / |
| // negate tanh negate tanh negate tanh |
| // \ / \ / |
| // +--------+ +---------+ |
| // |
| // The alternate memory is sized to fit only two f32[4,3] tensors at a time. |
| // Also, transcendentals are made to be lower bandwidth than FLOPs. So, the |
| // MemoryBoundednessBufferIntervalCompare should prioritize the negates, which |
| // are more memory bound. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {4, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* p1 = |
| builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); |
| HloInstruction* tanh0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1)); |
| HloInstruction* tanh1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* tanh2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* tanh3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* tanh4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh3)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({tanh4, negate4})); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, |
| {p0, p1, tanh0, negate0, tanh1, negate1, tanh2, negate2, |
| tanh3, negate3, tanh4, negate4, tuple}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpaceUsingCostAnalysis(module.get()); |
| // Parameters are in the default memory space. |
| EXPECT_THAT(p0, op::ShapeWithLayout(shape)); |
| EXPECT_THAT(p1, op::ShapeWithLayout(shape)); |
| Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {4, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kDefaultMemorySpace); |
| // Expect only negates to be in alternate memory space. Not all might fit but |
| // make sure at least one does. |
| std::vector<HloInstruction*> negate_instructions = {negate0, negate1, negate2, |
| negate3, negate4}; |
| int64 num_negates_in_alternate_mem = absl::c_count_if( |
| negate_instructions, [&](const HloInstruction* instruction) { |
| return instruction->shape().layout().memory_space() == |
| kAlternateMemorySpace; |
| }); |
| EXPECT_GE(num_negates_in_alternate_mem, 1); |
| EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem)); |
| EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem)); |
| EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem)); |
| EXPECT_THAT(tanh3, op::ShapeWithLayout(shape_in_default_mem)); |
| EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) { |
| Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); |
| Shape f32v1 = ShapeUtil::MakeShape(F32, {1}); |
| Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1}); |
| auto module = CreateNewVerifiedModule("SimpleWhile"); |
| HloSchedule schedule(module.get()); |
| |
| // A simple compare-to-limit (x < 4) computation for a While. |
| // |
| // condition: |
| // const4[s32] -----------------------------------\ |
| // \ |
| // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than |
| // |
| HloComputation* cond_computation; |
| { |
| auto builder = HloComputation::Builder("WhileCond"); |
| auto const4 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4))); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, t_s32_f32v1, "x")); |
| auto index = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); |
| auto compare = builder.AddInstruction( |
| HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index, |
| const4, ComparisonDirection::kLt)); |
| cond_computation = module->AddEmbeddedComputation(builder.Build()); |
| schedule.set_sequence(cond_computation, {const4, param, index, compare}); |
| } |
| |
| // Builds a simple body computation for a While. |
| // |
| // body: |
| // constv[f32[1]] --------------------------------------\ |
| // \ |
| // /--- get-tuple-elementv[1] --- addv ---\ |
| // param[(s32,f32[1])] ---| tuple |
| // \--- get-tuple-elementc[0] --- addc ---/ |
| // / |
| // const1[s32] -----------------------------------------/ |
| // |
| HloComputation* body_computation; |
| { |
| auto builder = HloComputation::Builder("WhileBody"); |
| auto const1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1))); |
| auto constv = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.1f}))); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, t_s32_f32v1, "x")); |
| auto indexc = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(const1->shape(), param, 0)); |
| auto addc = builder.AddInstruction(HloInstruction::CreateBinary( |
| indexc->shape(), HloOpcode::kAdd, indexc, const1)); |
| auto indexv = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(constv->shape(), param, 1)); |
| auto addv = builder.AddInstruction(HloInstruction::CreateBinary( |
| constv->shape(), HloOpcode::kAdd, indexv, constv)); |
| auto tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({addc, addv})); |
| body_computation = module->AddEmbeddedComputation(builder.Build()); |
| schedule.set_sequence(body_computation, {const1, constv, param, indexc, |
| addc, indexv, addv, tuple}); |
| } |
| |
| // This tests a simple while loop where the parameters are aliased with the |
| // output buffers. |
| auto builder = HloComputation::Builder("SimpleWhile"); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, t_s32_f32v1, "param")); |
| auto gte0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(s32, param, 0)); |
| auto gte1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(f32v1, param, 1)); |
| auto tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); |
| auto while0 = builder.AddInstruction(HloInstruction::CreateWhile( |
| t_s32_f32v1, cond_computation, body_computation, tuple)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, |
| /*max_prefetch_interval=*/50); |
| |
| // Ensure all parameters and while are placed in default memory. |
| Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {4, 6}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kDefaultMemorySpace); |
| Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout( |
| xla::S32, {}, |
| /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kDefaultMemorySpace); |
| Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {1}, |
| /*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kDefaultMemorySpace); |
| Shape t_s32_f32v1_in_default_mem = |
| ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem}); |
| EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem)); |
| EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem)); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) { |
| // This test reproduces an eviction scheduling bug where evictions to default |
| // memory can happen later than intended, causing memory corruption. This test |
| // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3] |
| // tensors instead, so at most two tensors should fit in the alternate memory |
| // space at a given time. We have a number of redundant operations |
| // (tanh_redundant ops) that do not have users. The bug was due to |
| // SimplifyGraph removing dead instructions, and removing them from the |
| // schedule. However, the CopyStart/CopyDone insertion relies on the schedule |
| // indexes, so they could be inserted too late. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {4, 3}); |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| HloInstruction* tanh0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* tanh_redundant0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* tanh_redundant1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* tanh_redundant2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* tanh_redundant3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* tanh_redundant4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* tanh_redundant5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* tanh_redundant6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0)); |
| HloInstruction* tanh1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* tanh2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* tanh3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* tuple = builder.AddInstruction( |
| HloInstruction::CreateTuple({tanh3, negate3, tanh0})); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence( |
| computation, |
| {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2, |
| tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6, |
| negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpaceUsingCostAnalysis(module.get()); |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, |
| HloAliasAnalysis::Run(module.get())); |
| TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range, |
| HloLiveRange::Run(module->schedule(), *alias_analysis, |
| module->entry_computation())); |
| |
| std::vector<int> num_live_buffers_in_alternate_mem( |
| hlo_live_range->flattened_instruction_sequence().size() + 1, 0); |
| |
| // Go through each value and for those that are allocated in the alternate |
| // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for |
| // every time step that they are live. |
| for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { |
| const Shape& shape = value->shape(); |
| if (!shape.has_layout() || |
| shape.layout().memory_space() == kDefaultMemorySpace) { |
| continue; |
| } |
| |
| HloLiveRange::TimeBound time_bound = |
| hlo_live_range->buffer_live_ranges().at(value); |
| for (int i = time_bound.start; i <= time_bound.end; ++i) { |
| ++num_live_buffers_in_alternate_mem[i]; |
| } |
| } |
| |
| // The test memory can at most hold two f32[4,3] buffers at a time. If there |
| // is more than that, it means we have memory corruption. |
| for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) { |
| EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, |
| InputOutputsInAlternateMemShouldntBeAssigned) { |
| // When input/outputs are marked to be in the alternate memory (e.g. |
| // go/tpu-fast-mem-inference), do not allocate those and assume they will live |
| // in the alternate memory for the entire computation. The BufferAssignment |
| // pass, which is run after this, will allocate those buffers. |
| HloComputation::Builder builder(TestName()); |
| Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); |
| Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( |
| F32, {2, 3}, |
| /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, |
| kAlternateMemorySpace); |
| // p0 is in the default memory space. |
| HloInstruction* p0 = |
| builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); |
| // p1 is in the alternate memory space. |
| HloInstruction* p1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, shape_in_alternate_mem, "p1")); |
| HloInstruction* negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); |
| HloInstruction* negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); |
| HloInstruction* negate2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); |
| HloInstruction* negate3 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); |
| HloInstruction* negate4 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); |
| HloInstruction* negate5 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); |
| HloInstruction* negate6 = builder.AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); |
| HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( |
| shape_in_alternate_mem, HloOpcode::kAdd, negate6, p1)); |
| // Index {0} of the root instruction is in the alternate memory space, index |
| // {1} is in the default memory space. |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({add, negate5})); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, |
| {p0, p1, negate0, negate1, negate2, negate3, negate4, |
| negate5, negate6, add, tuple}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| std::unique_ptr<PresetAssignments> preset_assignments = |
| AssignMemorySpace(module.get()); |
| |
| // Ensure that p1 is in the alternate memory and add, which has p1 as an |
| // operand, has a direct dependency to p1 (no CopyStart/CopyDone). |
| EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem)); |
| EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1))); |
| // Make sure add is still in the alternate memory space. |
| EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem)); |
| |
| // Check the preset assignments and ensure the inputs/outputs in the alternate |
| // memory space aren't in the preset assignments. Inputs/outputs in the |
| // alternate memory space are left to BufferAssignment to be allocated. |
| for (const auto& position_and_chunk : preset_assignments->chunks()) { |
| const HloPosition& position = position_and_chunk.first; |
| EXPECT_NE(position.instruction, p1); |
| EXPECT_NE(position.instruction, add); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) { |
| // Tests a memory corruption bug where the allocated chunk overlaps with a |
| // pending chunk. To test this, we provide a new buffer interval compare where |
| // we prioritize the allocation of sine, cosine, and tanh to create the |
| // situation: |
| // |
| // Max memory |
| // ------------------------------------------- |
| // +------------+ |
| // | b | |
| // +------------+ |
| // +-------+ |
| // | | |
| // | | |
| // | a | |
| // | | +------------+ |
| // | | | n | |
| // +-------+ +------------+ |
| // ------------------------------------------- |
| // Min memory time -> |
| // |
| // |
| // Then allocating for buffer d, we have these two prefetch buffers |
| // overlapping: |
| // |
| // Max memory |
| // ------------------------------------------- |
| // +------------+ +----------+ |
| // | b | | prefetch | |
| // +------------+ | for o | |
| // +-------+ +---------+ | |
| // | | | | | | |
| // | | | | | | |
| // | a | | +----|-----+ |
| // | | | prefetch| +------------+ |
| // | | | for m | | n | |
| // +-------+ +---------+ +------------+ |
| // ------------------------------------------- |
| // Min memory time -> |
| // |
| absl::string_view hlo_string = R"( |
| HloModule bug, is_scheduled=true |
| |
| ENTRY %Entry { |
| %param0 = f32[8,3] parameter(0) |
| %param1 = f32[2,4] parameter(1) |
| %a = f32[8,3] sine(%param0) |
| %b = f32[2,4] cosine(%param1) |
| %d = f32[8,3] tanh(%a) |
| %c = f32[8,3] negate(%a) |
| %e = f32[2,4] negate(%b) |
| %f = f32[2,4] negate(%e) |
| %g = f32[2,4] negate(%f) |
| %h = f32[2,4] negate(%g) |
| %i = f32[2,4] negate(%h) |
| %j = f32[2,4] negate(%i) |
| %k = f32[2,4] negate(%j) |
| %l = f32[2,4] negate(%k) |
| %m = f32[8,3] negate(%d) |
| %n = f32[2,4] sine(%l) |
| %o = f32[8,3] negate(%d) |
| %p = f32[2,4] negate(%n) |
| %q = f32[8,3] negate(%m) |
| ROOT %tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(%p, %q, %o) |
| } |
| )"; |
| |
| MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = |
| [](const MemorySpaceAssignment::BufferInterval& a, |
| const MemorySpaceAssignment::BufferInterval& b) { |
| auto get_opcode_priority = [](const HloOpcode& opcode) { |
| switch (opcode) { |
| case HloOpcode::kSin: |
| return 0; |
| case HloOpcode::kCos: |
| return 1; |
| case HloOpcode::kTanh: |
| return 2; |
| default: |
| return 3; |
| } |
| }; |
| |
| return get_opcode_priority(a.buffer->defining_instruction()->opcode()) < |
| get_opcode_priority(b.buffer->defining_instruction()->opcode()); |
| }; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| |
| InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); |
| AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, |
| buffer_interval_compare, &prefetch_interval_picker); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, Determinism) { |
| // Run memory space assignment a few times to make sure every time it compiles |
| // to the same thing. |
| std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule(); |
| |
| AssignMemorySpace(module.get()); |
| std::string module_str = module->ToString(); |
| |
| for (int i = 0; i < 10; ++i) { |
| std::unique_ptr<HloModule> other_module = CreateEvictAndPrefetchModule(); |
| AssignMemorySpace(other_module.get()); |
| EXPECT_EQ(module_str, other_module->ToString()); |
| } |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, |
| MemorySpaceAssignmentTest, |
| ::testing::Values(false, true)); |
| |
| using AsynchronousCopyOrderingTest = ::testing::Test; |
| |
| TEST_F(AsynchronousCopyOrderingTest, Simple) { |
| // Given asynchronous copies like the following, ensure the pipelining order |
| // is maintained (earlier start time must have earlier end time). |
| // 3,11 +-------+ OK |
| // 1,8 +------+ OK |
| // 5,14 +--------+ OK |
| // 7,14 +------+ OK |
| // 2,16 +-------------+ Violate |
| // 9,12 +--+ Violate |
| // 6,17 +----------+ Violate |
| // 5,13 +-------+ OK (same start as 5,14) |
| // 5,14 +--------+ OK (same as 5,14) |
| auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate; |
| AsynchronousCopyOrdering ordering; |
| EXPECT_FALSE(ordering.ViolatesOrdering(3, 11)); |
| ordering.AddCopy({3, 11, alternate_mem_space}); |
| EXPECT_FALSE(ordering.ViolatesOrdering(1, 8)); |
| ordering.AddCopy({1, 8, alternate_mem_space}); |
| EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); |
| ordering.AddCopy({5, 14, alternate_mem_space}); |
| EXPECT_FALSE(ordering.ViolatesOrdering(7, 14)); |
| ordering.AddCopy({7, 14, alternate_mem_space}); |
| EXPECT_TRUE(ordering.ViolatesOrdering(2, 16)); |
| EXPECT_TRUE(ordering.ViolatesOrdering(9, 12)); |
| EXPECT_TRUE(ordering.ViolatesOrdering(6, 17)); |
| EXPECT_FALSE(ordering.ViolatesOrdering(5, 13)); |
| ordering.AddCopy({5, 13, alternate_mem_space}); |
| EXPECT_FALSE(ordering.ViolatesOrdering(5, 14)); |
| ordering.AddCopy({5, 14, alternate_mem_space}); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) { |
| HloComputation::Builder builder(TestName()); |
| |
| constexpr int kBatch = 8; |
| constexpr int kFeature = 8; |
| constexpr int kOutput = 2; |
| |
| auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); |
| auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); |
| auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); |
| auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); |
| HloInstruction* param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| |
| auto lhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); |
| auto rhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| auto dot = builder.AddInstruction(HloInstruction::CreateDot( |
| result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {param, lhs, rhs, dot}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| auto cross_program_prefetches = module->CrossProgramPrefetches(); |
| EXPECT_EQ(cross_program_prefetches.size(), 1); |
| if (!cross_program_prefetches.empty()) { |
| EXPECT_EQ(cross_program_prefetches[0].first, 0); |
| EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) { |
| HloComputation::Builder builder(TestName()); |
| |
| constexpr int kBatch = 8; |
| constexpr int kFeature = 8; |
| constexpr int kOutput = 2; |
| |
| auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); |
| auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature}); |
| auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); |
| auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); |
| auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); |
| HloInstruction* param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| |
| auto lhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); |
| auto rhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); |
| |
| auto bitcast = |
| builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs)); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| auto dot = builder.AddInstruction(HloInstruction::CreateDot( |
| result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2))); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {param, lhs, rhs, bitcast, dot}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| auto cross_program_prefetches = module->CrossProgramPrefetches(); |
| EXPECT_EQ(cross_program_prefetches.size(), 1); |
| if (!cross_program_prefetches.empty()) { |
| EXPECT_EQ(cross_program_prefetches[0].first, 0); |
| EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); |
| } |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) { |
| HloComputation::Builder builder(TestName()); |
| |
| constexpr int kBatch = 8; |
| constexpr int kFeature = 8; |
| constexpr int kOutput = 2; |
| |
| auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); |
| auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); |
| auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); |
| auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); |
| auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape}); |
| HloInstruction* param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0")); |
| |
| auto gte = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(tuple_shape, param, 0)); |
| |
| auto lhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0)); |
| auto rhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1)); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| auto dot = builder.AddInstruction(HloInstruction::CreateDot( |
| result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {param, gte, lhs, rhs, dot}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| auto cross_program_prefetches = module->CrossProgramPrefetches(); |
| EXPECT_EQ(cross_program_prefetches.size(), 0); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) { |
| HloComputation::Builder builder(TestName()); |
| |
| constexpr int kFeature = 8; |
| constexpr int kOutput = 2; |
| |
| auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); |
| HloInstruction* param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, rhs_shape, "p0")); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {param}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| auto cross_program_prefetches = module->CrossProgramPrefetches(); |
| EXPECT_EQ(cross_program_prefetches.size(), 0); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) { |
| HloComputation::Builder builder(TestName()); |
| |
| constexpr int kBatch = 8; |
| constexpr int kFeature = 8; |
| constexpr int kOutput = 8; |
| |
| auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); |
| auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); |
| auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); |
| auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); |
| HloInstruction* param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| |
| auto lhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); |
| auto rhs = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); |
| |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| auto dot = builder.AddInstruction(HloInstruction::CreateDot( |
| result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {param, lhs, rhs, dot}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| auto cross_program_prefetches = module->CrossProgramPrefetches(); |
| EXPECT_EQ(cross_program_prefetches.size(), 0); |
| } |
| |
| TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) { |
| HloComputation::Builder builder(TestName()); |
| |
| constexpr int kBatch = 2; |
| constexpr int kFeature = 2; |
| constexpr int kOutput = 2; |
| |
| auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); |
| auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); |
| auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); |
| auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); |
| |
| auto module = CreateNewVerifiedModule(); |
| HloComputation::Builder fusion_builder("fusion"); |
| { |
| HloInstruction* param = fusion_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| auto lhs = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); |
| auto rhs = fusion_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); |
| DotDimensionNumbers dot_dnums; |
| dot_dnums.add_lhs_contracting_dimensions(1); |
| dot_dnums.add_rhs_contracting_dimensions(0); |
| auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot( |
| result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); |
| (void)dot; |
| } |
| HloComputation* fusion_computation = |
| module->AddEmbeddedComputation(fusion_builder.Build()); |
| |
| auto activations = builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}}))); |
| auto weights = builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}}))); |
| HloInstruction* tuple = builder.AddInstruction( |
| HloInstruction::CreateTuple({activations, weights})); |
| HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion( |
| result_shape, HloInstruction::FusionKind::kCustom, {tuple}, |
| fusion_computation)); |
| |
| HloComputation* computation = module->AddEntryComputation(builder.Build()); |
| |
| HloSchedule schedule(module.get()); |
| schedule.set_sequence(computation, {activations, weights, tuple, fusion}); |
| TF_CHECK_OK(module->set_schedule(schedule)); |
| |
| AssignMemorySpace(module.get()); |
| |
| auto cross_program_prefetches = module->CrossProgramPrefetches(); |
| EXPECT_EQ(cross_program_prefetches.size(), 0); |
| } |
| |
| // For testing purposes, we define a cost analysis where we can control the |
| // elapsed times of each HLO and asynchronous copy. |
| class FakeMemorySpaceAssignmentCostAnalysis |
| : public MemorySpaceAssignmentCostAnalysis { |
| public: |
| static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>> |
| Create(const HloCostAnalysis& cost_analysis, const HloModule& module) { |
| TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); |
| TF_ASSIGN_OR_RETURN(auto hlo_live_range, |
| HloLiveRange::Run(module.schedule(), *alias_analysis, |
| module.entry_computation())); |
| auto call_graph = CallGraph::Build(&module); |
| return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis( |
| cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1, |
| /*alternate_mem_bandwidth_bytes_per_second=*/1, |
| std::move(alias_analysis), std::move(hlo_live_range), |
| std::move(call_graph))); |
| } |
| |
| float GetInstructionElapsed( |
| const HloInstruction& instruction) const override { |
| return 1.0; |
| } |
| |
| float GetInstructionElapsedInAlternateMemory( |
| const HloInstruction& instruction, |
| absl::optional<int64> operand_in_alternate_mem, |
| bool output_in_alternate_mem) const override { |
| if (operand_in_alternate_mem) { |
| return 0.5; |
| } else { |
| return 1.0; |
| } |
| } |
| |
| float GetAsyncCopyElapsed(const Shape& shape) const override { return 3.0; } |
| |
| protected: |
| FakeMemorySpaceAssignmentCostAnalysis( |
| const HloCostAnalysis& cost_analysis, |
| float async_copy_bandwidth_bytes_per_second, |
| float alternate_mem_bandwidth_bytes_per_second, |
| std::unique_ptr<HloAliasAnalysis> alias_analysis, |
| std::unique_ptr<HloLiveRange> hlo_live_range, |
| std::unique_ptr<CallGraph> call_graph) |
| : MemorySpaceAssignmentCostAnalysis( |
| cost_analysis, async_copy_bandwidth_bytes_per_second, |
| alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis), |
| std::move(hlo_live_range), std::move(call_graph)) {} |
| }; |
| |
| using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; |
| |
| TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { |
| absl::string_view hlo_string = R"( |
| HloModule bug, is_scheduled=true |
| |
| ENTRY Entry { |
| param0 = f32[2,4] parameter(0) |
| a = f32[2,4] negate(param0) |
| b = f32[2,4] negate(a) |
| c = f32[2,4] negate(b) |
| d = f32[2,4] negate(c) |
| e = f32[2,4] negate(d) |
| f = f32[2,4] negate(e) |
| g = f32[2,4] negate(f) |
| h = f32[2,4] negate(g) |
| i = f32[2,4] negate(h) |
| j = f32[2,4] negate(i) |
| k = f32[2,4] negate(j) |
| l = f32[2,4] negate(k) |
| m = f32[2,4] negate(l) |
| n = f32[2,4] negate(m) |
| o = f32[2,4] negate(n) |
| p = f32[2,4] negate(o) |
| q = f32[2,4] negate(p) |
| r = f32[2,4] negate(q) |
| s = f32[2,4] negate(r) |
| t = f32[2,4] negate(s) |
| u = f32[2,4] negate(t) |
| ROOT v = f32[2,4] add(u, param0) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| |
| HloCostAnalysis hlo_cost_analysis(ShapeSize); |
| TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, |
| FakeMemorySpaceAssignmentCostAnalysis::Create( |
| hlo_cost_analysis, *module)); |
| CostAnalysisPrefetchIntervalPicker interval_picker( |
| *cost_analysis, |
| /*min_async_copy_to_overlap_ratio=*/1.0, |
| /*max_async_copy_to_overlap_ratio=*/4.0, |
| /*preferred_async_copy_to_overlap_ratio=*/2.0); |
| |
| HloInstruction* root = module->entry_computation()->root_instruction(); |
| const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; |
| interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22); |
| |
| // Expect that the first interval is (15, 22), which has elapsed time of 6.0, |
| // twice of the async copy elased (3.0). Then we expect that intervals will be |
| // visited in alternating increasing and decreasing orders until hitting the |
| // min and max async copy overlap ratios, which are the intervals (18, 22) |
| // and (9, 22) respectively. |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 15); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 16); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 14); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 17); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 13); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached. |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 12); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 11); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 10); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached. |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_TRUE(interval_picker.Done()); |
| |
| // Expect that if the time between start_time and end_time is too short, there |
| // won't be any available intervals. |
| interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_TRUE(interval_picker.Done()); |
| } |
| |
| TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { |
| absl::string_view hlo_string = R"( |
| HloModule bug, is_scheduled=true |
| |
| while_condition { |
| param1 = (f32[2,4]) parameter(0) // 19 |
| ROOT cond = pred[] constant(true) // 20 |
| } |
| |
| while_body { |
| param2 = (f32[2,4]) parameter(0) // 21 |
| gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22 |
| add = f32[2,4] add(gte2, gte2) // 23 |
| ROOT tuple2 = (f32[2,4]) tuple(add) // 24 |
| } |
| |
| ENTRY Entry { |
| param0 = f32[2,4] parameter(0) // 0 |
| a = f32[2,4] negate(param0) // 1 |
| b = f32[2,4] negate(a) // 2 |
| c = f32[2,4] negate(b) // 3 |
| d = f32[2,4] negate(c) // 4 |
| e = f32[2,4] negate(d) // 5 |
| f = f32[2,4] negate(e) // 6 |
| g = f32[2,4] negate(f) // 7 |
| h = f32[2,4] negate(g) // 8 |
| i = f32[2,4] negate(h) // 9 |
| j = f32[2,4] negate(i) // 10 |
| k = f32[2,4] negate(j) // 11 |
| l = f32[2,4] negate(k) // 12 |
| m = f32[2,4] negate(l) // 13 |
| n = f32[2,4] negate(m) // 14 |
| o = f32[2,4] negate(n) // 15 |
| p = f32[2,4] negate(o) // 16 |
| q = f32[2,4] negate(p) // 17 |
| tuple = (f32[2,4]) tuple(q) // 18 |
| while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25 |
| gte1 = f32[2,4] get-tuple-element(while), index=0 // 26 |
| r = f32[2,4] negate(gte1) // 27 |
| s = f32[2,4] negate(r) // 28 |
| t = f32[2,4] negate(s) // 29 |
| u = f32[2,4] negate(t) // 30 |
| ROOT v = f32[2,4] add(u, param0) // 31 |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(auto module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| |
| HloCostAnalysis hlo_cost_analysis(ShapeSize); |
| TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, |
| FakeMemorySpaceAssignmentCostAnalysis::Create( |
| hlo_cost_analysis, *module)); |
| CostAnalysisPrefetchIntervalPicker interval_picker( |
| *cost_analysis, |
| /*min_async_copy_to_overlap_ratio=*/1.0, |
| /*max_async_copy_to_overlap_ratio=*/12.0, |
| /*preferred_async_copy_to_overlap_ratio=*/2.0); |
| |
| HloInstruction* root = module->entry_computation()->root_instruction(); |
| const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; |
| interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31); |
| |
| // Because there are while loop computations between [19, 24], we ensure that |
| // the interval picker avoids this interval. |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 25); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 26); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 18); |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached. |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached. |
| LOG(INFO) << interval_picker.ToDebugString(); |
| EXPECT_TRUE(interval_picker.Done()); |
| } |
| |
| } // namespace |
| } // namespace xla |