[mhlo] Flatten tuple-return from mhlo::ReduceScatterOp's compulation-block during import.

During import (from HLO to MHLO) we flatten the tuple return-type in the
imported region-blocks. MHLO ReduceScatterOp::verifier ensures that the
flattened return-type is comaptible with the op-specification.

PiperOrigin-RevId: 448584018
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index a25efb9..c147460 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -1151,7 +1151,8 @@
           func_builder->create<mlir::mhlo::ReduceScatterOp>(
               loc, result_type, operands, attributes);
       TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(),
-                                        &reduce_scatter_op.computation()));
+                                        &reduce_scatter_op.computation(),
+                                        /*flatten_region_arg_tuple=*/true));
 
       return reduce_scatter_op.getOperation();
     }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index 885b453..731da33 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -895,6 +895,13 @@
   ROOT add = f32[] add(lhs, rhs)
 }
 
+%reduce_helper_add_returning_tuple {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  add = f32[] add(lhs, rhs)
+  ROOT tple = (f32[]) tuple(add)
+}
+
 // CHECK-LABEL:  func private @test_reduce_scatter
 // CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>)
 %test_reduce_scatter {
@@ -910,6 +917,16 @@
   ROOT ars = f32[4,4] reduce-scatter(input), replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add
 }
 
+// CHECK-LABEL:  func private @test_reduce_scatter_with_region_returning_tuple
+%test_reduce_scatter_with_region_returning_tuple {
+  input = f32[4,8] parameter(0)
+  // CHECK-NEXT: "mhlo.reduce_scatter"
+  // CHECK-NEXT:   ^bb0
+  // CHECK-NEXT:     [[ADD:%.*]] = mhlo.add
+  // CHECK-NEXT:     "mhlo.return"([[ADD]])
+  ROOT ars = f32[4,4] reduce-scatter(input), replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add_returning_tuple
+}
+
 // CHECK-LABEL:  func private @test_reduce_scatter_with_channel
 // CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>)
 %test_reduce_scatter_with_channel {