[XLA GPU] [NFC] Use the same function to check that two reductions are fusible

PiperOrigin-RevId: 283365847
Change-Id: Ifde925a68fce329b812fd34acaf0925e2c28b6c6
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
index 599eef4..2473868 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -154,11 +154,9 @@
   // operand shape) and the reduction dimensions need to match.
   auto* instr_1 = get_real_hero(&instr1);
   auto* instr_2 = get_real_hero(&instr2);
-  // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
   if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
       IsReductionFromOrToContiguousDimensions(*instr_2) &&
-      (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) ||
-       instr_1->dimensions() != instr_2->dimensions())) {
+      !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
     return false;
   }
   // The elementwise output shapes must be the same (including layout).
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 26a6deb..72f69ca 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -405,5 +405,33 @@
           EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)));
 }
 
+bool AreFusedReductionOutputsConsistent(
+    absl::Span<const HloInstruction* const> output_instructions,
+    const HloInstruction* first_reduce) {
+  for (const HloInstruction* inst : output_instructions) {
+    if (IsReductionFromOrToContiguousDimensions(*inst)) {
+      // Shapes, layouts and dimensions must be the same for all reduces
+      // inside of this fusion.
+      // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
+      if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
+            ShapeUtil::Equal(first_reduce->operand(0)->shape(),
+                             inst->operand(0)->shape()) &&
+            ShapeUtil::Equal(first_reduce->operand(1)->shape(),
+                             inst->operand(1)->shape()) &&
+            first_reduce->dimensions() == inst->dimensions())) {
+        return false;
+      }
+    } else {
+      if (!(ShapeUtil::CompatibleIgnoringElementType(
+                first_reduce->operand(0)->shape(), inst->shape()) &&
+            LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
+                              inst->shape().layout()))) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index f269cf8..db3cd22 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -200,6 +200,11 @@
 // block 0 of the kernel.
 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
 
+// Returns whether the outputs of a fusion with reduction are consistent.
+bool AreFusedReductionOutputsConsistent(
+    absl::Span<const HloInstruction* const> output_instructions,
+    const HloInstruction* first_reduce);
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 06a00d2..dbc2c95 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2779,32 +2779,6 @@
 }
 
 namespace {
-// Checks that the outputs of a fusion with reduction are consistent.
-Status AreFusedReductionOutputsConsistent(
-    absl::Span<HloInstruction* const> output_instructions,
-    const HloInstruction* first_reduce) {
-  for (const HloInstruction* inst : output_instructions) {
-    if (IsReductionFromOrToContiguousDimensions(*inst)) {
-      // Shapes, layouts and dimensions must be the same for all reduces
-      // inside of this fusion.
-      TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
-      TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
-                                    inst->operand(0)->shape()));
-      TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
-                                    inst->operand(1)->shape()));
-      TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions());
-    } else {
-      // For extra outputs we can relax shape equality to allow different
-      // types (with the same number of elements). Layouts still have to
-      // match.
-      TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType(
-          first_reduce->operand(0)->shape(), inst->shape()));
-      TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
-                                     inst->shape().layout()));
-    }
-  }
-  return Status::OK();
-}
 
 // Returns true if all the transitive users of hlo before hitting users in
 // use_chain_endings are elementwise operations.
@@ -2994,8 +2968,10 @@
 
   const HloInstruction* first_reduce = reduce_instructions.at(0);
   if (output_instructions.size() > 1) {
-    TF_RETURN_IF_ERROR(
-        AreFusedReductionOutputsConsistent(output_instructions, first_reduce));
+    if (!AreFusedReductionOutputsConsistent(output_instructions,
+                                            first_reduce)) {
+      return InternalError("Inconsistent reduction fusion outputs");
+    }
   }
 
   // Build a kernel thunk to compute all the outputs.