[XLA GPU] [NFC] Use the same function to check that two reductions are fusible
PiperOrigin-RevId: 283365847
Change-Id: Ifde925a68fce329b812fd34acaf0925e2c28b6c6
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
index 599eef4..2473868 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -154,11 +154,9 @@
// operand shape) and the reduction dimensions need to match.
auto* instr_1 = get_real_hero(&instr1);
auto* instr_2 = get_real_hero(&instr2);
- // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
IsReductionFromOrToContiguousDimensions(*instr_2) &&
- (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) ||
- instr_1->dimensions() != instr_2->dimensions())) {
+ !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
return false;
}
// The elementwise output shapes must be the same (including layout).
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 26a6deb..72f69ca 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -405,5 +405,33 @@
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)));
}
+bool AreFusedReductionOutputsConsistent(
+ absl::Span<const HloInstruction* const> output_instructions,
+ const HloInstruction* first_reduce) {
+ for (const HloInstruction* inst : output_instructions) {
+ if (IsReductionFromOrToContiguousDimensions(*inst)) {
+ // Shapes, layouts and dimensions must be the same for all reduces
+ // inside of this fusion.
+ // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
+ if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
+ ShapeUtil::Equal(first_reduce->operand(0)->shape(),
+ inst->operand(0)->shape()) &&
+ ShapeUtil::Equal(first_reduce->operand(1)->shape(),
+ inst->operand(1)->shape()) &&
+ first_reduce->dimensions() == inst->dimensions())) {
+ return false;
+ }
+ } else {
+ if (!(ShapeUtil::CompatibleIgnoringElementType(
+ first_reduce->operand(0)->shape(), inst->shape()) &&
+ LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
+ inst->shape().layout()))) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index f269cf8..db3cd22 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -200,6 +200,11 @@
// block 0 of the kernel.
llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
+// Returns whether the outputs of a fusion with reduction are consistent.
+bool AreFusedReductionOutputsConsistent(
+ absl::Span<const HloInstruction* const> output_instructions,
+ const HloInstruction* first_reduce);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 06a00d2..dbc2c95 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2779,32 +2779,6 @@
}
namespace {
-// Checks that the outputs of a fusion with reduction are consistent.
-Status AreFusedReductionOutputsConsistent(
- absl::Span<HloInstruction* const> output_instructions,
- const HloInstruction* first_reduce) {
- for (const HloInstruction* inst : output_instructions) {
- if (IsReductionFromOrToContiguousDimensions(*inst)) {
- // Shapes, layouts and dimensions must be the same for all reduces
- // inside of this fusion.
- TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
- TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
- inst->operand(0)->shape()));
- TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
- inst->operand(1)->shape()));
- TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions());
- } else {
- // For extra outputs we can relax shape equality to allow different
- // types (with the same number of elements). Layouts still have to
- // match.
- TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType(
- first_reduce->operand(0)->shape(), inst->shape()));
- TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
- inst->shape().layout()));
- }
- }
- return Status::OK();
-}
// Returns true if all the transitive users of hlo before hitting users in
// use_chain_endings are elementwise operations.
@@ -2994,8 +2968,10 @@
const HloInstruction* first_reduce = reduce_instructions.at(0);
if (output_instructions.size() > 1) {
- TF_RETURN_IF_ERROR(
- AreFusedReductionOutputsConsistent(output_instructions, first_reduce));
+ if (!AreFusedReductionOutputsConsistent(output_instructions,
+ first_reduce)) {
+ return InternalError("Inconsistent reduction fusion outputs");
+ }
}
// Build a kernel thunk to compute all the outputs.