[XLA] Support hoist copy in code motion.

PiperOrigin-RevId: 328883064
Change-Id: I9f34fe6633a79c2ada618a0b5888d45ab05c3848
diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc
index ce80b4c..6f180ac 100644
--- a/tensorflow/compiler/xla/service/conditional_code_motion.cc
+++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc
@@ -114,6 +114,8 @@
     case HloOpcode::kConstant:
     case HloOpcode::kGetTupleElement:
       return 0;
+    case HloOpcode::kConditional:
+      return 10;
     default:
       // Assume fusion will not happen anyway if user count > 1)
       if (op->user_count() > 1) {
@@ -587,6 +589,15 @@
     VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
     computation->set_root_instruction(new_root,
                                       /*accept_different_shape*/ true);
+    // Update get tuple element index of the conditional.
+    if (use_index != -1) {
+      for (auto* user : conditional->users()) {
+        if (user->opcode() == HloOpcode::kGetTupleElement &&
+            user->tuple_index() > use_index) {
+          user->set_tuple_index(user->tuple_index() - 1);
+        }
+      }
+    }
     if (old_root->opcode() == HloOpcode::kTuple) {
       TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
     }
@@ -677,7 +688,7 @@
       : conditional_(conditional),
         conditional_parent_(conditional->parent()),
         is_layout_sensitive_(is_layout_sensitive) {}
-  // Returns true if `instruction` is worth hoisting out.
+  // Returns true if `instruction` is worth hoisting.
   bool WorthHoisting(HloInstruction* instruction) {
     // This is needed for the "moving-in" transformation, to prevent the root
     // of the parent computation (which contains the conditional) to be moved
@@ -708,6 +719,7 @@
       case HloOpcode::kAllReduce:
       case HloOpcode::kAdd:
       case HloOpcode::kPower:
+      case HloOpcode::kCopy:
       case HloOpcode::kConstant:
       case HloOpcode::kSubtract:
       case HloOpcode::kMultiply:
diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
index b91f381..a4cb598 100644
--- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
@@ -728,6 +728,59 @@
   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
 }
 
+TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) {
+  absl::string_view hlo_string =
+      R"(
+HloModule RemoveIdenticalInstruction
+
+branch1 {
+  arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0)
+  constant.1 = s32[] constant(4)
+  get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0
+  add.1 = s32[] add(get-tuple-element.1, constant.1)
+  get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1
+  slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2),
+   slice={[0:4:1], [0:3:1]}
+  ROOT tuple.1 = (s32[],f32[4,3]{0,1}) tuple(add.1, slice.1)
+}
+
+branch2 {
+  arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0)
+  get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0
+  copy.1 = s32[] copy(get-tuple-element.3)
+  get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1
+  copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4)
+  ROOT tuple.2 = (s32[],f32[4,3]{0,1}) tuple(copy.1, copy.2)
+}
+
+ENTRY main {
+  pred.1 = pred[] parameter(0)
+  tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1)
+  tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2)
+  conditional = (s32[],f32[4,3]{0,1})
+    conditional(pred.1, tuple.3, tuple.4), true_computation=branch1,
+    false_computation=branch2
+  get-first-index = f32[4,3]{0,1} get-tuple-element(conditional), index=1
+  get-zero-index = s32[] get-tuple-element(conditional), index=0
+  copy.3 = f32[4,3]{1,0} copy(get-first-index)
+  ROOT tuple.5 = (s32[], f32[4,3]{0,1}) tuple(get-zero-index, copy.3)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  ConditionalCodeMotion pass(true, true);
+  ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
+  VLOG(1) << module->ToString();
+
+  const HloInstruction* conditional =
+      FindInstruction(module.get(), "conditional");
+  const HloComputation* on_true = conditional->branch_computation(0);
+  ASSERT_EQ(on_true->instruction_count(), 8);
+  const HloComputation* on_false = conditional->branch_computation(1);
+  ASSERT_EQ(on_false->instruction_count(), 7);
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::Conditional()));
+}
+
 }  // namespace conditional_opt
 
 }  // namespace xla