| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/conditional_simplifier.h" |
| |
| #include <string> |
| #include <utility> |
| |
| #include "tensorflow/compiler/xla/literal_util.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_matchers.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/hlo_parser.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/test.h" |
| #include "tensorflow/compiler/xla/tests/hlo_test_base.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/platform/types.h" |
| |
| namespace xla { |
| namespace { |
| |
| namespace op = xla::testing::opcode_matchers; |
| |
| class ConditionalSimplifierTest : public HloTestBase { |
| public: |
| // Makes a computation that contains a conditional with constant predicate. |
| HloComputation* MakeConditional(HloModule* module, bool is_constant = true); |
| }; |
| |
| HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module, |
| bool is_constant) { |
| HloComputation::Builder builder(TestName()); |
| |
| // true_computation returns param+1. |
| HloComputation* true_computation; |
| { |
| HloComputation::Builder true_computation_builder(TestName() + |
| ".true_computation"); |
| auto param = |
| true_computation_builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, ShapeUtil::MakeShape(S32, {}), "param")); |
| auto one = true_computation_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); |
| |
| true_computation_builder.AddInstruction(HloInstruction::CreateBinary( |
| ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one)); |
| |
| true_computation = |
| module->AddEmbeddedComputation(true_computation_builder.Build()); |
| } |
| |
| // false_computation returns param+42. |
| HloComputation* false_computation; |
| { |
| HloComputation::Builder false_computation_builder(TestName() + |
| ".false_computation"); |
| auto param = false_computation_builder.AddInstruction( |
| HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), |
| "param")); |
| auto forty_two = false_computation_builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42))); |
| |
| false_computation_builder.AddInstruction(HloInstruction::CreateBinary( |
| ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two)); |
| false_computation = |
| module->AddEmbeddedComputation(false_computation_builder.Build()); |
| } |
| |
| auto false_instrn = builder.AddInstruction( |
| is_constant |
| ? HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)) |
| : HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}), |
| "cond")); |
| auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( |
| 0, ShapeUtil::MakeShape(S32, {}), "false_param")); |
| auto one = builder.AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); |
| |
| builder.AddInstruction(HloInstruction::CreateConditional( |
| ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation, |
| false_param, false_computation)); |
| |
| return module->AddEntryComputation(builder.Build()); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { |
| auto m = CreateNewVerifiedModule(); |
| HloComputation* computation = MakeConditional(m.get()); |
| ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); |
| EXPECT_THAT(computation->root_instruction(), |
| op::Add(op::Parameter(), op::Constant())); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, BranchGetsInlined) { |
| auto m = CreateNewVerifiedModule(); |
| HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false); |
| ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); |
| EXPECT_THAT( |
| computation->root_instruction(), |
| op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()), |
| op::Add(op::Parameter(0), op::Constant()))); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { |
| auto m = CreateNewVerifiedModule(); |
| HloComputation* computation = MakeConditional(m.get()); |
| |
| auto* true_op = computation->AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))); |
| TF_ASSERT_OK( |
| true_op->AddControlDependencyTo(computation->root_instruction())); |
| |
| EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { |
| auto m = CreateNewVerifiedModule(); |
| HloComputation* computation = MakeConditional(m.get()); |
| auto* conditional = computation->root_instruction(); |
| ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); |
| |
| auto* true_computation = conditional->true_computation(); |
| auto* token = true_computation->AddInstruction(HloInstruction::CreateToken()); |
| auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( |
| true_computation->AddInstruction( |
| HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))), |
| token, /*channel_id=*/0)); |
| true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); |
| EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { |
| auto m = CreateNewVerifiedModule(); |
| HloComputation* computation = MakeConditional(m.get()); |
| auto* conditional = computation->root_instruction(); |
| ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); |
| |
| auto* true_computation = conditional->true_computation(); |
| auto* token = true_computation->AddInstruction(HloInstruction::CreateToken()); |
| auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( |
| ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); |
| true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); |
| EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { |
| auto m = CreateNewVerifiedModule(); |
| HloComputation* computation = MakeConditional(m.get()); |
| auto* conditional = computation->root_instruction(); |
| ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); |
| auto* false_computation = conditional->false_computation(); |
| auto token = false_computation->AddInstruction(HloInstruction::CreateToken()); |
| false_computation->AddInstruction(HloInstruction::CreateInfeed( |
| ShapeUtil::MakeShape(F32, {1}), token, "config")); |
| EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, TrivalOperandsRemoved) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule UnusedTupleOperands |
| on_false { |
| t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0) |
| lhs = f32[20,40] get-tuple-element(t), index=0 |
| rhs = f32[40,40] get-tuple-element(t), index=1 |
| dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| ROOT result = (f32[20,40]) tuple(dot) |
| } |
| |
| on_true { |
| t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0) |
| lhs = f32[20,40] get-tuple-element(t), index=2 |
| rhs = f32[40,40] get-tuple-element(t), index=3 |
| dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| ROOT result = (f32[20,40]) tuple(dot) |
| } |
| |
| ENTRY main { |
| c0_0 = f32[20,40] parameter(0) |
| c0_1 = f32[40,40] parameter(1) |
| c1_0 = f32[20,40] parameter(2) |
| c1_1 = f32[40,40] parameter(3) |
| p = pred[] parameter(4) |
| t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) tuple(c0_0, c0_1, c1_0, c1_1) |
| ROOT result = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true |
| } |
| )"; |
| auto status = ParseAndReturnUnverifiedModule(hlo_string); |
| TF_ASSERT_OK(status.status()); |
| HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| EXPECT_TRUE( |
| ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| EXPECT_EQ(status.ValueOrDie() |
| ->entry_computation() |
| ->root_instruction() |
| ->operand(1) |
| ->shape() |
| .tuple_shapes() |
| .size(), |
| 2); |
| EXPECT_EQ(status.ValueOrDie() |
| ->entry_computation() |
| ->root_instruction() |
| ->operand(2) |
| ->shape() |
| .tuple_shapes() |
| .size(), |
| 2); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, |
| TwoConditionalsCreatedInReversedLexicalOrder) { |
| absl::string_view hlo_string = R"( |
| HloModule DeadConditional |
| computation.1 { |
| param.1 = s64[] parameter(0) |
| constant.1 = s64[] constant(1) |
| ROOT add.1 = s64[] add(param.1, constant.1) |
| } |
| |
| computation.2 { |
| param.2 = s64[] parameter(0) |
| constant.2 = s64[] constant(2) |
| ROOT add.2 = s64[] add(param.2, constant.2) |
| } |
| |
| computation.3 { |
| param.3 = s64[] parameter(0) |
| constant.3 = s64[] constant(3) |
| ROOT add.3 = s64[] add(param.3, constant.3) |
| } |
| |
| computation.4 { |
| param.4 = s64[] parameter(0) |
| constant.4 = s64[] constant(4) |
| ROOT add.4 = s64[] add(param.4, constant.4) |
| } |
| |
| ENTRY KernelEntry { |
| param.1 = s64[] parameter(0) |
| param.2 = s64[] parameter(1) |
| param.3 = s64[] parameter(2) |
| param.4 = pred[] parameter(3) |
| |
| conditional_1 = s64[] conditional(param.4, param.3, param.2), |
| true_computation=computation.3, false_computation=computation.4 |
| constant.1 = pred[] constant(false) |
| ROOT conditional_2 = s64[] conditional(constant.1, conditional_1, |
| param.1), true_computation=computation.1, |
| false_computation=computation.2 |
| })"; |
| auto status = ParseAndReturnUnverifiedModule(hlo_string); |
| TF_ASSERT_OK(status.status()); |
| std::unique_ptr<HloModule> module = status.ConsumeValueOrDie(); |
| HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); |
| TF_ASSERT_OK(v.Run(module.get()).status()); |
| |
| // Replace conditional_1 with a clone that is created after conditional_2. |
| HloInstruction* conditional_1 = |
| FindInstruction(module.get(), "conditional_1"); |
| HloInstruction* conditional_1_clone = |
| conditional_1->parent()->AddInstruction(conditional_1->Clone()); |
| TF_ASSERT_OK(conditional_1->ReplaceAllUsesWith(conditional_1_clone)); |
| TF_ASSERT_OK(conditional_1->parent()->RemoveInstruction(conditional_1)); |
| |
| EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie()); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule RemoveDeadRoots |
| on_false { |
| t = (f32[20,40], f32[40,40]) parameter(0) |
| lhs = f32[20,40] get-tuple-element(t), index=0 |
| rhs = f32[40,40] get-tuple-element(t), index=1 |
| dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
| after-all = token[] after-all() |
| outfeed = token[] outfeed(dot, after-all) |
| ROOT result = (f32[20,40]) tuple(dot) |
| } |
| |
| on_true { |
| t = (f32[20,40], f32[40,40]) parameter(0) |
| lhs = f32[20,40] get-tuple-element(t), index=0 |
| add = f32[20,40] add(lhs, lhs) |
| ROOT result = (f32[20,40]) tuple(add) |
| } |
| |
| ENTRY main { |
| c0_0 = f32[20,40] parameter(0) |
| c0_1 = f32[40,40] parameter(1) |
| p = pred[] parameter(2) |
| t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1) |
| conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true |
| ROOT result = () tuple() |
| } |
| )"; |
| auto status = ParseAndReturnUnverifiedModule(hlo_string); |
| TF_ASSERT_OK(status.status()); |
| HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| EXPECT_TRUE( |
| ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| HloInstruction* conditional = |
| FindInstruction(status.ValueOrDie().get(), "conditional"); |
| // The conditional root should be replaced with an empty tuple. |
| EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, SecondTupleElementUnusedAndRemoved) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule SecondTupleElementUnusedAndRemoved |
| |
| on_true { |
| arg_tuple.7 = (f32[10,10]{1,0}) parameter(0) |
| get-tuple-element.9 = f32[10,10]{1,0} get-tuple-element(arg_tuple.7), index=0 |
| copy = f32[10,10]{1,0} copy(get-tuple-element.9) |
| ROOT tuple.6 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(copy, get-tuple-element.9) |
| } |
| |
| on_false { |
| constant.17 = f32[] constant(0) |
| constant.18 = f32[] constant(1) |
| rng.19 = f32[10,10]{1,0} rng(constant.17, constant.18), distribution=rng_uniform |
| arg_tuple.14 = (f32[10,10]{1,0}) parameter(0) |
| get-tuple-element.16 = f32[10,10]{1,0} get-tuple-element(arg_tuple.14), index=0 |
| ROOT tuple.7 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(rng.19, get-tuple-element.16) |
| } |
| |
| ENTRY main { |
| constant.38 = pred[] constant(true) |
| arg_tuple.30 = (s32[], f32[10,10]{1,0}) parameter(0) |
| get-tuple-element.21 = f32[10,10]{1,0} get-tuple-element(arg_tuple.30), index=1 |
| tuple.1 = (f32[10,10]{1,0}) tuple(get-tuple-element.21) |
| conditional = (f32[10,10]{1,0}, f32[10,10]{1,0}) conditional(constant.38, tuple.1, tuple.1), true_computation=on_true, false_computation=on_false |
| get-first-index = f32[10,10]{1,0} get-tuple-element(conditional), index=0 |
| ROOT result = (f32[10,10]{1,0}) tuple(get-first-index) |
| } |
| )"; |
| auto status = ParseAndReturnUnverifiedModule(hlo_string); |
| TF_ASSERT_OK(status.status()); |
| HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| EXPECT_TRUE( |
| ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| const HloInstruction* conditional = |
| FindInstruction(status.ValueOrDie().get(), "conditional"); |
| // The second element of "conditional" result tuple (f32[10,10], f32[10,10]) |
| // should be removed since it is not referenced by any GTE instructions |
| // (see "get-first-index" instruction in hlo_string). |
| EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1); |
| } |
| |
| TEST_F(ConditionalSimplifierTest, FirstTupleElementUnusedAndRemoved) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule FirstTupleElementUnusedAndRemoved |
| |
| on_true { |
| arg_tuple.7 = (f32[10,10]{1,0}) parameter(0) |
| get-tuple-element.9 = f32[10,10]{1,0} get-tuple-element(arg_tuple.7), index=0 |
| copy = f32[10,10]{1,0} copy(get-tuple-element.9) |
| ROOT tuple.6 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(copy, get-tuple-element.9) |
| } |
| |
| on_false { |
| constant.17 = f32[] constant(0) |
| constant.18 = f32[] constant(1) |
| rng.19 = f32[10,10]{1,0} rng(constant.17, constant.18), distribution=rng_uniform |
| arg_tuple.14 = (f32[10,10]{1,0}) parameter(0) |
| get-tuple-element.16 = f32[10,10]{1,0} get-tuple-element(arg_tuple.14), index=0 |
| ROOT tuple.7 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(rng.19, get-tuple-element.16) |
| } |
| |
| ENTRY main { |
| constant.38 = pred[] constant(true) |
| arg_tuple.30 = (s32[], f32[10,10]{1,0}) parameter(0) |
| get-tuple-element.21 = f32[10,10]{1,0} get-tuple-element(arg_tuple.30), index=1 |
| tuple.1 = (f32[10,10]{1,0}) tuple(get-tuple-element.21) |
| conditional = (f32[10,10]{1,0}, f32[10,10]{1,0}) conditional(constant.38, tuple.1, tuple.1), true_computation=on_true, false_computation=on_false |
| get-second-index = f32[10,10]{1,0} get-tuple-element(conditional), index=1 |
| ROOT result = (f32[10,10]{1,0}) tuple(get-second-index) |
| } |
| )"; |
| auto status = ParseAndReturnUnverifiedModule(hlo_string); |
| TF_ASSERT_OK(status.status()); |
| HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| EXPECT_TRUE( |
| ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| const HloInstruction* conditional = |
| FindInstruction(status.ValueOrDie().get(), "conditional"); |
| // The first element of "conditional" result tuple (f32[10,10], f32[10,10]) |
| // should be removed since it is not referenced by any GTE instructions (see |
| // "get-second-index" instruction in hlo_string). |
| EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1); |
| } |
| |
| // Before: |
| // gte rng |
| // / \ / \ |
| // | | | | |
| // on_true on_false |
| // (f32, f32) (f32, f32) |
| // | | |
| // \ / |
| // conditional |
| // (f32, f32) |
| // |
| // After: |
| // gte rng |
| // | | |
| // on_true on_false |
| // (f32) (f32) |
| // | | |
| // \ / |
| // conditional |
| // (f32) |
| // |
| // The 'rng' in on_false is to add side-effect so that conditional is not being |
| // simplified and replaced with 'select' instruction by TryRemoveConditional. |
| TEST_F(ConditionalSimplifierTest, MergeDuplicateTupleElements) { |
| absl::string_view hlo_string = |
| R"( |
| HloModule MergeDuplicateTupleElements |
| |
| on_true { |
| param-true = (f32[]) parameter(0) |
| gte-true = f32[] get-tuple-element(param-true), index=0 |
| ROOT tuple-true = (f32[], f32[]) tuple(gte-true, gte-true) |
| } |
| |
| on_false { |
| param-false = (f32[]) parameter(0) |
| constant.0 = f32[] constant(0) |
| constant.1 = f32[] constant(1) |
| rng = f32[] rng(constant.0, constant.1), distribution=rng_uniform |
| ROOT tuple-false = (f32[], f32[]) tuple(rng, rng) |
| } |
| |
| ENTRY main { |
| comp = pred[] parameter(0) |
| arg = (f32[]) parameter(1) |
| conditional = (f32[], f32[]) conditional(comp, arg, arg), true_computation=on_true, false_computation=on_false |
| gte.0 = f32[] get-tuple-element(conditional), index=0 |
| gte.1 = f32[] get-tuple-element(conditional), index=1 |
| ROOT add = f32[] add(gte.0, gte.1) |
| } |
| )"; |
| auto status = ParseAndReturnUnverifiedModule(hlo_string); |
| TF_ASSERT_OK(status.status()); |
| HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| EXPECT_TRUE( |
| ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); |
| TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); |
| const HloInstruction* conditional = |
| FindInstruction(status.ValueOrDie().get(), "conditional"); |
| EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 1); |
| const HloInstruction* gte_0 = |
| FindInstruction(status.ValueOrDie().get(), "gte.0"); |
| const HloInstruction* gte_1 = |
| FindInstruction(status.ValueOrDie().get(), "gte.1"); |
| EXPECT_EQ(gte_0->tuple_index(), 0); |
| EXPECT_EQ(gte_1->tuple_index(), 0); |
| } |
| |
| } // namespace |
| |
| } // namespace xla |