| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/hlo_cse.h" |
| |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/memory/memory.h" |
| #include "absl/strings/substitute.h" |
| #include "tensorflow/compiler/xla/layout_util.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_matchers.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/hlo_parser.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h" |
| #include "tensorflow/compiler/xla/tests/literal_test_util.h" |
| #include "tensorflow/compiler/xla/tests/test_utils.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| |
| namespace op = xla::testing::opcode_matchers; |
| |
| namespace xla { |
| namespace { |
| |
| class HloCseTest : public HloTestBase { |
| protected: |
| HloCseTest() {} |
| }; |
| |
| TEST_F(HloCseTest, CombineTwoConstants) { |
| // Test that two identical constants are commoned. |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); |
| auto constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); |
| builder.AddInstruction(HloInstruction::CreateBinary( |
| constant1->shape(), HloOpcode::kAdd, constant1, constant2)); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(3, computation->instruction_count()); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(2, computation->instruction_count()); |
| HloInstruction* constant = *computation->instructions().begin(); |
| EXPECT_EQ(42.0f, constant->literal().Get<float>({})); |
| |
| auto result = ExecuteAndTransfer(module->Clone(), {}); |
| auto expected = LiteralUtil::CreateR0<float>(84.0); |
| EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); |
| } |
| |
| TEST_F(HloCseTest, CombineTwoConstantsDifferentLayouts) { |
| // Test that two identical constants with different layouts are *not* |
| // combined. |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( |
| {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); |
| auto constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( |
| {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); |
| auto add = builder.AddInstruction(HloInstruction::CreateBinary( |
| constant1->shape(), HloOpcode::kAdd, constant1, constant2)); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(3, computation->instruction_count()); |
| EXPECT_THAT(add, op::Add(constant1, constant2)); |
| |
| HloCSE cse(/*is_layout_sensitive=*/true); |
| EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(3, computation->instruction_count()); |
| EXPECT_THAT(add, op::Add(constant1, constant2)); |
| |
| auto result = ExecuteAndTransfer(module->Clone(), {}); |
| auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); |
| EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); |
| } |
| |
| TEST_F(HloCseTest, ConstantsSameValueDifferentType) { |
| // Test that constants with the same value but different type are *not* |
| // commoned. |
| auto builder = HloComputation::Builder(TestName()); |
| std::vector<HloInstruction*> constants; |
| constants.push_back(builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42)))); |
| constants.push_back(builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42)))); |
| constants.push_back(builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64_t>(42.0)))); |
| constants.push_back(builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64_t>(42.0)))); |
| constants.push_back(builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0)))); |
| constants.push_back(builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)))); |
| // Duplicate the float constant to verify something happens. |
| constants.push_back(builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)))); |
| |
| const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); |
| for (int64_t i = 0; i < constants.size(); ++i) { |
| constants[i] = builder.AddInstruction( |
| HloInstruction::CreateConvert(shape_r0, constants[i])); |
| } |
| HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( |
| shape_r0, HloOpcode::kAdd, constants[0], constants[1])); |
| for (int64_t i = 2; i < constants.size(); ++i) { |
| root = builder.AddInstruction(HloInstruction::CreateBinary( |
| shape_r0, HloOpcode::kAdd, root, constants[i])); |
| } |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(20, computation->instruction_count()); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); |
| |
| // CSE will remove both the second float(42.0f) and the corresponding |
| // convert/cast. |
| EXPECT_EQ(18, computation->instruction_count()); |
| } |
| |
| TEST_F(HloCseTest, NonscalarConstants) { |
| // Test that identical nonscalar constants are merged. |
| auto builder = HloComputation::Builder(TestName()); |
| auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); |
| auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); |
| // Create a constant which has the same shape but a different value. |
| auto uncommon_constant = |
| builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}))); |
| |
| // Tie the constants together with a tuple. This makes it easier to refer to |
| // the constant instructions via their use. |
| auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( |
| {common_constant1, common_constant2, uncommon_constant})); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| EXPECT_THAT(tuple, |
| op::Tuple(common_constant1, common_constant2, uncommon_constant)); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(3, computation->instruction_count()); |
| auto first_operand = tuple->operand(0); |
| EXPECT_THAT(first_operand, |
| ::testing::AnyOf(common_constant1, common_constant2)); |
| EXPECT_THAT(tuple, |
| op::Tuple(first_operand, first_operand, uncommon_constant)); |
| } |
| |
| TEST_F(HloCseTest, IdenticalInstructions) { |
| // Test that three identical instructions are commoned. |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); |
| auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| auto tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(5, computation->instruction_count()); |
| EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); |
| |
| HloCSE cse(/*is_layout_sensitive=*/true); |
| EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(3, computation->instruction_count()); |
| auto first_operand = tuple->operand(0); |
| EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3)); |
| EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand)); |
| } |
| |
| // Test two identical while loops with same inputs |
| TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { |
| const char* const hlo_string = R"( |
| HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput |
| |
| %body (param: (f32[], f32[])) -> (f32[], f32[]) { |
| %param = (f32[], f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) |
| ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) |
| } |
| |
| %condition (param.1: (f32[], f32[])) -> pred[] { |
| %param.1 = (f32[], f32[]) parameter(0) |
| ROOT %constant = pred[] constant(false) |
| } |
| |
| %condition.1 (param.2: (f32[], f32[])) -> pred[] { |
| %param.2 = (f32[], f32[]) parameter(0) |
| ROOT %constant.1 = pred[] constant(false) |
| } |
| |
| ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput () -> (f32[], f32[]) |
| { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %tuple.1 = |
| (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) %while = (f32[], |
| f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT |
| %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), |
| condition=%condition.1, body=%body |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| auto computation = m->entry_computation(); |
| |
| EXPECT_EQ(5, computation->instruction_count()); |
| HloCSE cse(true); |
| EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); |
| EXPECT_EQ(4, computation->instruction_count()); |
| } |
| |
| // Test two while loops with same conditions, same inputs, but different |
| // bodies |
| TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { |
| const char* const hlo_string = R"( |
| HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies |
| |
| %body (param: (f32[], f32[])) -> (f32[], f32[]) { |
| %param = (f32[], f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) |
| ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) |
| } |
| |
| %body2 (param.1: (f32[], f32[])) -> (f32[], f32[]) { |
| %param.1 = (f32[], f32[]) parameter(0) |
| %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %param.1), |
| index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %param.1), |
| index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] |
| %get-tuple-element.3) ROOT %tuple.2 = (f32[], f32[]) tuple(f32[] |
| %get-tuple-element.2, f32[] %sub) |
| } |
| |
| %condition (param.2: (f32[], f32[])) -> pred[] { |
| %param.2 = (f32[], f32[]) parameter(0) |
| ROOT %constant = pred[] constant(false) |
| } |
| |
| %condition.1 (param.3: (f32[], f32[])) -> pred[] { |
| %param.3 = (f32[], f32[]) parameter(0) |
| ROOT %constant.1 = pred[] constant(false) |
| } |
| |
| ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies () -> |
| (f32[], f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) |
| %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) |
| %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), |
| condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], |
| f32[]) %tuple.1), condition=%condition.1, body=%body2 |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| auto computation = m->entry_computation(); |
| |
| EXPECT_EQ(5, computation->instruction_count()); |
| HloCSE cse(true); |
| EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); |
| EXPECT_EQ(5, computation->instruction_count()); |
| } |
| |
| // Test two identical while loops with different inputs |
| TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { |
| const char* const hlo_string = R"( |
| HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput |
| |
| %body (param: (f32[], f32[])) -> (f32[], f32[]) { |
| %param = (f32[], f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) |
| ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) |
| } |
| |
| %condition (param.1: (f32[], f32[])) -> pred[] { |
| %param.1 = (f32[], f32[]) parameter(0) |
| ROOT %constant = pred[] constant(false) |
| } |
| |
| %condition.1 (param.2: (f32[], f32[])) -> pred[] { |
| %param.2 = (f32[], f32[]) parameter(0) |
| ROOT %constant.1 = pred[] constant(false) |
| } |
| |
| ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput () -> (f32[], |
| f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) |
| %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) |
| %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), |
| condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 = |
| f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] |
| %constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2), |
| condition=%condition.1, body=%body |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| auto computation = m->entry_computation(); |
| |
| EXPECT_EQ(8, computation->instruction_count()); |
| HloCSE cse(true); |
| EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); |
| EXPECT_EQ(8, computation->instruction_count()); |
| } |
| |
| // Test two while loops with identical bodies and same inputs, but different |
| // conditions |
| TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferentConditions) { |
| const char* const hlo_string = R"( |
| HloModule WhileLoopsIdenticalBodiesAndInputDifferentConditions |
| |
| %body (param: (f32[], f32[])) -> (f32[], f32[]) { |
| %param = (f32[], f32[]) parameter(0) |
| %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), |
| index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) |
| ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) |
| } |
| |
| %condition (param.1: (f32[], f32[])) -> pred[] { |
| %param.1 = (f32[], f32[]) parameter(0) |
| ROOT %constant = pred[] constant(false) |
| } |
| |
| %condition.1 (param.2: (f32[], f32[])) -> pred[] { |
| %param.2 = (f32[], f32[]) parameter(0) |
| ROOT %constant.1 = pred[] constant(true) |
| } |
| |
| ENTRY %WhileLoopsIdenticalBodiesAndInputDifferentConditions () -> (f32[], |
| f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) |
| %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) |
| %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), |
| condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], |
| f32[]) %tuple.1), condition=%condition.1, body=%body |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| auto computation = m->entry_computation(); |
| |
| EXPECT_EQ(5, computation->instruction_count()); |
| HloCSE cse(true); |
| EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); |
| EXPECT_EQ(5, computation->instruction_count()); |
| } |
| |
| TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { |
| // Test that two identical instructions with different layouts are *not* |
| // commoned if the pass is layout sensitive. |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant = builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); |
| |
| auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); |
| |
| auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); |
| |
| auto tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); |
| |
| HloCSE cse(/*is_layout_sensitive=*/true); |
| EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); |
| } |
| |
| TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { |
| // Test that two identical instructions with different layouts are commoned if |
| // the pass is layout insensitive. |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant = builder.AddInstruction(HloInstruction::CreateConstant( |
| LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); |
| |
| auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); |
| |
| auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); |
| |
| auto tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(3, computation->instruction_count()); |
| auto first_operand = tuple->operand(0); |
| EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2)); |
| EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand)); |
| } |
| |
| TEST_F(HloCseTest, FusionInternalCSE) { |
| // Test that we can CSE expressions that live within a fusion node |
| // computation. |
| auto module = CreateNewVerifiedModule(); |
| auto builder = HloComputation::Builder(TestName()); |
| |
| const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); |
| auto param0 = builder.AddInstruction( |
| HloInstruction::CreateParameter(0, shape_r0, "p0")); |
| auto param1 = builder.AddInstruction( |
| HloInstruction::CreateParameter(1, shape_r0, "p1")); |
| auto add1 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1)); |
| auto add2 = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1)); |
| auto mul = builder.AddInstruction( |
| HloInstruction::CreateBinary(shape_r0, HloOpcode::kMultiply, add1, add2)); |
| |
| auto computation = module->AddEntryComputation(builder.Build()); |
| auto fused_computation = |
| computation |
| ->CreateFusionInstruction({mul, add1, add2}, |
| HloInstruction::FusionKind::kLoop) |
| ->fused_instructions_computation(); |
| |
| EXPECT_EQ(5, fused_computation->instruction_count()); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); |
| EXPECT_EQ(4, fused_computation->instruction_count()); |
| |
| auto root = fused_computation->root_instruction(); |
| EXPECT_THAT(root, op::Multiply(root->operand(0), root->operand(0))); |
| } |
| |
| TEST_F(HloCseTest, IdenticalExpressions) { |
| // Test that two identical expressions are commoned. Build the following |
| // computation: |
| // |
| // constant = 42.0 |
| // negate1 = neg(constant) |
| // exp1 = exp(constant) |
| // add1 = add(negate1, exp1) |
| // negate2 = neg(constant) |
| // exp2 = exp(constant) |
| // add2 = add(negate2, exp2) |
| // tuple = tuple(add1, add2) |
| // |
| // The *1 instructions should be merged with the *2 instructions. |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); |
| |
| auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kNegate, constant)); |
| auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( |
| constant->shape(), HloOpcode::kAdd, negate1, exp1)); |
| |
| auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kNegate, constant)); |
| auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( |
| constant->shape(), HloOpcode::kExp, constant)); |
| auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( |
| constant->shape(), HloOpcode::kAdd, negate2, exp2)); |
| |
| auto tuple = |
| builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(8, computation->instruction_count()); |
| EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(5, computation->instruction_count()); |
| auto operand = tuple->operand(0); |
| EXPECT_THAT(tuple, op::Tuple(operand, operand)); |
| EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp())); |
| } |
| |
| TEST_F(HloCseTest, DoNotCombineRng) { |
| // Test that two RNG ops are not commoned. |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); |
| auto constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); |
| auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( |
| ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, |
| {constant1, constant2})); |
| auto rng2 = builder.AddInstruction(HloInstruction::CreateRng( |
| ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, |
| {constant1, constant2})); |
| |
| builder.AddInstruction(HloInstruction::CreateBinary( |
| constant1->shape(), HloOpcode::kAdd, rng1, rng2)); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| HloInstruction* root = computation->root_instruction(); |
| EXPECT_THAT(root, op::Add(rng1, rng2)); |
| |
| uint32_t count_before = computation->instruction_count(); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); |
| |
| uint32_t count_after = computation->instruction_count(); |
| EXPECT_EQ(count_before, count_after); |
| root = computation->root_instruction(); |
| EXPECT_THAT(root, op::Add(rng1, rng2)); |
| } |
| |
| TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { |
| // Test that two calls to an impure function are not commoned. RNG |
| // is the source of the impurity. |
| |
| auto module = CreateNewVerifiedModule(); |
| |
| // rng_function is an impure function because it does RNG. |
| HloComputation* rng_function = nullptr; |
| { |
| Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); |
| auto builder = HloComputation::Builder(TestName() + "_rng_fun"); |
| auto constant1 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); |
| auto constant2 = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); |
| auto rng = builder.AddInstruction(HloInstruction::CreateRng( |
| scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); |
| auto param = builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, ShapeUtil::MakeShape(F32, {}), "param")); |
| builder.AddInstruction(HloInstruction::CreateBinary( |
| scalar_shape, HloOpcode::kAdd, rng, param)); |
| rng_function = module->AddEmbeddedComputation(builder.Build()); |
| } |
| |
| // Computation calls rng_function twice with the same parameter. |
| HloComputation* computation = nullptr; |
| { |
| auto builder = HloComputation::Builder(TestName()); |
| auto constant = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f}))); |
| auto rng1 = builder.AddInstruction( |
| HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); |
| auto rng2 = builder.AddInstruction( |
| HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); |
| builder.AddInstruction(HloInstruction::CreateBinary( |
| constant->shape(), HloOpcode::kAdd, rng1, rng2)); |
| computation = module->AddEntryComputation(builder.Build()); |
| } |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| HloInstruction* root = computation->root_instruction(); |
| EXPECT_THAT(root, op::Add(op::Map(), op::Map())); |
| |
| VLOG(3) << "before: " << module->ToString(); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); |
| |
| VLOG(3) << "after: " << module->ToString(); |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| root = computation->root_instruction(); |
| EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); |
| } |
| |
| TEST_F(HloCseTest, CompareComputations) { |
| const char* const hlo_string = R"( |
| HloModule m |
| |
| add_computation { |
| add_lhs = f32[] parameter(0) |
| add_rhs = f32[] parameter(1) |
| ROOT add_root = f32[] add(add_lhs, add_rhs) |
| } |
| |
| add_computation2 { |
| add_lhs2 = f32[] parameter(0) |
| add_rhs2 = f32[] parameter(1) |
| ROOT add_root2 = f32[] add(add_lhs2, add_rhs2) |
| } |
| |
| ENTRY entry { |
| p = f32[10]{0} parameter(0) |
| c = f32[] constant(0) |
| r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation |
| r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 |
| ROOT f2 = (f32[],f32[]) tuple(r1, r2) |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); |
| HloInstruction* root = m->entry_computation()->root_instruction(); |
| EXPECT_EQ(root->operand(0), root->operand(1)); |
| } |
| |
| TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { |
| // Test that constants and iotas with the same value but in different domains |
| // (disjoint in this case) are not collapsed. |
| auto builder = HloComputation::Builder(TestName()); |
| builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42))); |
| builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42))); |
| builder.AddInstruction( |
| HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), 0)); |
| builder.AddInstruction( |
| HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), 0)); |
| |
| auto module = CreateNewVerifiedModule(); |
| auto computation = module->AddEntryComputation(builder.Build()); |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); |
| |
| EXPECT_EQ(4, computation->instruction_count()); |
| } |
| |
| TEST_F(HloCseTest, Domain) { |
| const char* const hlo_string = R"( |
| HloModule module |
| ENTRY %entry { |
| %param = f32[] parameter(0), sharding={maximal device=0} |
| %domain.0 = f32[] domain(%param), |
| domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} |
| %domain.1 = f32[] domain(%param), |
| domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} |
| %domain.2 = f32[] domain(%param), |
| domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}} |
| %negate.0 = f32[] negate(%domain.0) |
| %negate.1 = f32[] negate(%domain.1) |
| %negate.2 = f32[] negate(%domain.2) |
| %domain.3 = f32[] domain(%negate.0), |
| domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} |
| %domain.4 = f32[] domain(%negate.1), |
| domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} |
| %domain.5 = f32[] domain(%negate.2), |
| domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} |
| %add = f32[] add(%domain.3, %domain.4) |
| ROOT %sub = f32[] subtract(%add, %domain.5) |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); |
| const HloInstruction* sub = m->entry_computation()->root_instruction(); |
| const HloInstruction* add = sub->operand(0); |
| EXPECT_EQ(add->operand(0), add->operand(1)); |
| EXPECT_NE(add->operand(0), sub->operand(1)); |
| EXPECT_NE(add->operand(1), sub->operand(1)); |
| } |
| |
| TEST_F(HloCseTest, Iota) { |
| const char* const hlo_string = R"( |
| HloModule m |
| |
| ENTRY entry { |
| i1 = s64[16,16] iota(), iota_dimension=0 |
| i2 = s64[16,16] iota(), iota_dimension=0 |
| i3 = s64[17,16] iota(), iota_dimension=0 |
| i4 = s64[16,16] iota(), iota_dimension=1 |
| ROOT root = (s64[16,16], s64[16,16], s64[17,16], s64[16,16]) tuple(i1, i2, i3, i4) |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get())); |
| EXPECT_TRUE(changed); |
| HloInstruction* root = m->entry_computation()->root_instruction(); |
| EXPECT_EQ(root->operand(0), root->operand(1)); |
| EXPECT_NE(root->operand(0), root->operand(2)); |
| EXPECT_NE(root->operand(0), root->operand(3)); |
| } |
| |
| TEST_F(HloCseTest, OptimizationBarrier) { |
| const char* const hlo_string = R"( |
| HloModule m |
| |
| ENTRY entry { |
| %param.0 = f32[] parameter(0) |
| %param.1 = f32[] parameter(1) |
| %add.0 = f32[] add(%param.0, %param.1) |
| %cse_tmp.0 = (f32[], f32[], f32[]) tuple(%param.0, %param.1, %add.0) |
| %cse_tmp.1 = (f32[], f32[], f32[]) opt-barrier(%cse_tmp.0) |
| |
| %param.0.1 = f32[] get-tuple-element(%cse_tmp.1), index=0 |
| %param.1.1 = f32[] get-tuple-element(%cse_tmp.1), index=1 |
| %add.0.1 = f32[] get-tuple-element(%cse_tmp.1), index=2 |
| |
| %add.1 = f32[] add(%param.0.1, %param.1.1) |
| ROOT %add.2 = f32[] add(%add.1, %add.0.1) |
| })"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get())); |
| EXPECT_FALSE(changed); |
| } |
| |
| class HloCseCustomCallTest |
| : public HloCseTest, |
| public ::testing::WithParamInterface<std::tuple< |
| std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>> {}; |
| |
| TEST_P(HloCseCustomCallTest, DoIt) { |
| std::string op1 = std::get<0>(GetParam()); |
| std::string op2 = std::get<1>(GetParam()); |
| bool should_cse = std::get<2>(GetParam()); |
| |
| const char* const hlo_string_tmpl = R"( |
| HloModule m |
| ENTRY entry { |
| p0 = f32[1,1,1] parameter(0) |
| |
| op0 = $0 |
| op1 = $0 |
| op2 = $1 |
| ROOT root = tuple(op0, op1, op2) |
| } |
| )"; |
| std::string hlo_string = absl::Substitute(hlo_string_tmpl, op1, op2); |
| SCOPED_TRACE(absl::StrCat("Module before CSE:\n", hlo_string)); |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get())); |
| |
| SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString())); |
| EXPECT_EQ(changed, true); // we always CSE op0 and op1, which are identical. |
| HloInstruction* root = m->entry_computation()->root_instruction(); |
| EXPECT_EQ(root->operand(0), root->operand(1)) |
| << "Identical ops should be CSE'ed"; |
| if (should_cse) { |
| EXPECT_EQ(root->operand(0), root->operand(2)) << "Ops should be CSE'ed"; |
| } else { |
| EXPECT_NE(root->operand(0), root->operand(2)) << "Ops should not be CSE'ed"; |
| } |
| } |
| |
| static std::vector< |
| std::tuple<std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>> |
| CustomCallTests() { |
| auto build = [](absl::string_view args1, absl::string_view args2) { |
| absl::string_view prefix = |
| "f32[] custom-call(p0), custom_call_target=\"foo\", "; |
| return std::make_tuple(absl::StrCat(prefix, args1), |
| absl::StrCat(prefix, args2), false); |
| }; |
| return { |
| { |
| // metadata shouldn't prevent CSE |
| "f32[] custom-call(p0), custom_call_target=\"foo\"", |
| "f32[] custom-call(p0), custom_call_target=\"foo\", " |
| "metadata={op_name=\"bar\"}", |
| true, |
| }, |
| { |
| "f32[] custom-call(p0), custom_call_target=\"foo\"", |
| "f32[] custom-call(p0, p0), custom_call_target=\"foo\"", |
| false, |
| }, |
| { |
| "f32[1] custom-call(p0), custom_call_target=\"foo\"", |
| "f32[2] custom-call(p0), custom_call_target=\"foo\"", |
| false, |
| }, |
| { |
| "f32[] custom-call(p0), custom_call_target=\"foo\"", |
| "f32[] custom-call(p0), custom_call_target=\"bar\"", |
| false, |
| }, |
| |
| build("window={size=1}", "window={size=2}"), |
| build("dim_labels=b0f_0oi->b0f", "dim_labels=b0f_0oi->bf0"), |
| build("backend_config=\"foo\"", "backend_config=\"bar\""), |
| build("literal=s32[] 0", "literal=s32[] 1"), |
| build("literal=s32[] 0", "literal=f32[] 0"), |
| build("operand_precision={high,default}", |
| "operand_precision={high, high}"), |
| build("api_version=API_VERSION_STATUS_RETURNING", |
| "api_version=API_VERSION_ORIGINAL"), |
| build("feature_group_count=0", "feature_group_count=1"), |
| }; |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(HloCseCustomCallTestSuite, HloCseCustomCallTest, |
| ::testing::ValuesIn(CustomCallTests())); |
| |
| TEST_F(HloCseTest, CustomCallCalledComputations) { |
| const char* const hlo_string = R"( |
| HloModule m |
| |
| comp { |
| lhs = f32[] parameter(0) |
| rhs = f32[] parameter(1) |
| ROOT maximum = f32[] maximum(lhs, rhs) |
| } |
| |
| ENTRY entry { |
| p0 = f32[] parameter(0) |
| |
| op0 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp} |
| op1 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp, comp} |
| ROOT root = tuple(op0, op1) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get())); |
| |
| SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString())); |
| EXPECT_EQ(changed, false); |
| } |
| |
| TEST_F(HloCseTest, CustomCallSideEffects) { |
| const char* const hlo_string = R"( |
| HloModule m |
| |
| ENTRY entry { |
| p0 = f32[] parameter(0) |
| |
| op0 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true |
| op1 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true |
| ROOT root = tuple(op0, op1) |
| } |
| )"; |
| |
| TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); |
| HloCSE cse(/*is_layout_sensitive=*/false); |
| TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get())); |
| |
| SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString())); |
| EXPECT_EQ(changed, false); |
| } |
| |
| } // namespace |
| } // namespace xla |