[mhlo] Remove the tuple when inlining a variadic reduce region

In proto HLO the variadic reducer returns a tuple, but mhlo.reduce doesn't
accept it. Flatten the tuple when importing.

PiperOrigin-RevId: 426433281
Change-Id: I45adaa67bbf6325e0ebcb554b81f282ea8cc17dc
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index bb8749b..c84c791 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -999,8 +999,9 @@
           llvm::makeArrayRef(operands).take_front(num_inputs),
           llvm::makeArrayRef(operands).drop_front(num_inputs),
           ConvertDimensions(instruction->dimensions()));
-      TF_RETURN_IF_ERROR(
-          ImportAsRegion(*instruction->to_apply(), &reduce.body()));
+      TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
+                                        &reduce.body(),
+                                        /*flatten_region_arg_tuple=*/true));
 
       // Check if the output needs to be tupled.
       if (return_types.size() == 1 && return_types.front() == result_type) {
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index b54796e..81f312b 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -844,8 +844,9 @@
 
   // CHECK:  mhlo.reduce([[ARG0]] init: [[ARG2]]), ([[ARG0]] init: [[ARG2]])
   // CHECK-SAME:  dimensions = [0, 1]
-  // CHECK:  mhlo.add{{.*}} : tensor<f32>
-  // CHECK:  mhlo.add{{.*}} : tensor<f32>
+  // CHECK:  %[[A:.*]] = mhlo.add{{.*}} : tensor<f32>
+  // CHECK:  %[[B:.*]] = mhlo.add{{.*}} : tensor<f32>
+  // CHECK: "mhlo.return"(%[[A]], %[[B]]) : (tensor<f32>, tensor<f32>) -> ()
   // CHECK: "mhlo.tuple"(%0#0, %0#1) {xla_shape = {{.*}}} : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
   %reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1