[mhlo] Flatten tuple-return from mhlo::ReduceScatterOp's compulation-block during import.
During import (from HLO to MHLO) we flatten the tuple return-type in the
imported region-blocks. MHLO ReduceScatterOp::verifier ensures that the
flattened return-type is comaptible with the op-specification.
PiperOrigin-RevId: 448584018
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index a25efb9..c147460 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -1151,7 +1151,8 @@
func_builder->create<mlir::mhlo::ReduceScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(),
- &reduce_scatter_op.computation()));
+ &reduce_scatter_op.computation(),
+ /*flatten_region_arg_tuple=*/true));
return reduce_scatter_op.getOperation();
}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index 885b453..731da33 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -895,6 +895,13 @@
ROOT add = f32[] add(lhs, rhs)
}
+%reduce_helper_add_returning_tuple {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ add = f32[] add(lhs, rhs)
+ ROOT tple = (f32[]) tuple(add)
+}
+
// CHECK-LABEL: func private @test_reduce_scatter
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>)
%test_reduce_scatter {
@@ -910,6 +917,16 @@
ROOT ars = f32[4,4] reduce-scatter(input), replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add
}
+// CHECK-LABEL: func private @test_reduce_scatter_with_region_returning_tuple
+%test_reduce_scatter_with_region_returning_tuple {
+ input = f32[4,8] parameter(0)
+ // CHECK-NEXT: "mhlo.reduce_scatter"
+ // CHECK-NEXT: ^bb0
+ // CHECK-NEXT: [[ADD:%.*]] = mhlo.add
+ // CHECK-NEXT: "mhlo.return"([[ADD]])
+ ROOT ars = f32[4,4] reduce-scatter(input), replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add_returning_tuple
+}
+
// CHECK-LABEL: func private @test_reduce_scatter_with_channel
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>)
%test_reduce_scatter_with_channel {