[mhlo] Remove the tuple when inlining a variadic reduce region
In proto HLO the variadic reducer returns a tuple, but mhlo.reduce doesn't
accept it. Flatten the tuple when importing.
PiperOrigin-RevId: 426433281
Change-Id: I45adaa67bbf6325e0ebcb554b81f282ea8cc17dc
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index bb8749b..c84c791 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -999,8 +999,9 @@
llvm::makeArrayRef(operands).take_front(num_inputs),
llvm::makeArrayRef(operands).drop_front(num_inputs),
ConvertDimensions(instruction->dimensions()));
- TF_RETURN_IF_ERROR(
- ImportAsRegion(*instruction->to_apply(), &reduce.body()));
+ TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
+ &reduce.body(),
+ /*flatten_region_arg_tuple=*/true));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index b54796e..81f312b 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -844,8 +844,9 @@
// CHECK: mhlo.reduce([[ARG0]] init: [[ARG2]]), ([[ARG0]] init: [[ARG2]])
// CHECK-SAME: dimensions = [0, 1]
- // CHECK: mhlo.add{{.*}} : tensor<f32>
- // CHECK: mhlo.add{{.*}} : tensor<f32>
+ // CHECK: %[[A:.*]] = mhlo.add{{.*}} : tensor<f32>
+ // CHECK: %[[B:.*]] = mhlo.add{{.*}} : tensor<f32>
+ // CHECK: "mhlo.return"(%[[A]], %[[B]]) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: "mhlo.tuple"(%0#0, %0#1) {xla_shape = {{.*}}} : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
%reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1