| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/copy_insertion.h" |
| |
| #include <set> |
| |
| #include "tensorflow/compiler/xla/debug_options_flags.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_matchers.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/hlo_runner.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/test.h" |
| #include "tensorflow/compiler/xla/test_helpers.h" |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/platform/test_benchmark.h" |
| |
| namespace op = xla::testing::opcode_matchers; |
| |
| namespace xla { |
| namespace { |
| |
| using ::testing::UnorderedElementsAre; |
| |
| int64 CountCopies(const HloComputation& computation) { |
| int64 count = 0; |
| for (const auto& instruction : computation.instructions()) { |
| if (instruction->opcode() == HloOpcode::kCopy) { |
| count++; |
| } |
| } |
| return count; |
| } |
| |
| int64 CountCopies(const HloModule& module) { |
| int64 count = 0; |
| for (const auto& computation : module.computations()) { |
| count += CountCopies(*computation); |
| } |
| return count; |
| } |
| |
| int64 CountControlEdges(const HloComputation& computation) { |
| int64 count = 0; |
| for (const auto& instruction : computation.instructions()) { |
| count += instruction->control_successors().size(); |
| } |
| return count; |
| } |
| |
| int64 CountControlEdges(const HloModule& module) { |
| int64 count = 0; |
| for (const auto& computation : module.computations()) { |
| count += CountControlEdges(*computation); |
| } |
| return count; |
| } |
| |
| class CopyInsertionTest : public HloTestBase { |
| protected: |
| void InsertCopies(HloModule* module) { |
| CopyInsertion copy_insertion; |
| ASSERT_IS_OK(copy_insertion.Run(module).status()); |
| } |
| |
| const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); |
| }; |
| |
| TEST_F(CopyInsertionTest, SingleParameter) { |
| // Computation is a single parameter passed into a tuple. The parameter should |
| // be copied before entering the tuple. |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* x = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({x})); |
| |
| EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(x))); |
| } |
| |
| TEST_F(CopyInsertionTest, SingleConstant) { |
| // Computation is a single constant passed into a tuple. The parameter should |
| // be copied before entering the tuple. |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* constant = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| HloInstruction* tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({constant})); |
| |
| EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(constant))); |
| } |
| |
| TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { |
| // Verify that kCopy instructions which change layout and exist before |
| // copy-insertion remain in the graph after copy-insertion. |
| auto module = CreateNewVerifiedModule(); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* constant = |
| builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}}))); |
| auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); |
| Layout reversed_layout = |
| LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); |
| Shape copy_shape = constant->shape(); |
| *copy_shape.mutable_layout() = reversed_layout; |
| HloInstruction* copy_1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); |
| HloInstruction* copy_2 = builder.AddInstruction( |
| HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); |
| HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( |
| constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); |
| builder.AddInstruction( |
| HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add)); |
| |
| module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(CountCopies(*module), 3); |
| |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 2); |
| |
| EXPECT_EQ(module->entry_computation()->root_instruction(), add); |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))); |
| } |
| |
| TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { |
| // Create a computation with more than one constant and parameter. Only one of |
| // each constant/parameter is pointed to by the output tuple. Only these |
| // instructions should be copied. |
| auto builder = HloComputation::Builder(TestName()); |
| |
| HloInstruction* constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| HloInstruction* constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); |
| |
| HloInstruction* x = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); |
| HloInstruction* y = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y")); |
| |
| HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( |
| ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y)); |
| |
| builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 2); |
| |
| EXPECT_THAT( |
| module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y))); |
| } |
| |
| TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { |
| // Create a computation using select which has an ambiguous points-to set for |
| // the computation result. Verify that copies are added properly. |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| HloInstruction* constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); |
| HloInstruction* constant3 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0))); |
| |
| HloInstruction* tuple1 = builder.AddInstruction( |
| HloInstruction::CreateTuple({constant1, constant2})); |
| HloInstruction* tuple2 = builder.AddInstruction( |
| HloInstruction::CreateTuple({constant3, constant2})); |
| |
| HloInstruction* pred = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| builder.AddInstruction(HloInstruction::CreateTernary( |
| tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); |
| |
| EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); |
| EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); |
| EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| HloInstruction* old_root = module->entry_computation()->root_instruction(); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 2); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(op::GetTupleElement(old_root)), |
| op::Copy(op::GetTupleElement(old_root)))); |
| } |
| |
| TEST_F(CopyInsertionTest, BitcastParameter) { |
| // The output of a bitcast is its operand (same buffer), so a bitcast |
| // parameter feeding the result must have a copy added. |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* x = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); |
| HloInstruction* bitcast = builder.AddInstruction( |
| HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x)); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); |
| |
| HloInstruction* old_root = module->entry_computation()->root_instruction(); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Copy(old_root)); |
| } |
| |
| TEST_F(CopyInsertionTest, BitcastConstant) { |
| // The output of a bitcast is its operand (same buffer), so a bitcast |
| // constant feeding the result must have a copy added. |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* constant = |
| builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR1<float>({1.0, 42.0}))); |
| HloInstruction* bitcast = builder.AddInstruction( |
| HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2}), constant)); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); |
| |
| HloInstruction* old_root = module->entry_computation()->root_instruction(); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Copy(old_root)); |
| } |
| |
| TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { |
| // Same as BitcastParameter, but the bitcast is wrapped in a tuple. |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* x = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); |
| HloInstruction* bitcast = builder.AddInstruction( |
| HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x)); |
| builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); |
| |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(bitcast))); |
| } |
| |
| TEST_F(CopyInsertionTest, NestedTupleParameter) { |
| // Construct a trivial computation where the root of the computation is a |
| // nested tuple-shaped parameter. The parameter should be deep copied and the |
| // copy should be the root of the computation. |
| auto builder = HloComputation::Builder(TestName()); |
| |
| // Param shape is: ((F32[], S32[1,2,3]), F32[42]) |
| builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, |
| ShapeUtil::MakeTupleShape( |
| {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), |
| ShapeUtil::MakeShape(S32, {1, 2, 3})}), |
| ShapeUtil::MakeShape(F32, {42})}), |
| "param0")); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(HloOpcode::kParameter, |
| module->entry_computation()->root_instruction()->opcode()); |
| |
| HloInstruction* old_root = module->entry_computation()->root_instruction(); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 3); |
| |
| HloInstruction* new_root = module->entry_computation()->root_instruction(); |
| EXPECT_NE(old_root, new_root); |
| |
| EXPECT_THAT( |
| new_root, |
| op::Tuple( |
| op::Tuple( |
| op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))), |
| op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))), |
| op::Copy(op::GetTupleElement(old_root)))); |
| } |
| |
| TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { |
| // Construct a computation where the root of the computation is a tuple |
| // element of a nested tuple-shaped parameter. |
| auto builder = HloComputation::Builder(TestName()); |
| |
| // Param shape is: ((F32[], S32[1,2,3]), F32[42]) |
| auto param = builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, |
| ShapeUtil::MakeTupleShape( |
| {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), |
| ShapeUtil::MakeShape(S32, {1, 2, 3})}), |
| ShapeUtil::MakeShape(F32, {42})}), |
| "param0")); |
| |
| // The return value of the computation is the zero-th element of the nested |
| // tuple. This element is itself a tuple. |
| auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(gte, module->entry_computation()->root_instruction()); |
| |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 2); |
| |
| EXPECT_THAT( |
| module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))), |
| op::Copy(op::GetTupleElement(op::GetTupleElement(param))))); |
| } |
| |
| TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { |
| // Create a computation using select which has an ambiguous points-to set for |
| // the top-level buffer of the root of the computation. Verify that a shallow |
| // copy is added. |
| auto builder = HloComputation::Builder(TestName()); |
| HloInstruction* constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| HloInstruction* constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); |
| |
| HloInstruction* tuple1 = builder.AddInstruction( |
| HloInstruction::CreateTuple({constant1, constant2})); |
| HloInstruction* tuple2 = builder.AddInstruction( |
| HloInstruction::CreateTuple({constant2, constant1})); |
| |
| HloInstruction* pred = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( |
| tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); |
| HloInstruction* gte = |
| builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); |
| |
| auto module = CreateNewVerifiedModule(); |
| module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(gte, module->entry_computation()->root_instruction()); |
| |
| HloInstruction* old_root = module->entry_computation()->root_instruction(); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Copy(old_root)); |
| } |
| |
| class WhileCopyInsertionTest : public CopyInsertionTest { |
| protected: |
| WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {} |
| |
| // Builds a While condition computation which reads the induction variable |
| // from the tuple parameter, and returns a predicate indicating whether this |
| // value is less than the constant '10'. |
| // The parameter 'nested' specifies the loop state shape from which to |
| // read the induction variable. |
| std::unique_ptr<HloComputation> BuildConditionComputation( |
| const Shape& loop_state_shape) { |
| auto builder = HloComputation::Builder(TestName() + ".Condition"); |
| auto limit_const = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10))); |
| auto loop_state = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); |
| auto induction_variable = |
| builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| limit_const->shape(), loop_state, 0)); |
| builder.AddInstruction(HloInstruction::CreateCompare( |
| condition_result_shape_, induction_variable, limit_const, |
| ComparisonDirection::kLt)); |
| return builder.Build(); |
| } |
| |
| // Builds a While body computation with one output tuple element dependent on |
| // both input tuple elements. |
| // EX: |
| // Body({in0, in1}) |
| // out0 = Add(in0, 1) |
| // out1 = Add(BCast(in0), in1) |
| // Tuple(out0, out1) |
| std::unique_ptr<HloComputation> BuildDependentBodyComputation() { |
| auto builder = HloComputation::Builder(TestName() + ".Body"); |
| // Create param instruction to access loop state. |
| auto loop_state = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); |
| // Update the induction variable GTE(0). |
| auto induction_variable = |
| builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| induction_variable_shape_, loop_state, 0)); |
| auto inc = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); |
| auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( |
| induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); |
| // Update data GTE(1). |
| auto data = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); |
| // Use 'induction_variable' in computation with no path to output tuple. |
| Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {}); |
| auto convert = builder.AddInstruction( |
| HloInstruction::CreateConvert(f32_scalar_shape, induction_variable)); |
| auto update = builder.AddInstruction( |
| HloInstruction::CreateBroadcast(data_shape_, convert, {})); |
| auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( |
| data_shape_, HloOpcode::kAdd, data, update)); |
| // Create output Tuple. |
| builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); |
| return builder.Build(); |
| } |
| |
| // Builds a While body computation with two output tuple elements dependent on |
| // both input tuple elements. |
| // |
| // EX: Body({in0, in1, in2}) |
| // out0 = Add(in0, 1) |
| // out1 = in1 |
| // out2 = in2 |
| // Tuple(out0, out1, out2) |
| std::unique_ptr<HloComputation> BuildDependentBodyComputation2() { |
| auto builder = HloComputation::Builder(TestName() + ".Body"); |
| |
| const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( |
| {induction_variable_shape_, data_shape_, data_shape_}); |
| |
| auto loop_state = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); |
| |
| // Update the induction variable GTE(0). |
| auto induction_variable = |
| builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| induction_variable_shape_, loop_state, 0)); |
| auto inc = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); |
| |
| // add0 = Add(in0, 1) |
| auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( |
| induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); |
| // data1 = GTE(1). |
| HloInstruction* data1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); |
| |
| // data2 = GTE(2). |
| HloInstruction* data2 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2)); |
| |
| // Create output Tuple. |
| builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2})); |
| |
| return builder.Build(); |
| } |
| |
| // Builds a While body computation with read-only tuple element 0. |
| // EX: |
| // Body({in0, in1}) |
| // out0 = in0 |
| // out1 = Add(BCast(in0), in1) |
| // Tuple(out0, out1) |
| std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() { |
| auto builder = HloComputation::Builder(TestName() + ".Body"); |
| // Create param instruction to access loop state. |
| auto loop_state = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); |
| // Update the induction variable GTE(0). |
| auto induction_variable = |
| builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| induction_variable_shape_, loop_state, 0)); |
| // Update data GTE(1). |
| auto data = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); |
| |
| // Use 'induction_variable' in computation with no path to output tuple. |
| Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {}); |
| auto convert = builder.AddInstruction( |
| HloInstruction::CreateConvert(f32_scalar_shape, induction_variable)); |
| auto update = builder.AddInstruction( |
| HloInstruction::CreateBroadcast(data_shape_, convert, {})); |
| auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( |
| data_shape_, HloOpcode::kAdd, data, update)); |
| // Create output Tuple. |
| builder.AddInstruction( |
| HloInstruction::CreateTuple({induction_variable, add1})); |
| return builder.Build(); |
| } |
| |
| // Builds a While body computation with independent outputs. |
| // EX: |
| // Body({in0, in1}) |
| // out0 = Add(in0, 1) |
| // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) |
| // Tuple(out0, out1) |
| std::unique_ptr<HloComputation> BuildIndependentBodyComputation( |
| bool nested = false) { |
| auto builder = HloComputation::Builder(TestName() + ".Body"); |
| // Create param instruction to access loop state. |
| const Shape& loop_state_shape = |
| nested ? nested_loop_state_shape_ : loop_state_shape_; |
| |
| auto loop_state = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); |
| // Update the induction variable GTE(0). |
| auto induction_variable = |
| builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| induction_variable_shape_, loop_state, 0)); |
| auto inc = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); |
| // add0 = Add(in0, 1) |
| auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( |
| induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); |
| // Update data GTE(1). |
| HloInstruction* data = nullptr; |
| if (nested) { |
| data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| nested_tuple_shape_, loop_state, 1)); |
| data = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, data, 0)); |
| } else { |
| data = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); |
| } |
| auto update = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( |
| {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); |
| // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) |
| auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( |
| data_shape_, HloOpcode::kAdd, data, update)); |
| // Create output Tuple. |
| if (nested) { |
| auto nested_tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({add1, add1})); |
| builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple})); |
| } else { |
| builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); |
| } |
| return builder.Build(); |
| } |
| |
| // Builds a While body computation with the following nested tuple |
| // sub-computation: |
| // | |
| // GTE(loop_state, 1) |
| // / \ |
| // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) |
| // | | |
| // Add Reverse |
| // | | |
| std::unique_ptr<HloComputation> BuildNestedBodyComputation() { |
| auto builder = HloComputation::Builder(TestName() + ".Body"); |
| // Create param instruction to access loop state. |
| auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, nested_loop_state_shape_, "loop_state")); |
| // Update GTE(0). |
| auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| induction_variable_shape_, loop_state, 0)); |
| auto inc = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); |
| auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( |
| gte0->shape(), HloOpcode::kAdd, gte0, inc)); |
| |
| // GTE(loop_state, 1) |
| auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| nested_tuple_shape_, loop_state, 1)); |
| // GTE(GTE(loop_state, 1), 0) -> Add |
| auto gte10 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); |
| auto update10 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( |
| {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); |
| auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( |
| data_shape_, HloOpcode::kAdd, gte10, update10)); |
| |
| // GTE(GTE(loop_state, 1), 1) -> Reverse |
| auto gte11 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1)); |
| auto rev11 = builder.AddInstruction( |
| HloInstruction::CreateReverse(data_shape_, gte11, {0})); |
| |
| // Create output Tuple. |
| auto inner_tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11})); |
| builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple})); |
| return builder.Build(); |
| } |
| |
| // Builds a While instruction using 'condition' and 'body' sub-computations. |
| // Init operand is initialized to zeros of appropriate shape. |
| HloInstruction* BuildWhileInstruction(HloComputation* condition, |
| HloComputation* body, |
| bool nested = false) { |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| auto induction_var_init = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); |
| |
| auto data_init = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( |
| {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); |
| |
| if (nested) { |
| auto inner_init = builder.AddInstruction( |
| HloInstruction::CreateTuple({data_init, data_init})); |
| auto loop_state_init = builder.AddInstruction( |
| HloInstruction::CreateTuple({induction_var_init, inner_init})); |
| auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_init->shape(), condition, body, loop_state_init)); |
| module_->AddEntryComputation(builder.Build()); |
| return while_hlo; |
| } |
| |
| auto loop_state_init = builder.AddInstruction( |
| HloInstruction::CreateTuple({induction_var_init, data_init})); |
| auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape_, condition, body, loop_state_init)); |
| module_->AddEntryComputation(builder.Build()); |
| return while_hlo; |
| } |
| |
| HloInstruction* BuildWhileInstruction_InitPointsToConstant() { |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| auto data_init = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( |
| {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); |
| return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, |
| &builder); |
| } |
| |
| HloInstruction* BuildWhileInstruction_InitPointsToParameter() { |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| auto data_init = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, data_shape_, "data_init")); |
| return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, |
| &builder); |
| } |
| |
| HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() { |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| |
| auto one = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| auto v1 = builder.AddInstruction( |
| HloInstruction::CreateBroadcast(data_shape_, one, {})); |
| auto zero = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| auto v2 = builder.AddInstruction( |
| HloInstruction::CreateBroadcast(data_shape_, zero, {})); |
| |
| auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2})); |
| auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); |
| |
| auto pred = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( |
| nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); |
| |
| return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, |
| data_init, &builder); |
| } |
| |
| HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() { |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| |
| auto one = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| auto one_vec = builder.AddInstruction( |
| HloInstruction::CreateBroadcast(data_shape_, one, {})); |
| auto data_init = |
| builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec})); |
| |
| return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, |
| data_init, &builder); |
| } |
| |
| HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| auto one = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| auto data_init = builder.AddInstruction( |
| HloInstruction::CreateBroadcast(data_shape_, one, {})); |
| auto one_vec = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( |
| {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); |
| // Take a reference to 'data_init' to make it interfere with while result. |
| auto add = builder.AddInstruction(HloInstruction::CreateBinary( |
| data_shape_, HloOpcode::kAdd, data_init, one_vec)); |
| |
| auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_, |
| data_init, &builder); |
| |
| // Add an additional binary operation operating on the while and the |
| // interfering add so that neither operation is dead. |
| auto gte = xla_while->parent()->AddInstruction( |
| HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1)); |
| auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary( |
| data_shape_, HloOpcode::kSubtract, add, gte)); |
| auto gte0 = xla_while->parent()->AddInstruction( |
| HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0)); |
| auto tuple = xla_while->parent()->AddInstruction( |
| HloInstruction::CreateTuple({gte0, sub})); |
| |
| xla_while->parent()->set_root_instruction(tuple); |
| |
| return xla_while; |
| } |
| |
| HloInstruction* BuildWhileInstructionWithCustomInit( |
| const Shape& loop_state_shape, HloInstruction* data_init, |
| HloComputation::Builder* builder) { |
| const bool nested = |
| ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); |
| auto induction_var_init = builder->AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); |
| auto condition = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape)); |
| auto body = module_->AddEmbeddedComputation( |
| BuildIndependentBodyComputation(nested)); |
| auto loop_state_init = builder->AddInstruction( |
| HloInstruction::CreateTuple({induction_var_init, data_init})); |
| auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape, condition, body, loop_state_init)); |
| module_->AddEntryComputation(builder->Build()); |
| return while_hlo; |
| } |
| |
| std::unique_ptr<HloModule> module_; |
| Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {}); |
| Shape data_shape_ = ShapeUtil::MakeShape(F32, {8}); |
| Shape loop_state_shape_ = |
| ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_}); |
| Shape nested_tuple_shape_ = |
| ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); |
| Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape( |
| {induction_variable_shape_, nested_tuple_shape_}); |
| Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {}); |
| }; |
| |
| // Tests while body computation with independent tuple elements: |
| // |
| // While.Body({in0, in1}) |
| // out0 = Add(in0, 1) |
| // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) |
| // Tuple(out0, out1) |
| // |
| // CopyInsertion pass should not generate any copies. |
| // |
| TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { |
| auto condition = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape_)); |
| auto body = |
| module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); |
| auto while_hlo = BuildWhileInstruction(condition, body); |
| |
| InsertCopies(module_.get()); |
| |
| // Body should have no copies as the adds can be done inplace. |
| EXPECT_EQ(CountCopies(*body), 0); |
| EXPECT_EQ(CountControlEdges(*module_), 0); |
| |
| // Both init indices need copies as they are constants. |
| EXPECT_THAT(while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); |
| } |
| |
| // Tests Copy Insertion when a while feeds another while |
| // PARAMETER |
| // | | |
| // GTE(0) GTE(1) |
| // | | |
| // X = CreateTuple(GTE(0), GTE(1)) |
| // | |
| // WHILE(X) (root) |
| TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterWithCopies) { |
| const string& hlo_string = R"( |
| HloModule DependentTupleElements |
| |
| %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) { |
| %loop_state.1 = (s32[], f32[8]{0}) parameter(0) |
| %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0 |
| %constant.1 = s32[] constant(1) |
| %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) |
| %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1 |
| %convert = f32[] convert(s32[] %get-tuple-element.1) |
| %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={} |
| %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast) |
| ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1) |
| } |
| |
| %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] { |
| %loop_state = (s32[], f32[8]{0}) parameter(0) |
| %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0 |
| %constant = s32[] constant(10) |
| ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT |
| } |
| |
| ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) { |
| %constant.2 = s32[] constant(0) |
| %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0}) |
| %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3) |
| ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body |
| } |
| )"; |
| auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); |
| auto module_ = module_or_status.ConsumeValueOrDie(); |
| auto while_hlo = module_->entry_computation()->root_instruction(); |
| // module_ and while_hlo are the pre-existing module and hlo, the below |
| // code generates a clone of the existing while and replaces that while |
| // with itself. The body of the new while calls the previous while |
| HloComputation* outer_while_condition = |
| module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone()); |
| HloComputation* outer_while_body = |
| module_->AddEmbeddedComputation(while_hlo->while_body()->Clone()); |
| HloInstruction* outer_while = |
| while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile( |
| while_hlo->shape(), outer_while_condition, outer_while_body, |
| while_hlo->mutable_operand(0))); |
| HloInstruction* outer_param = outer_while_body->parameter_instruction(0); |
| std::vector<HloInstruction*> materialized_gtes; |
| for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) { |
| materialized_gtes.push_back( |
| outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( |
| outer_param->shape().tuple_shapes(i), outer_param, i))); |
| } |
| HloInstruction* dual_init = outer_while_body->AddInstruction( |
| HloInstruction::CreateTuple(materialized_gtes)); |
| HloInstruction* dual_while = |
| outer_while_body->AddInstruction(HloInstruction::CreateWhile( |
| while_hlo->shape(), while_hlo->while_condition(), |
| while_hlo->while_body(), dual_init)); |
| TF_CHECK_OK(outer_while_body->ReplaceInstruction( |
| outer_while_body->root_instruction(), dual_while)); |
| TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while)); |
| InsertCopies(module_.get()); |
| } |
| |
| // Tests Copy Insertion when a while feeds another while |
| // PARAMETER |
| // | | |
| // \ / |
| // WHILE(PARAMETER) (root) |
| TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterNoCopies) { |
| const string& hlo_string = R"( |
| HloModule DependentTupleElements |
| |
| %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) { |
| %loop_state.1 = (s32[], f32[8]{0}) parameter(0) |
| %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0 |
| %constant.1 = s32[] constant(1) |
| %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) |
| %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1 |
| %convert = f32[] convert(s32[] %get-tuple-element.1) |
| %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={} |
| %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast) |
| ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1) |
| } |
| |
| %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] { |
| %loop_state = (s32[], f32[8]{0}) parameter(0) |
| %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0 |
| %constant = s32[] constant(10) |
| ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT |
| } |
| |
| ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) { |
| %constant.2 = s32[] constant(0) |
| %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0}) |
| %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3) |
| ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body |
| } |
| )"; |
| auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); |
| auto module_ = module_or_status.ConsumeValueOrDie(); |
| auto while_hlo = module_->entry_computation()->root_instruction(); |
| // module_ and while_hlo are the pre-existing module and hlo, the below |
| // code generates a clone of the existing while and replaces that while |
| // with itself. The body of the new while calls the previous while |
| HloComputation* outer_while_condition = |
| module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone()); |
| HloComputation* outer_while_body = |
| module_->AddEmbeddedComputation(while_hlo->while_body()->Clone()); |
| HloInstruction* outer_while = |
| while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile( |
| while_hlo->shape(), outer_while_condition, outer_while_body, |
| while_hlo->mutable_operand(0))); |
| HloInstruction* outer_param = outer_while_body->parameter_instruction(0); |
| HloInstruction* dual_while = |
| outer_while_body->AddInstruction(HloInstruction::CreateWhile( |
| while_hlo->shape(), while_hlo->while_condition(), |
| while_hlo->while_body(), outer_param)); |
| TF_CHECK_OK(outer_while_body->ReplaceInstruction( |
| outer_while_body->root_instruction(), dual_while)); |
| TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while)); |
| InsertCopies(module_.get()); |
| } |
| |
| // Tests Copy Insertion when a while feeds another while |
| // PARAMETER |
| // | | |
| // \ / |
| // WHILE(PARAMETER) (root) |
| TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterBig) { |
| const string& hlo_string = R"( |
| HloModule DependentTupleElements |
| |
| %DependentTupleElements.Body (loop_state.1: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) { |
| %loop_state.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0) |
| %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=0 |
| %constant.1 = s32[] constant(1) |
| %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) |
| %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=1 |
| %convert = f32[] convert(s32[] %get-tuple-element.1) |
| %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={} |
| %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast) |
| ROOT %tuple = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1) |
| } |
| |
| %DependentTupleElements.Condition (loop_state: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> pred[] { |
| %loop_state = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0) |
| %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state), index=0 |
| %constant = s32[] constant(10) |
| ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT |
| } |
| |
| ENTRY %DependentTupleElements.While () -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) { |
| %constant.2 = s32[] constant(0) |
| %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0}) |
| %tuple.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3) |
| ROOT %while.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) while( (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body |
| } |
| )"; |
| auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); |
| auto module_ = module_or_status.ConsumeValueOrDie(); |
| auto while_hlo = module_->entry_computation()->root_instruction(); |
| // module_ and while_hlo are the pre-existing module and hlo, the below |
| // code generates a clone of the existing while and replaces that while |
| // with itself. The body of the new while calls the previous while |
| HloComputation* outer_while_condition = |
| module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone()); |
| HloComputation* outer_while_body = |
| module_->AddEmbeddedComputation(while_hlo->while_body()->Clone()); |
| HloInstruction* outer_while = |
| while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile( |
| while_hlo->shape(), outer_while_condition, outer_while_body, |
| while_hlo->mutable_operand(0))); |
| HloInstruction* outer_param = outer_while_body->parameter_instruction(0); |
| std::vector<HloInstruction*> materialized_gtes; |
| for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) { |
| materialized_gtes.push_back( |
| outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( |
| outer_param->shape().tuple_shapes(i), outer_param, i))); |
| } |
| HloInstruction* dual_init = outer_while_body->AddInstruction( |
| HloInstruction::CreateTuple(materialized_gtes)); |
| HloInstruction* dual_while = |
| outer_while_body->AddInstruction(HloInstruction::CreateWhile( |
| while_hlo->shape(), while_hlo->while_condition(), |
| while_hlo->while_body(), dual_init)); |
| TF_CHECK_OK(outer_while_body->ReplaceInstruction( |
| outer_while_body->root_instruction(), dual_while)); |
| TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while)); |
| InsertCopies(module_.get()); |
| } |
| |
| // Tests while body computation with dependent tuple elements: |
| // |
| // While.Body({in0, in1}) |
| // out0 = Add(in0, 1) |
| // out1 = Add(BCast(in0), in1) |
| // Tuple(out0, out1) |
| // |
| // CopyInsertion pass should convert the root instruction to: |
| // |
| // Tuple(Copy(out0), out1) |
| // |
| TEST_F(WhileCopyInsertionTest, DependentTupleElements) { |
| auto condition = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape_)); |
| auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); |
| auto while_hlo = BuildWhileInstruction(condition, body); |
| |
| InsertCopies(module_.get()); |
| |
| EXPECT_EQ(CountCopies(*body), 1); |
| EXPECT_EQ(CountControlEdges(*body), 0); |
| |
| EXPECT_THAT( |
| body->root_instruction(), |
| op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast()))); |
| |
| auto add = body->root_instruction()->operand(0); |
| auto bcast = body->root_instruction()->operand(1)->operand(1); |
| ASSERT_EQ(add->opcode(), HloOpcode::kAdd); |
| ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); |
| |
| EXPECT_THAT(while_hlo->while_body()->root_instruction(), |
| op::Tuple(op::Add(op::Copy(), op::Constant()), |
| op::Add(op::GetTupleElement(), |
| op::Broadcast(op::Convert(op::Copy()))))); |
| |
| // Both init indices need copies as they are constants. |
| EXPECT_THAT(while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); |
| } |
| |
| // Tests while body computation with read-only tuple element 0: |
| // |
| // PARAMETER |
| // / \ |
| // GTE(0) GTE(1) |
| // | \ | |
| // | BCAST | |
| // | \ | |
| // | ADD |
| // | | |
| // \ / |
| // TUPLE (root) |
| // |
| // CopyInsertion pass should not generate any copies for the while body. |
| TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { |
| auto condition = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape_)); |
| auto body = module_->AddEmbeddedComputation( |
| BuildDependentBodyOneReadOnlyComputation()); |
| BuildWhileInstruction(condition, body); |
| |
| InsertCopies(module_.get()); |
| |
| // No copies or control edges should be inserted. The body is legal as is. |
| EXPECT_EQ(CountCopies(*body), 0); |
| EXPECT_EQ(CountControlEdges(*body), 0); |
| } |
| |
| // Same as above, but with two while loops, sharing entry parameters. |
| TEST_F(WhileCopyInsertionTest, |
| DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { |
| auto condition1 = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape_)); |
| auto condition2 = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape_)); |
| auto body1 = module_->AddEmbeddedComputation( |
| BuildDependentBodyOneReadOnlyComputation()); |
| auto body2 = module_->AddEmbeddedComputation( |
| BuildDependentBodyOneReadOnlyComputation()); |
| |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| auto iter_param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); |
| auto data_param = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, data_shape_, "data")); |
| auto loop_init = builder.AddInstruction( |
| HloInstruction::CreateTuple({iter_param, data_param})); |
| |
| auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape_, condition1, body1, loop_init)); |
| auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape_, condition2, body2, loop_init)); |
| |
| // Add a couple elements from each of the while so both whiles are live. |
| auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); |
| auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); |
| builder.AddInstruction( |
| HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); |
| |
| auto entry = module_->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module_.get()); |
| |
| // Neither body should have any copies or control edges in them. |
| EXPECT_EQ(CountCopies(*body1), 0); |
| EXPECT_EQ(CountCopies(*body2), 0); |
| EXPECT_EQ(CountControlEdges(*body1), 0); |
| EXPECT_EQ(CountControlEdges(*body2), 0); |
| |
| // Only two copies should be necessary. Each of the whiles should have |
| // a copy of tuple element 1 (init value is a parameter, and the element is |
| // not non-read-only) so each of the while bodies gets its own buffer to write |
| // element 1 into. |
| EXPECT_EQ(CountCopies(*entry), 2); |
| |
| EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); |
| EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); |
| |
| // The two copies of element 1 should be different. |
| EXPECT_NE(while_hlo1->operand(0)->operand(1), |
| while_hlo2->operand(0)->operand(1)); |
| } |
| |
| // Same as above, but with two while loops, sharing non-parameters. |
| TEST_F(WhileCopyInsertionTest, |
| DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { |
| auto condition1 = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape_)); |
| auto condition2 = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape_)); |
| auto body1 = module_->AddEmbeddedComputation( |
| BuildDependentBodyOneReadOnlyComputation()); |
| auto body2 = module_->AddEmbeddedComputation( |
| BuildDependentBodyOneReadOnlyComputation()); |
| |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| auto iter_param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); |
| auto data_param = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, data_shape_, "data")); |
| // Add dummy ops to ensure loop_init elements aren't entry parameters. |
| Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {}); |
| auto convert = builder.AddInstruction( |
| HloInstruction::CreateConvert(f32_scalar_shape, iter_param)); |
| auto iter_value = builder.AddInstruction( |
| HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert)); |
| auto convert2 = builder.AddInstruction( |
| HloInstruction::CreateConvert(induction_variable_shape_, iter_value)); |
| auto data_value = builder.AddInstruction(HloInstruction::CreateUnary( |
| data_param->shape(), HloOpcode::kExp, data_param)); |
| auto loop_init = builder.AddInstruction( |
| HloInstruction::CreateTuple({convert2, data_value})); |
| |
| auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape_, condition1, body1, loop_init)); |
| auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape_, condition2, body2, loop_init)); |
| |
| // Add a couple elements from each of the while so both whiles are not dead. |
| auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); |
| auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); |
| builder.AddInstruction( |
| HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); |
| auto entry = module_->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module_.get()); |
| |
| // Ideally only one copy should be necessary. One of the whiles should |
| // have a copy of tuple element 1 (the non-read-only element) so each of the |
| // while bodies gets its own buffer to write element 1 into. However, the |
| // analysis isn't perfect and adds an additional copy of element 0. |
| EXPECT_EQ(CountCopies(*entry), 2); |
| |
| EXPECT_THAT(while_hlo1->operand(0), |
| op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp()))); |
| EXPECT_THAT(while_hlo2->operand(0), |
| op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp()))); |
| } |
| |
| // Tests while body computation with nested tuple elements: |
| // |
| // | |
| // GTE(loop_state, 1) |
| // / \ |
| // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) |
| // | | |
| // Add Reverse |
| // | | |
| // |
| // CopyInsertion pass will conceptually generate the following, but with the |
| // actual GTE and Tuple instructions optimized away: |
| // |
| // Tuple // old root |
| // / \ |
| // / \ |
| // GTE(0) GTE(1) |
| // | / \ |
| // | / \ |
| // | GTE(0) GTE(1) |
| // | | | |
| // | | Copy |
| // | | | |
| // \ | / |
| // \ Tuple // "inner" tuple. |
| // \ / |
| // \ / |
| // Tuple // new root |
| // |
| TEST_F(WhileCopyInsertionTest, NestedTupleElements) { |
| auto condition = module_->AddEmbeddedComputation( |
| BuildConditionComputation(nested_loop_state_shape_)); |
| auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); |
| BuildWhileInstruction(condition, body, true); |
| |
| // HloInstruction* old_root = body->root_instruction(); |
| InsertCopies(module_.get()); |
| |
| // The only copy necessary is for the kReverse as it cannot be done |
| // in-place (instruction can share buffer with operand). The other elements of |
| // the loop state are kAdd instructions which can be done in-place. |
| EXPECT_EQ(CountCopies(*body), 1); |
| |
| // Each element of the init needs a copy as all are constants. |
| EXPECT_EQ(CountCopies(*module_), 4); |
| |
| // Either the kReverse itself must be copied or the operand of the kReverse |
| // must be copied. |
| if (body->root_instruction()->operand(1)->operand(1)->opcode() == |
| HloOpcode::kCopy) { |
| EXPECT_THAT( |
| body->root_instruction(), |
| op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse())))); |
| } else { |
| EXPECT_THAT( |
| body->root_instruction(), |
| op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy())))); |
| } |
| } |
| |
| // Tests while init instruction which points-to a constant. |
| // |
| // init = Tuple(Constant(S32, {}), Constant(F32, {8})) |
| // |
| // CopyInsertion pass should add copies for both constants. |
| // |
| TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { |
| auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); |
| |
| InsertCopies(module_.get()); |
| EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); |
| EXPECT_EQ(CountCopies(*module_), 2); |
| |
| EXPECT_THAT(while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); |
| } |
| |
| // Tests while init instruction which points-to a parameter. |
| // |
| // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) |
| // |
| // CopyInsertion pass should add copies for both the constant and parameter. |
| // |
| TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { |
| auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); |
| |
| InsertCopies(module_.get()); |
| EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); |
| EXPECT_EQ(CountCopies(*module_), 2); |
| |
| EXPECT_THAT(while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter()))); |
| } |
| |
| // Tests while init instruction which has an ambiguous points-to set. |
| // |
| // select = Select(pred, tuple1, tuple2) |
| // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) |
| // |
| // CopyInsertion pass will conceptually generate the following, but with some of |
| // the actual GTE and Tuple instructions optimized away: |
| // |
| // Tuple // old init |
| // / \ |
| // / \ |
| // GTE(0) GTE(1) |
| // | / \ |
| // | / \ |
| // | GTE(0) GTE(1) |
| // | | | |
| // Copy Copy Copy |
| // | | | |
| // \ | / |
| // \ Tuple |
| // \ / |
| // \ / |
| // Tuple // new init |
| // |
| TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { |
| auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); |
| |
| InsertCopies(module_.get()); |
| EXPECT_EQ(CountCopies(*module_), 4); |
| // The entry computation requires three copies to resolve the ambiguity of two |
| // init elements and the constant passed in as one of the init elements. |
| EXPECT_EQ(CountCopies(*module_->entry_computation()), 3); |
| EXPECT_THAT(while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), |
| op::Tuple(op::Copy(op::GetTupleElement()), |
| op::Copy(op::GetTupleElement())))); |
| |
| // The body requires one copy because the buffer set is not distinct: the |
| // result of one of the adds is written into two elements of the output of the |
| // loop body. Either element might be copied. |
| EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); |
| if (while_hlo->while_body() |
| ->root_instruction() |
| ->operand(1) |
| ->operand(0) |
| ->opcode() == HloOpcode::kCopy) { |
| EXPECT_THAT( |
| while_hlo->while_body()->root_instruction(), |
| op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); |
| } else { |
| EXPECT_THAT( |
| while_hlo->while_body()->root_instruction(), |
| op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); |
| } |
| } |
| |
| // Tests while init instruction which has a non-distinct points-to set. |
| // |
| // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one})) |
| // |
| // CopyInsertion pass will conceptually generate the following, but with some of |
| // the actual GTE and Tuple instructions optimized away: |
| // |
| // Tuple // old init |
| // / \ |
| // / \ |
| // GTE(0) GTE(1) |
| // | / \ |
| // | / \ |
| // | GTE(0) GTE(1) |
| // | | | |
| // Copy Copy Copy |
| // | | | |
| // \ | / |
| // \ Tuple |
| // \ / |
| // \ / |
| // Tuple // new init |
| // |
| TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { |
| auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); |
| |
| InsertCopies(module_.get()); |
| |
| // The entry computation requires two copies to resolve the non-distinctness |
| // of two init elements and the constant passed in as one of the init |
| // elements. Either element can be copied for the distinctness issue. |
| EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); |
| if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() == |
| HloOpcode::kCopy) { |
| EXPECT_THAT( |
| while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), |
| op::Tuple(op::Copy(op::Broadcast()), op::Broadcast()))); |
| } else { |
| EXPECT_THAT( |
| while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), |
| op::Tuple(op::Broadcast(), op::Copy(op::Broadcast())))); |
| } |
| |
| // The body requires one copy because the buffer set is not distinct: the |
| // result of one of the adds is written into two elements of the output of the |
| // loop body. Either element might be copied. |
| EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); |
| if (while_hlo->while_body() |
| ->root_instruction() |
| ->operand(1) |
| ->operand(0) |
| ->opcode() == HloOpcode::kCopy) { |
| EXPECT_THAT( |
| while_hlo->while_body()->root_instruction(), |
| op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); |
| } else { |
| EXPECT_THAT( |
| while_hlo->while_body()->root_instruction(), |
| op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); |
| } |
| } |
| |
| // Tests while init instruction buffer which interferes with while result |
| // buffer. |
| // |
| // init_data = Broadcast(...) |
| // add_unrelated = Add(init_data) // takes a reference to cause interference |
| // init = Tuple(Constant(S32, {}), init_data)) |
| // |
| // CopyInsertion pass should copy both operands. |
| // |
| TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { |
| auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); |
| |
| InsertCopies(module_.get()); |
| EXPECT_EQ(CountCopies(*module_), 2); |
| EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); |
| |
| EXPECT_THAT(while_hlo->operand(0), |
| op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast()))); |
| } |
| |
| // Tests while init instruction buffer which has a non-distinct points-to set: |
| // |
| // init = Tuple(Parameter(S32, {}), Parameter(F32, {8}, |
| // Parameter(F32, {8}))) |
| // |
| // where the second and third parameters are identical *and* the tuple shared |
| // by another while instruction. |
| // |
| // Verifies that the resulting point-to set is distinct in the resulting Tuple |
| // (non-identical Copys). In other words, verifies that copy sharing does not |
| // insert identical copies to the resulting tuple. |
| TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { |
| // Loop body that outputs tuple comprises two elements dependent on the init |
| // tuple. |
| const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( |
| {induction_variable_shape_, data_shape_, data_shape_}); |
| |
| auto condition1 = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape)); |
| auto condition2 = module_->AddEmbeddedComputation( |
| BuildConditionComputation(loop_state_shape)); |
| auto body1 = |
| module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); |
| auto body2 = |
| module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); |
| |
| auto builder = HloComputation::Builder(TestName() + ".While"); |
| |
| auto iter_param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); |
| auto data_param = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, data_shape_, "data")); |
| |
| // Loop init tuple contains two identical parameter buffers. |
| auto loop_init = builder.AddInstruction( |
| HloInstruction::CreateTuple({iter_param, data_param, data_param})); |
| |
| // Two while loops share the same loop init tuple. |
| auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape, condition1, body1, loop_init)); |
| auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape, condition2, body2, loop_init)); |
| |
| // Add add instruction so neither while is dead. |
| auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); |
| auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0)); |
| builder.AddInstruction( |
| HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); |
| |
| module_->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module_.get()); |
| |
| // None of the bodies should have copies or control flow edges. |
| EXPECT_EQ(CountCopies(*body1), 0); |
| EXPECT_EQ(CountCopies(*body2), 0); |
| |
| // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally |
| // these should not need to be copied before either while. However, copy |
| // insertion is not able to reason about the transparency of elements through |
| // while bodies in all circumstances so extra copies are added (b/xxx). |
| EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); |
| |
| EXPECT_THAT(while_hlo1->operand(0), |
| op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); |
| EXPECT_THAT(while_hlo2->operand(0), |
| op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); |
| } |
| |
| TEST_F(CopyInsertionTest, SwizzlingWhile) { |
| // Test a while instruction with a body which permutes its tuple parameter |
| // elements. |
| auto module = CreateNewVerifiedModule(); |
| const Shape loop_state_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| // Body simply interchanges the two tuple elements in the loop state. |
| auto body_builder = HloComputation::Builder("body"); |
| auto body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto body_element_0 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); |
| auto body_element_1 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); |
| body_builder.AddInstruction( |
| HloInstruction::CreateTuple({body_element_1, body_element_0})); |
| HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto cond_builder = HloComputation::Builder("condition"); |
| cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto cond_constant = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| cond_builder.AddInstruction(HloInstruction::CreateUnary( |
| cond_constant->shape(), HloOpcode::kNot, cond_constant)); |
| HloComputation* condition = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| auto constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); |
| auto tuple = builder.AddInstruction( |
| HloInstruction::CreateTuple({constant1, constant2})); |
| auto xla_while = builder.AddInstruction( |
| HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 6); |
| |
| // The loop state elements should be copied at the parameter and at the root |
| // with a control edge in between (see DeepCopyAndAddControlEdges). This is |
| // technically one more copy than is strictly necessary, but in order to have |
| // only three copies the copies of different loop state elements must be |
| // ordered with a control edge. |
| EXPECT_EQ(CountCopies(*body), 4); |
| EXPECT_EQ(CountControlEdges(*body), 2); |
| |
| EXPECT_THAT(body->root_instruction(), |
| op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy()))); |
| |
| EXPECT_EQ(CountCopies(*module->entry_computation()), 2); |
| EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); |
| } |
| |
| TEST_F(CopyInsertionTest, CrossingParameters) { |
| // Test a case where two parameters' dataflow cross with each other while |
| // input and output are aliased with same index: |
| // |
| // (p0 , p1) |
| // | \ /| |
| // | \ / | |
| // alias X alias |
| // | / \ | |
| // | / \| |
| // (p1 , p0) |
| auto module = CreateNewVerifiedModule(); |
| const Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "0")); |
| auto gte0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); |
| auto gte1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); |
| builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); |
| module->AddEntryComputation(builder.Build()); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 4); |
| } |
| |
| TEST_F(CopyInsertionTest, ParametersAliasing) { |
| // Test a case where two parameters' dataflow don't interfere with each other |
| // while aliased. |
| // |
| // (p0 , p1) |
| // | | |
| // | | |
| // alias alias |
| // | | |
| // | | |
| // (p0 , p1) |
| auto module = CreateNewVerifiedModule(); |
| const Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| auto gte0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); |
| auto gte1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); |
| builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); |
| module->AddEntryComputation(builder.Build()); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { |
| // Test a case where no parameter is aliased with result. In this case, copy |
| // should be added |
| // |
| // (p0 , p1) |
| // | | |
| // | | |
| // | | |
| // | | |
| // | | |
| // (p0 , p1) |
| auto module = CreateNewVerifiedModule(); |
| const Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| auto gte0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); |
| auto gte1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); |
| builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); |
| module->AddEntryComputation(builder.Build()); |
| InsertCopies(module.get()); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(op::GetTupleElement(param, 0)), |
| op::Copy(op::GetTupleElement(param, 1)))); |
| |
| EXPECT_EQ(CountCopies(*module), 2); |
| } |
| |
| TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { |
| // Test a case where one parameter is aliased with result while another one |
| // isn't. |
| // |
| // (p0 , p1) |
| // | | |
| // | | |
| // alias | |
| // | | |
| // | | |
| // (p0 , p1) |
| auto module = CreateNewVerifiedModule(); |
| const Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| auto gte0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); |
| auto gte1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); |
| builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); |
| module->AddEntryComputation(builder.Build()); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); |
| InsertCopies(module.get()); |
| |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::GetTupleElement(param, 0), |
| op::Copy(op::GetTupleElement(param, 1)))); |
| |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { |
| // Test a case where one parameter is aliased with result while another one |
| // isn't. |
| // |
| // +-- (p0 , p1) |
| // | | | |
| // | | | |
| // alias Negate Negate |
| // | | | |
| // | | | |
| // +-- (p0 , p1) |
| auto module = CreateNewVerifiedModule(); |
| const Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| auto gte0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); |
| auto gte1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); |
| |
| auto negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); |
| |
| auto negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); |
| builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); |
| module->AddEntryComputation(builder.Build()); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { |
| // Test a case where one parameter is aliased with result while another one |
| // isn't. |
| // |
| // +-- (p0 , p1) |
| // | | | |
| // | | | |
| // alias Negate Negate |
| // | | | |
| // | Add----+ |
| // | | | |
| // +-- (p0 , p1) |
| auto module = CreateNewVerifiedModule(); |
| const Shape tuple_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, tuple_shape, "p0")); |
| auto gte0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); |
| auto gte1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); |
| |
| auto negate0 = builder.AddInstruction( |
| HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); |
| |
| auto negate1 = builder.AddInstruction( |
| HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); |
| |
| auto add = builder.AddInstruction(HloInstruction::CreateBinary( |
| scalar_shape_, HloOpcode::kAdd, negate0, negate1)); |
| builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); |
| module->AddEntryComputation(builder.Build()); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { |
| // Test a while instruction with a body which permutes its tuple parameter |
| // elements and applies one operation to one of the elements. The addition of |
| // the operation (instruction) on the element makes the live range of the |
| // respective input and output elements different than if the instruction were |
| // not there (as in the SwizzlingWhile test above). |
| auto module = CreateNewVerifiedModule(); |
| const Shape loop_state_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| // Body interchanges the two tuple elements in the loop state and negates one |
| // of them. |
| auto body_builder = HloComputation::Builder("body"); |
| auto body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto body_element_0 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); |
| auto body_element_1 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); |
| auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( |
| scalar_shape_, HloOpcode::kNegate, body_element_1)); |
| body_builder.AddInstruction( |
| HloInstruction::CreateTuple({negate, body_element_0})); |
| HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto cond_builder = HloComputation::Builder("condition"); |
| cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto cond_constant = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| cond_builder.AddInstruction(HloInstruction::CreateUnary( |
| cond_constant->shape(), HloOpcode::kNot, cond_constant)); |
| HloComputation* condition = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| auto constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); |
| auto tuple = builder.AddInstruction( |
| HloInstruction::CreateTuple({constant1, constant2})); |
| auto xla_while = builder.AddInstruction( |
| HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 6); |
| |
| // The loop state elements should be copied at the parameter and at the root |
| // with a control edge in between (see DeepCopyAndAddControlEdges). |
| EXPECT_EQ(CountCopies(*body), 4); |
| EXPECT_EQ(CountControlEdges(*body), 2); |
| |
| EXPECT_THAT( |
| body->root_instruction(), |
| op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy()))); |
| |
| EXPECT_EQ(CountCopies(*module->entry_computation()), 2); |
| EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); |
| } |
| |
| TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { |
| // Test a while instruction with a body which permutes it's tuple parameter |
| // elements similar to SwizzlinWhile above. However, in this test the input to |
| // the while body is a single constant (both loop state elements are the same |
| // constant). This means no copies are necessary because both loop state |
| // elements are the same so interchanging them is a no-op. |
| auto module = CreateNewVerifiedModule(); |
| const Shape loop_state_shape = |
| ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); |
| |
| // Body simply interchanges the two tuple elements in the loop state. |
| auto body_builder = HloComputation::Builder("body"); |
| auto body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto body_element_0 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); |
| auto body_element_1 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); |
| body_builder.AddInstruction( |
| HloInstruction::CreateTuple({body_element_1, body_element_0})); |
| HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto cond_builder = HloComputation::Builder("condition"); |
| cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto cond_constant = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| cond_builder.AddInstruction(HloInstruction::CreateUnary( |
| cond_constant->shape(), HloOpcode::kNot, cond_constant)); |
| HloComputation* condition = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); |
| auto tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); |
| builder.AddInstruction( |
| HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 2); |
| EXPECT_EQ(CountCopies(*body), 0); |
| |
| EXPECT_EQ(CountCopies(*module->entry_computation()), 2); |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(), op::Copy())); |
| } |
| |
| TEST_F(CopyInsertionTest, SequentialWhiles) { |
| // Construct a computation with a series of sequential while instructions |
| // containing four loop state elements: |
| // |
| // element 0 is passed to each while directly from an entry parameter. |
| // |
| // element 1 is passed transparently in series through all the while bodies. |
| // |
| // element 2 is negated in each while body. (in-place possible) |
| // |
| // element 3 is reversed in each while body. (in-place not possible) |
| // |
| const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); |
| const Shape loop_state_shape = ShapeUtil::MakeTupleShape( |
| {element_shape, element_shape, element_shape, element_shape}); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto builder = HloComputation::Builder(TestName()); |
| auto param_0 = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, element_shape, "param_0")); |
| auto param_1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, element_shape, "param_1")); |
| auto param_2 = builder.AddInstruction( |
| HloInstruction::CreateParameter(2, element_shape, "param_2")); |
| auto param_3 = builder.AddInstruction( |
| HloInstruction::CreateParameter(3, element_shape, "param_3")); |
| |
| // The number of sequential kWhile instructions. |
| const int kNumWhiles = 3; |
| |
| HloInstruction* prev_element_1 = param_1; |
| HloInstruction* prev_element_2 = param_2; |
| HloInstruction* prev_element_3 = param_3; |
| |
| // Vector containing all of the while instructions. |
| std::vector<const HloInstruction*> whiles; |
| for (int i = 0; i < kNumWhiles; ++i) { |
| auto body_builder = HloComputation::Builder("body"); |
| auto body_param = body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto body_element_0 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, body_param, 0)); |
| auto body_element_1 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, body_param, 1)); |
| auto body_element_2 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, body_param, 2)); |
| auto body_element_3 = body_builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, body_param, 3)); |
| auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( |
| element_shape, HloOpcode::kNegate, body_element_2)); |
| auto reverse = body_builder.AddInstruction( |
| HloInstruction::CreateReverse(element_shape, body_element_3, {0})); |
| body_builder.AddInstruction(HloInstruction::CreateTuple( |
| {body_element_0, body_element_1, negate, reverse})); |
| HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto cond_builder = HloComputation::Builder("condition"); |
| cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "param")); |
| auto cond_constant = cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| cond_builder.AddInstruction(HloInstruction::CreateUnary( |
| cond_constant->shape(), HloOpcode::kNot, cond_constant)); |
| HloComputation* condition = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto while_init = builder.AddInstruction(HloInstruction::CreateTuple( |
| {param_0, prev_element_1, prev_element_2, prev_element_3})); |
| |
| auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile( |
| loop_state_shape, condition, body, while_init)); |
| whiles.push_back(xla_while); |
| if (i != kNumWhiles - 1) { |
| prev_element_1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1)); |
| prev_element_2 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2)); |
| prev_element_3 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3)); |
| } |
| } |
| |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| |
| // Each while body has one copy. And each loop state element is copied once in |
| // the entry computation. |
| EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles); |
| |
| // Each while body should have exactly one copy for element three which is an |
| // op (kReverse) which cannot be done in place. |
| for (const HloInstruction* xla_while : whiles) { |
| EXPECT_EQ(CountCopies(*xla_while->while_body()), 1); |
| } |
| |
| EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(), |
| op::Copy(), op::Copy())); |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(), |
| op::GetTupleElement())); |
| } |
| |
| TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { |
| // Test a while body and condition which are each simply a constant (root of |
| // computation is a constant). The body constant should be copied. |
| auto module = CreateNewVerifiedModule(); |
| auto builder = HloComputation::Builder(TestName()); |
| auto param_0 = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); |
| |
| auto body_builder = HloComputation::Builder("body"); |
| body_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, scalar_shape_, "param")); |
| body_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0))); |
| HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); |
| |
| auto cond_builder = HloComputation::Builder("condition"); |
| cond_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, scalar_shape_, "param")); |
| cond_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| HloComputation* condition = |
| module->AddEmbeddedComputation(cond_builder.Build()); |
| |
| auto xla_while = builder.AddInstruction( |
| HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); |
| |
| module->AddEntryComputation(builder.Build()); |
| |
| InsertCopies(module.get()); |
| |
| EXPECT_EQ(CountCopies(*module), 2); |
| |
| EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); |
| EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); |
| EXPECT_THAT(condition->root_instruction(), op::Constant()); |
| } |
| |
| TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) { |
| string module_string = R"( |
| HloModule TokensShouldNotBeCopied |
| |
| %Body (param.1: (s32[], token[])) -> (s32[], token[]) { |
| %param.1 = (s32[], token[]) parameter(0) |
| %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 |
| %constant.1 = s32[] constant(1) |
| %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) |
| %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 |
| %after-all = token[] after-all(token[] %get-tuple-element.2) |
| ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) |
| } |
| |
| %Cond (param: (s32[], token[])) -> pred[] { |
| %param = (s32[], token[]) parameter(0) |
| %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 |
| %constant = s32[] constant(42) |
| ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT |
| } |
| |
| ENTRY %TokensShouldNotBeCopied () -> s32[] { |
| %one = s32[] constant(1) |
| %negative_one = s32[] negate(%one) |
| %init_token = token[] after-all() |
| %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) |
| %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body |
| ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(module_string)); |
| InsertCopies(module.get()); |
| |
| // There should be no copies added because tokens should not be copied. |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) { |
| auto builder = HloComputation::Builder("trivial_condition"); |
| builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "loop_state")); |
| auto constant = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); |
| builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kNot, constant)); |
| return builder.Build(); |
| } |
| |
| std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() { |
| auto builder = HloComputation::Builder("benchmark_loop_body"); |
| const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); |
| const Shape loop_state_shape = |
| ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape}); |
| HloInstruction* param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); |
| HloInstruction* element_0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, param, 0)); |
| HloInstruction* element_1 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, param, 1)); |
| HloInstruction* element_2 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, param, 2)); |
| |
| HloInstruction* rev_1 = builder.AddInstruction( |
| HloInstruction::CreateReverse(element_shape, element_1, {0})); |
| HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary( |
| element_shape, HloOpcode::kAdd, element_1, element_2)); |
| |
| builder.AddInstruction( |
| HloInstruction::CreateTuple({element_0, rev_1, add_1_2})); |
| return builder.Build(); |
| } |
| |
| void BM_SequentialWhiles(::testing::benchmark::State& state) { |
| const int num_whiles = state.range(0); |
| |
| // This benchmark constructs a chain of sequential while instructions. |
| // Timer starts automatically at the first iteration of this loop |
| // and ends after the last one. |
| for (auto s : state) { |
| state.PauseTiming(); |
| HloModuleConfig config; |
| config.set_debug_options(GetDebugOptionsFromFlags()); |
| HloModule module("BM_SequentialWhiles", config); |
| |
| auto builder = HloComputation::Builder("BM_SequentialWhiles"); |
| HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, ShapeUtil::MakeShape(F32, {42}), "x")); |
| HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( |
| 1, ShapeUtil::MakeShape(F32, {42}), "y")); |
| HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( |
| 2, ShapeUtil::MakeShape(F32, {42}), "z")); |
| HloInstruction* init = |
| builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); |
| |
| HloInstruction* prev_loop_state = init; |
| for (int w = 0; w < num_whiles; ++w) { |
| HloComputation* condition = |
| module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); |
| HloComputation* body = |
| module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); |
| prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile( |
| init->shape(), condition, body, prev_loop_state)); |
| } |
| module.AddEntryComputation(builder.Build()); |
| |
| CopyInsertion copy_insertion; |
| |
| state.ResumeTiming(); |
| ASSERT_IS_OK(copy_insertion.Run(&module).status()); |
| state.PauseTiming(); |
| |
| // The entry computation should have three copies, and each body has one. |
| ASSERT_EQ(CountCopies(module), 3 + num_whiles); |
| state.ResumeTiming(); |
| } |
| } |
| |
| void BM_ParallelWhiles(::testing::benchmark::State& state) { |
| const int num_whiles = state.range(0); |
| |
| // This benchmark constructs a fan-out of parallel while instructions. |
| for (auto s : state) { |
| state.PauseTiming(); |
| HloModuleConfig config; |
| config.set_debug_options(GetDebugOptionsFromFlags()); |
| HloModule module("BM_SequentialWhiles", config); |
| |
| auto builder = HloComputation::Builder("BM_ParallelWhiles"); |
| HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, ShapeUtil::MakeShape(F32, {42}), "x")); |
| HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( |
| 1, ShapeUtil::MakeShape(F32, {42}), "y")); |
| HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( |
| 2, ShapeUtil::MakeShape(F32, {42}), "z")); |
| HloInstruction* init = |
| builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); |
| |
| HloInstruction* sum = nullptr; |
| for (int w = 0; w < num_whiles; ++w) { |
| HloComputation* condition = |
| module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); |
| HloComputation* body = |
| module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); |
| |
| HloInstruction* xla_while = builder.AddInstruction( |
| HloInstruction::CreateWhile(init->shape(), condition, body, init)); |
| |
| if (sum == nullptr) { |
| sum = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); |
| } else { |
| HloInstruction* element_0 = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); |
| sum = builder.AddInstruction(HloInstruction::CreateBinary( |
| x->shape(), HloOpcode::kAdd, sum, element_0)); |
| } |
| } |
| module.AddEntryComputation(builder.Build()); |
| |
| CopyInsertion copy_insertion; |
| |
| state.ResumeTiming(); |
| ASSERT_IS_OK(copy_insertion.Run(&module).status()); |
| state.PauseTiming(); |
| |
| // Each body receives of copy of two of the parameters (the corresponding |
| // elements in the body are modified), and there is one copy in each body. |
| ASSERT_EQ(CountCopies(module), 3 * num_whiles); |
| } |
| } |
| |
| std::unique_ptr<HloComputation> MakeBenchmarkWhileBody( |
| const int num_tuple_inputs) { |
| auto builder = HloComputation::Builder("benchmark_loop_body"); |
| const Shape element_shape = ShapeUtil::MakeShape(F32, {}); |
| std::vector<Shape> input_shape(num_tuple_inputs, element_shape); |
| const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape); |
| HloInstruction* param = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); |
| std::vector<HloInstruction*> gte_nodes(num_tuple_inputs); |
| for (int i = 0; i < num_tuple_inputs; ++i) { |
| gte_nodes[i] = builder.AddInstruction( |
| HloInstruction::CreateGetTupleElement(element_shape, param, i)); |
| } |
| builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes)); |
| return builder.Build(); |
| } |
| |
| void BM_ManyElementTuple(::testing::benchmark::State& state) { |
| const int num_tuple_inputs = state.range(0); |
| HloModuleConfig config; |
| config.set_debug_options(GetDebugOptionsFromFlags()); |
| CopyInsertion copy_insertion; |
| const Shape element_shape = ShapeUtil::MakeShape(F32, {}); |
| std::vector<HloInstruction*> tuple_params(num_tuple_inputs); |
| for (auto s : state) { |
| state.PauseTiming(); |
| auto builder = HloComputation::Builder("BM_ParallelWhiles"); |
| HloModule module("BM_ManyElementTuple", config); |
| for (int j = 0; j < num_tuple_inputs; ++j) { |
| tuple_params[j] = builder.AddInstruction( |
| HloInstruction::CreateParameter(j, element_shape, "")); |
| } |
| HloInstruction* init = |
| builder.AddInstruction(HloInstruction::CreateTuple(tuple_params)); |
| HloComputation* condition = |
| module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); |
| HloComputation* body = |
| module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs)); |
| HloInstruction* xla_while = builder.AddInstruction( |
| HloInstruction::CreateWhile(init->shape(), condition, body, init)); |
| builder.AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::MakeShape(F32, {}), xla_while, 0)); |
| module.AddEntryComputation(builder.Build()); |
| state.ResumeTiming(); |
| ASSERT_IS_OK(copy_insertion.Run(&module).status()); |
| } |
| } |
| |
| BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); |
| BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); |
| BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288); |
| |
| TEST_F(CopyInsertionTest, SimpleControlFlowTest) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| if-body.v5 { |
| constant.3 = s32[] constant(-1) |
| p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) |
| get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 |
| get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 |
| get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 |
| add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) |
| tuple.33 = (s32[]) tuple(add.3) |
| ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) |
| } |
| |
| if-condition.v4 { |
| p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) |
| get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 |
| constant.4 = s32[] constant(0) |
| ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ |
| } |
| |
| _functionalize_body_1__.v28 { |
| arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0 |
| constant.7 = s32[] constant(1) |
| add.4 = s32[] add(get-tuple-element.68, constant.7) |
| get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 |
| get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 |
| less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE |
| constant.8 = s32[] constant(0) |
| select = s32[] select(less-than-or-equal-to, constant.8, constant.7) |
| get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 |
| tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70) |
| tuple.36 = (s32[]) tuple(constant.8) |
| tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36) |
| while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5 |
| get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2 |
| get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0 |
| ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73) |
| } |
| |
| cond_wrapper.v3.1 { |
| inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 |
| constant.11 = s32[] constant(7) |
| ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT |
| } |
| |
| _functionalize_body_2__.v25 { |
| arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0 |
| get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2 |
| get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3 |
| get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4 |
| tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79) |
| while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 |
| get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0 |
| get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1 |
| constant.12 = s32[] constant(1) |
| add.5 = s32[] add(get-tuple-element.81, constant.12) |
| get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3 |
| ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82) |
| } |
| |
| cond_wrapper.v3.2 { |
| inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 |
| constant.13 = s32[] constant(5) |
| ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT |
| } |
| |
| ENTRY TestComputation { |
| arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) |
| ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 |
| } |
| )"; |
| auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); |
| auto module = module_or_status.ConsumeValueOrDie(); |
| InsertCopies(module.get()); |
| } |
| |
| TEST_F(CopyInsertionTest, ControlFlowTest) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| if-body.v5 { |
| constant.3 = s32[] constant(-1) |
| p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) |
| get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 |
| get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 |
| get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 |
| add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) |
| tuple.33 = (s32[]) tuple(add.3) |
| ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) |
| } |
| |
| if-condition.v4 { |
| p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) |
| get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 |
| constant.4 = s32[] constant(0) |
| ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ |
| } |
| |
| if-body.v5.1 { |
| constant.5 = s32[] constant(-1) |
| p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) |
| get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1 |
| get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2 |
| multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70) |
| tuple.35 = (s32[]) tuple(multiply.1) |
| ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35) |
| } |
| |
| if-condition.v4.1 { |
| p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) |
| get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 |
| constant.6 = s32[] constant(1) |
| ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ |
| } |
| |
| _functionalize_body_1__.v28 { |
| arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0 |
| constant.7 = s32[] constant(1) |
| add.4 = s32[] add(get-tuple-element.72, constant.7) |
| get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 |
| get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 |
| less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE |
| constant.8 = s32[] constant(0) |
| select = s32[] select(less-than-or-equal-to, constant.8, constant.7) |
| get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 |
| tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74) |
| tuple.38 = (s32[]) tuple(constant.8) |
| tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38) |
| while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5 |
| while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1 |
| get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2 |
| get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0 |
| ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77) |
| } |
| |
| cond_wrapper.v3.1 { |
| inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 |
| constant.11 = s32[] constant(7) |
| ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT |
| } |
| |
| _functionalize_body_2__.v25 { |
| arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0 |
| get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2 |
| get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3 |
| get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4 |
| tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82) |
| while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 |
| get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0 |
| get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1 |
| constant.12 = s32[] constant(1) |
| add.5 = s32[] add(get-tuple-element.84, constant.12) |
| get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3 |
| ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85) |
| } |
| |
| cond_wrapper.v3.2 { |
| inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) |
| get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 |
| constant.13 = s32[] constant(5) |
| ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT |
| } |
| |
| ENTRY TestComputation { |
| arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) |
| ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 |
| } |
| )"; |
| auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); |
| auto module = module_or_status.ConsumeValueOrDie(); |
| InsertCopies(module.get()); |
| } |
| |
| TEST_F(CopyInsertionTest, NestedWhiles) { |
| // Verify that only no unnecessary copies remain after copy insertion for |
| // trivial nested whiles (b/112472605). |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| cond.inner { |
| ROOT param.cond.inner = pred[] parameter(0) |
| } |
| |
| body.inner { |
| param.body.inner = pred[] parameter(0) |
| ROOT not = pred[] not(param.body.inner) |
| } |
| |
| cond.outer { |
| ROOT param.cond.outer = pred[] parameter(0) |
| } |
| |
| body.outer { |
| param.cond.outer = pred[] parameter(0) |
| ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner |
| } |
| |
| ENTRY TestComputation { |
| entry_param = pred[] parameter(0) |
| ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| |
| // There should only be a single copy inserted, and it's in the entry |
| // computation. |
| EXPECT_EQ(CountCopies(*module), 1); |
| EXPECT_THAT(module->entry_computation()->root_instruction(), |
| op::While(op::Copy(op::Parameter()))); |
| } |
| |
| TEST_F(CopyInsertionTest, NestedWhileAndConditional2) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| on_true |
| { |
| v1 = f32[2] parameter(0) |
| v2 = f32[2] add(v1,v1) |
| ROOT t1 = (f32[2], f32[2]) tuple(v1,v2) |
| } |
| |
| on_false |
| { |
| v1 = f32[2] parameter(0) |
| v2 = f32[2] multiply(v1,v1) |
| ROOT t2 = (f32[2], f32[2]) tuple(v1,v2) |
| } |
| |
| cond.outer { |
| param.1 = (pred[], f32[2], f32[2]) parameter(0) |
| ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 |
| } |
| |
| body.outer { |
| param.1 = (pred[], f32[2], f32[2]) parameter(0) |
| pred.1 = pred[] get-tuple-element(param.1), index=0 |
| arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 |
| if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false |
| e1 = f32[2] get-tuple-element(if), index=0 |
| e2 = f32[2] get-tuple-element(if), index=1 |
| ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2) |
| } |
| |
| ENTRY TestComputation { |
| entry_param.1 = pred[] parameter(0) |
| float_param = f32[2] parameter(1) |
| entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param) |
| ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| VLOG(2) << module->ToString() << "\n"; |
| |
| // An extra copy must be kept inside the loop due to uses in the conditional. |
| EXPECT_EQ(CountCopies(*module), 3); |
| } |
| |
| TEST_F(CopyInsertionTest, NestedWhileAndConditional) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| on_true |
| { |
| v1 = f32[2] parameter(0) |
| ROOT v2 = f32[2] add(v1,v1) |
| } |
| |
| on_false |
| { |
| v1 = f32[2] parameter(0) |
| ROOT v2 = f32[2] multiply(v1,v1) |
| } |
| |
| cond.outer { |
| param.1 = (pred[], f32[2]) parameter(0) |
| ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 |
| } |
| |
| body.outer { |
| param.1 = (pred[], f32[2]) parameter(0) |
| pred.1 = pred[] get-tuple-element(param.1), index=0 |
| arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 |
| if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false |
| ROOT res = (pred[], f32[2]) tuple(pred.1,if) |
| } |
| |
| ENTRY TestComputation { |
| entry_param.1 = pred[] parameter(0) |
| float_param = f32[2] parameter(1) |
| entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param) |
| ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| VLOG(2) << module->ToString() << "\n"; |
| |
| // There should only be two copies inserted, and in the entry and exit of the |
| // computation. |
| EXPECT_EQ(CountCopies(*module), 2); |
| } |
| |
| TEST_F(CopyInsertionTest, FixpointComputationRequired) { |
| const string& hlo_string = R"( |
| HloModule Module |
| |
| fused_computation { |
| param0 = f32[3,3,96,1] parameter(0) |
| param1 = f32[] parameter(1) |
| broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={} |
| ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast) |
| } |
| |
| ENTRY entry_computation { |
| arg0 = f32[3,3,96,1] parameter(0) |
| arg1 = f32[] parameter(1) |
| fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1), |
| kind=kLoop, calls=fused_computation |
| negate = f32[] negate(f32[] arg1) |
| ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple( |
| f32[3,3,96,1] fusion, |
| f32[3,3,96,1] arg0, |
| f32[] negate, |
| f32[] arg1) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| // Set up the aliasing manually which normally would be set by |
| // alias_passthrough_params pass. |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{1}, |
| /*param_number=*/0, |
| /*param_index=*/{})); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{3}, |
| /*param_number=*/1, |
| /*param_index=*/{})); |
| |
| InsertCopies(module.get()); |
| |
| // There should be no copies inserted. |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| TEST_F(CopyInsertionTest, NoAliasCheckViolation) { |
| const string& hlo_string = R"( |
| HloModule cluster |
| |
| ENTRY Entry { |
| %arg = f32[8,28,28,1] parameter(0) |
| %bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg) |
| ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{1}, |
| /*param_number=*/0, |
| /*param_index=*/{})); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) { |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| negate = f32[1280,1,128] negate(param) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) { |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| fused_computation { |
| param0 = f32[1280,1,128] parameter(0) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) |
| } |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| negate = f32[1280,1,128] negate(param) |
| ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) { |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| negate = f32[1280,1,128] negate(param) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| add = f32[1280,1,128] add(negate, negate) |
| dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) |
| ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) { |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) { |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| fused_computation { |
| param0 = f32[1280,1,128] parameter(0) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) |
| } |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| negate = f32[1280,1,128] negate(param) |
| add = f32[1280,1,128] add(negate, negate) |
| fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation |
| ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) { |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| ENTRY main { |
| state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={} |
| get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 |
| get-tuple-element.3 = s32[] get-tuple-element(state), index=0 |
| constant.2 = s32[] constant(128) |
| add.5 = s32[] add(get-tuple-element.3, constant.2) |
| constant.3 = s32[] constant(0) |
| dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) |
| dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) |
| ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) { |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| fused_computation.1 { |
| param0 = f32[1280,1,128] parameter(0) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) |
| } |
| |
| fused_computation.2 { |
| param0 = f32[1280,1,128] parameter(0) |
| param1 = f32[1280,1,128] parameter(1) |
| slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]} |
| constant.3 = s32[] constant(0) |
| ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3) |
| } |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| negate = f32[1280,1,128] negate(param) |
| add = f32[1280,1,128] add(negate, negate) |
| fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1 |
| ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2 |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) { |
| // Tests multi-output fusion with two DUS outputs, requiring two copies. |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| fused_computation { |
| param0 = f32[1280,1,128] parameter(0) |
| param1 = f32[1280,1,128] parameter(1) |
| param2 = f32[1280,1,128] parameter(2) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| add.1 = f32[1280,1,128] add(param0, param0) |
| dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) |
| dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) |
| ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) |
| } |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| negate0 = f32[1280,1,128] negate(param) |
| negate1 = f32[1280,1,128] negate(param) |
| negate2 = f32[1280,1,128] negate(param) |
| fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation |
| gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 |
| gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 |
| gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 |
| add0 = f32[1280,1,128] add(negate0, gte0) |
| add1 = f32[1280,1,128] add(negate1, gte1) |
| add2 = f32[1280,1,128] add(negate2, gte2) |
| ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 2); |
| } |
| |
| TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) { |
| // Same as above, but negate1 is not used beyond fusion, so it only needs one |
| // copy for negate0. |
| absl::string_view hlo_string = R"( |
| HloModule Module |
| |
| fused_computation { |
| param0 = f32[1280,1,128] parameter(0) |
| param1 = f32[1280,1,128] parameter(1) |
| param2 = f32[1280,1,128] parameter(2) |
| constant.1 = f32[] constant(0) |
| broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} |
| constant.3 = s32[] constant(0) |
| add.1 = f32[1280,1,128] add(param0, param0) |
| dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) |
| dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) |
| ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) |
| } |
| |
| ENTRY main { |
| param = f32[1280,1,128] parameter(0) |
| negate0 = f32[1280,1,128] negate(param) |
| negate1 = f32[1280,1,128] negate(param) |
| negate2 = f32[1280,1,128] negate(param) |
| fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation |
| gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 |
| gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 |
| gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 |
| add0 = f32[1280,1,128] add(negate0, gte0) |
| add1 = f32[1280,1,128] add(gte1, gte1) |
| add2 = f32[1280,1,128] add(negate2, gte2) |
| ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| EXPECT_EQ(CountCopies(*module), 1); |
| } |
| |
| TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) { |
| const string& hlo_string = R"( |
| HloModule test |
| |
| fused_computation { |
| p0 = f32[10,20] parameter(0) |
| p1 = f32[10,20] parameter(1) |
| p2 = f32[10,10] parameter(2) |
| p3 = f32[10,10] parameter(3) |
| add0 = f32[10, 20] add(p0, p1) |
| sub0 = f32[10, 10] subtract(p2, p3) |
| reshape0 = f32[200] reshape(add0) |
| reshape1 = f32[100] reshape(sub0) |
| concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0} |
| slice0 = f32[200] slice(concat0), slice={[0:200]} |
| slice1 = f32[100] slice(concat0), slice={[200:300]} |
| ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1) |
| } |
| |
| ENTRY test { |
| p0 = f32[10,20] parameter(0) |
| p1 = f32[10,20] parameter(1) |
| p2 = f32[10,10] parameter(2) |
| p3 = f32[10,10] parameter(3) |
| fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation |
| gte0 = f32[200] get-tuple-element(fusion), index=0 |
| gte1 = f32[100] get-tuple-element(fusion), index=1 |
| bitcast0 = f32[10,20] bitcast(gte0) |
| bitcast1 = f32[10,10] bitcast(gte1) |
| ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{0}, |
| /*param_number=*/0, |
| /*param_index=*/{})); |
| ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( |
| /*output_index=*/{1}, |
| /*param_number=*/3, |
| /*param_index=*/{})); |
| |
| InsertCopies(module.get()); |
| |
| // There should be no copies inserted. |
| EXPECT_EQ(CountCopies(*module), 0); |
| } |
| |
| TEST_F(CopyInsertionTest, NestedWhileAndConditional3) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| on_true.1 |
| { |
| ROOT v1 = f32[2] parameter(0) |
| } |
| |
| on_false.1 |
| { |
| v1 = f32[2] parameter(0) |
| ROOT v2 = f32[2] multiply(v1,v1) |
| } |
| |
| on_true |
| { |
| v1 = f32[2] parameter(0) |
| v2 = f32[2] add(v1,v1) |
| v3 = (f32[2],f32[2]) tuple(v1,v2) |
| v4 = f32[2] get-tuple-element(v3), index=1 |
| v5 = f32[2] multiply(v4,v2) |
| ROOT t1 = (f32[2], f32[2]) tuple(v5,v2) |
| } |
| |
| on_false |
| { |
| v1 = f32[2] parameter(0) |
| v2 = f32[2] multiply(v1,v1) |
| pred.1 = pred[] constant(true) |
| v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1 |
| v5 = f32[2] multiply(v4,v2) |
| ROOT t2 = (f32[2], f32[2]) tuple(v2,v5) |
| |
| } |
| |
| cond.outer { |
| param.1 = (pred[], f32[2], f32[2]) parameter(0) |
| ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 |
| } |
| |
| body.outer { |
| param.1 = (pred[], f32[2], f32[2]) parameter(0) |
| pred.1 = pred[] get-tuple-element(param.1), index=0 |
| arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 |
| if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false |
| e1 = f32[2] get-tuple-element(if), index=0 |
| e2 = f32[2] get-tuple-element(if), index=1 |
| ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2) |
| } |
| |
| ENTRY TestComputation { |
| entry_param.1 = pred[] parameter(0) |
| float_param = f32[2] parameter(1) |
| entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param) |
| ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| // An extra copy must be kept inside the loop due to uses in the conditional |
| EXPECT_EQ(CountCopies(*module), 4); |
| } |
| |
| TEST_F(CopyInsertionTest, ConditionalBranchMustCopy1) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| branch_0_comp.5.clone { |
| %parameter.0 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0 |
| %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element) |
| %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate) |
| ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy) |
| } |
| |
| branch_1_comp.12.clone { |
| %parameter.4 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0 |
| %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5) |
| ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1) |
| } |
| |
| ENTRY TestComputation { |
| %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3) |
| %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2) |
| %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0 |
| ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| CopyInsertion copy_insertion(nullptr, |
| /*use_region_based_live_range_analysis=*/true); |
| ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); |
| VLOG(3) << module->ToString(); |
| // The copy.1 must be kept due to modification in the other branch. |
| auto conditional18 = FindInstruction(module.get(), "conditional.18"); |
| CHECK_NE(conditional18, nullptr); |
| auto tuple6 = conditional18->branch_computation(1)->root_instruction(); |
| CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple); |
| auto copy1 = tuple6->operand(0); |
| CHECK_EQ(copy1->opcode(), HloOpcode::kCopy); |
| } |
| |
| TEST_F(CopyInsertionTest, ConditionalBranchMustCopy2) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| branch_0_comp.5.clone { |
| %parameter.0 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0 |
| %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element) |
| %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate) |
| ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy) |
| } |
| |
| branch_1_comp.12.clone { |
| %parameter.4 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0 |
| %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5) |
| %constant.1 = s32[] constant(0) |
| %broadcast.6 = s32[2] broadcast(constant.1), dimensions={} |
| dynamic-update-slice.5 = s32[2] dynamic-update-slice(%copy.1, %broadcast.6, %constant.1) |
| %add.1 = s32[2]{0:T(128)} add(dynamic-update-slice.5, %copy.1) |
| ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%add.1) |
| } |
| |
| ENTRY TestComputation { |
| %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3) |
| %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2) |
| %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0 |
| ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| // TODO(b/189898980): the region based live range analysis currently |
| // does not enforce a strict ordering of the merged live ranges. This wil |
| // cause the following invocation to fail when run under UNDEBUG mode. |
| #if 0 |
| CopyInsertion copy_insertion(nullptr, |
| /*use_region_based_live_range_analysis=*/true); |
| ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); |
| // The copy.1 must be kept due to modification in the other branch. |
| auto conditional18 = FindInstruction(module.get(), "conditional.18"); |
| CHECK_NE(conditional18, nullptr); |
| auto tuple6 = conditional18->branch_computation(1)->root_instruction(); |
| CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple); |
| auto add1 = tuple6->operand(0); |
| CHECK_EQ(add1->opcode(), HloOpcode::kAdd); |
| auto dus = add1->operand(0); |
| auto copy1 = dus->operand(0); |
| CHECK_EQ(copy1->opcode(), HloOpcode::kCopy); |
| #endif |
| } |
| |
| TEST_F(CopyInsertionTest, ConditionalBranchMustCopy3) { |
| const string& hlo_string = R"( |
| HloModule primitive_computation_cond.19 |
| %branch_0_comp.5.clone (parameter.0: (s32[2])) -> (s32[2]) { |
| %parameter.0 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0 |
| %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element) |
| %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate) |
| ROOT %tuple.5 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy) |
| } |
| |
| %branch_1_comp.12.clone (parameter.4: (s32[2])) -> (s32[2]) { |
| %parameter.4 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0 |
| %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5) |
| ROOT %tuple.6 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy.1) |
| } |
| |
| ENTRY %primitive_computation_cond.19 (parameter.1: s32[], parameter.2: s32[2], parameter.3: s32[2]) -> (s32[2]) { |
| %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3) |
| %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2) |
| ROOT %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| InsertCopies(module.get()); |
| CopyInsertion copy_insertion(nullptr, |
| /*use_region_based_live_range_analysis=*/true); |
| ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); |
| VLOG(3) << module->ToString(); |
| // The copy.1 must be kept b/c aliasing of parameter and root is not allowed. |
| auto conditional18 = FindInstruction(module.get(), "conditional.18"); |
| CHECK_NE(conditional18, nullptr); |
| auto tuple6 = conditional18->branch_computation(1)->root_instruction(); |
| CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple); |
| auto copy1 = tuple6->operand(0); |
| CHECK_EQ(copy1->opcode(), HloOpcode::kCopy); |
| } |
| |
| TEST_F(CopyInsertionTest, ConditionalBranchDoNotCopy1) { |
| const string& hlo_string = R"( |
| HloModule TestModule |
| |
| branch_0_comp.5.clone { |
| %parameter.0 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0 |
| %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element) |
| %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate) |
| ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy) |
| } |
| |
| branch_1_comp.12.clone { |
| %parameter.4 = (s32[2]{0:T(128)}) parameter(0) |
| %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0 |
| %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5) |
| ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1) |
| } |
| |
| ENTRY TestComputation { |
| %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3) |
| %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2) |
| %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"} |
| %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0 |
| ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(gte.1, gte.1) |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| CopyInsertion copy_insertion(nullptr, |
| /*use_region_based_live_range_analysis=*/true); |
| ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); |
| VLOG(3) << module->ToString() << "\n"; |
| |
| // The copy.1 must be kept due to modification in the other branch. |
| auto conditional18 = FindInstruction(module.get(), "conditional.18"); |
| CHECK_NE(conditional18, nullptr); |
| auto tuple6 = conditional18->branch_computation(1)->root_instruction(); |
| CHECK_EQ(tuple6->opcode(), HloOpcode::kParameter); |
| } |
| |
| TEST_F(CopyInsertionTest, RootInstructionNotLast) { |
| // This is a test for b/189219227. When the root instruction is scheduled not |
| // as the last instruction, it still lives out. So, we make sure that the copy |
| // after the root cannot be removed. |
| const string& hlo_string = R"( |
| HloModule module, is_scheduled=true |
| |
| body2 { |
| p_body2 = (f32[2]{0}) parameter(0) |
| p_body2.1 = f32[2]{0} get-tuple-element(p_body2), index=0 |
| add.3 = f32[2]{0} add(p_body2.1, p_body2.1) |
| ROOT root2 = (f32[2]{0}) tuple(add.3) |
| } |
| |
| condition2 { |
| p_cond2 = (f32[2]{0}) parameter(0) |
| ROOT result = pred[] constant(true) |
| } |
| |
| body { |
| p_body = (f32[2]{0}) parameter(0) |
| p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0 |
| ROOT root = (f32[2]{0}) tuple(p_body.1) |
| copy = f32[2]{0} copy(p_body.1) |
| tuple = (f32[2]{0}) tuple(copy) |
| while.1 = (f32[2]{0}) while(tuple), condition=condition2, body=body2 |
| } |
| |
| condition { |
| p_cond = (f32[2]{0}) parameter(0) |
| ROOT result = pred[] constant(true) |
| } |
| |
| ENTRY entry { |
| const0 = f32[2]{0} constant({1, 2}) |
| while_init = (f32[2]{0}) tuple(const0) |
| ROOT while.0 = (f32[2]{0}) while(while_init), condition=condition, body=body |
| } |
| )"; |
| TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, |
| ParseAndReturnVerifiedModule(hlo_string)); |
| CopyInsertion copy_insertion(nullptr, |
| /*use_region_based_live_range_analysis=*/true); |
| SequentialHloOrdering ordering(module->schedule()); |
| ASSERT_IS_OK(copy_insertion.RemoveUnnecessaryCopies(&ordering, module.get())); |
| auto while_1 = FindInstruction(module.get(), "while.1"); |
| EXPECT_THAT(while_1, op::While(op::Tuple(op::Copy()))); |
| } |
| |
| } // namespace |
| } // namespace xla |