[XLA] InstructionFusion: Don't fuse across a root instruction

There may be instructions after the root, but we only need the result of the
root. Don't fuse it with anything that comes below as that would change which
value is returned.

PiperOrigin-RevId: 403381025
Change-Id: If534c77fa5ff236f8fb64c892a6af22c2a4ec362
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index df22b16..4a26851 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -840,6 +840,12 @@
                                    int64_t operand_index) {
   HloInstruction* producer = consumer->mutable_operand(operand_index);
 
+  // Don't fuse across a root instruction.
+  if (producer == producer->parent()->root_instruction()) {
+    VLOG(4) << "Not fusing into the output of the root instruction";
+    return false;
+  }
+
   // Cost condition: don't duplicate expensive instructions.
   if (FusionWouldDuplicate(*producer, *consumer) &&
       (!may_duplicate_ || is_expensive_(*producer)) &&
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index f0b7690..c7d0524 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -567,4 +567,26 @@
   EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter()));
 }
 
+TEST_F(InstructionFusionTest, DontFuseAcrossRoot) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+  HloModule test_module
+  ENTRY entry_computation {
+    p0 = f32[4,3]{1,0} parameter(0)
+    mul = f32[4,3]{1,0} multiply(p0, p0)
+    ROOT add = f32[4,3]{1,0} add(mul, p0)
+    sub = f32[4,3]{1,0} subtract(p0, add)
+  })")
+                    .ValueOrDie();
+  EXPECT_TRUE(
+      InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false)
+          .Run(module.get())
+          .ValueOrDie())
+      << module->ToString();
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Fusion(op::Parameter()));
+  EXPECT_THAT(
+      root->fused_expression_root(),
+      op::Add(op::Multiply(op::Parameter(), op::Parameter()), op::Parameter()));
+}
+
 }  // namespace xla