[XLA] Support hoist copy in code motion.
PiperOrigin-RevId: 328883064
Change-Id: I9f34fe6633a79c2ada618a0b5888d45ab05c3848
diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc
index ce80b4c..6f180ac 100644
--- a/tensorflow/compiler/xla/service/conditional_code_motion.cc
+++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc
@@ -114,6 +114,8 @@
case HloOpcode::kConstant:
case HloOpcode::kGetTupleElement:
return 0;
+ case HloOpcode::kConditional:
+ return 10;
default:
// Assume fusion will not happen anyway if user count > 1)
if (op->user_count() > 1) {
@@ -587,6 +589,15 @@
VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
computation->set_root_instruction(new_root,
/*accept_different_shape*/ true);
+ // Update get tuple element index of the conditional.
+ if (use_index != -1) {
+ for (auto* user : conditional->users()) {
+ if (user->opcode() == HloOpcode::kGetTupleElement &&
+ user->tuple_index() > use_index) {
+ user->set_tuple_index(user->tuple_index() - 1);
+ }
+ }
+ }
if (old_root->opcode() == HloOpcode::kTuple) {
TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
}
@@ -677,7 +688,7 @@
: conditional_(conditional),
conditional_parent_(conditional->parent()),
is_layout_sensitive_(is_layout_sensitive) {}
- // Returns true if `instruction` is worth hoisting out.
+ // Returns true if `instruction` is worth hoisting.
bool WorthHoisting(HloInstruction* instruction) {
// This is needed for the "moving-in" transformation, to prevent the root
// of the parent computation (which contains the conditional) to be moved
@@ -708,6 +719,7 @@
case HloOpcode::kAllReduce:
case HloOpcode::kAdd:
case HloOpcode::kPower:
+ case HloOpcode::kCopy:
case HloOpcode::kConstant:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
index b91f381..a4cb598 100644
--- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
@@ -728,6 +728,59 @@
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
+TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) {
+ absl::string_view hlo_string =
+ R"(
+HloModule RemoveIdenticalInstruction
+
+branch1 {
+ arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0)
+ constant.1 = s32[] constant(4)
+ get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0
+ add.1 = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1
+ slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2),
+ slice={[0:4:1], [0:3:1]}
+ ROOT tuple.1 = (s32[],f32[4,3]{0,1}) tuple(add.1, slice.1)
+}
+
+branch2 {
+ arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0
+ copy.1 = s32[] copy(get-tuple-element.3)
+ get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1
+ copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4)
+ ROOT tuple.2 = (s32[],f32[4,3]{0,1}) tuple(copy.1, copy.2)
+}
+
+ENTRY main {
+ pred.1 = pred[] parameter(0)
+ tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1)
+ tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2)
+ conditional = (s32[],f32[4,3]{0,1})
+ conditional(pred.1, tuple.3, tuple.4), true_computation=branch1,
+ false_computation=branch2
+ get-first-index = f32[4,3]{0,1} get-tuple-element(conditional), index=1
+ get-zero-index = s32[] get-tuple-element(conditional), index=0
+ copy.3 = f32[4,3]{1,0} copy(get-first-index)
+ ROOT tuple.5 = (s32[], f32[4,3]{0,1}) tuple(get-zero-index, copy.3)
+}
+)";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+ ConditionalCodeMotion pass(true, true);
+ ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
+ VLOG(1) << module->ToString();
+
+ const HloInstruction* conditional =
+ FindInstruction(module.get(), "conditional");
+ const HloComputation* on_true = conditional->branch_computation(0);
+ ASSERT_EQ(on_true->instruction_count(), 8);
+ const HloComputation* on_false = conditional->branch_computation(1);
+ ASSERT_EQ(on_false->instruction_count(), 7);
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, AllOf(op::Conditional()));
+}
+
} // namespace conditional_opt
} // namespace xla