Expose use_global_device_ids as an attribute in MHLO AllReduceOp
This time also add a backwards-compatible custom builder method.
PiperOrigin-RevId: 467222612
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 3f66b7e..ec8e7d7 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -991,6 +991,8 @@
if (all_reduce->channel_id().has_value())
attributes.push_back(
ConvertChannelHandle(all_reduce->channel_id().value()));
+ if (all_reduce->use_global_device_ids())
+ attributes.push_back(ConvertUseGlobalDeviceIds());
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
@@ -1709,6 +1711,11 @@
context_, channel.handle(), channel.type()));
}
+mlir::NamedAttribute HloFunctionImporter::ConvertUseGlobalDeviceIds() {
+ return builder_->getNamedAttr("use_global_device_ids",
+ builder_->getUnitAttr());
+}
+
void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
const Shape& shape,
llvm::StringRef attr_name) {
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
index f74b931..5e8954d 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
@@ -226,6 +226,9 @@
// Converts channel id to attribute
mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id);
+ // Convert use global device ids flag to attribute
+ mlir::NamedAttribute ConvertUseGlobalDeviceIds();
+
// Converts channel handle to attribute
mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index d6100cf..2cb512a 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -16,6 +16,7 @@
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include <memory>
+#include <optional>
#include <string>
#include "llvm/ADT/DenseMap.h"
@@ -178,6 +179,12 @@
return xla::ConvertNx2Attribute(padding).ValueOrDie();
}
+static std::optional<bool> Convert_use_global_device_ids(
+ llvm::Optional<bool> use_global_device_ids) {
+ if (!use_global_device_ids) return {};
+ return *use_global_device_ids;
+}
+
static std::vector<std::pair<int64_t, int64_t>> Convert_source_target_pairs(
llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
@@ -757,9 +764,10 @@
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
- value_map[op] = xla::AllReduce(operand, computation,
- Convert_replica_groups(op.replica_groups()),
- Convert_channel_handle(op.channel_handle()));
+ value_map[op] = xla::AllReduce(
+ operand, computation, Convert_replica_groups(op.replica_groups()),
+ Convert_channel_handle(op.channel_handle()), std::nullopt,
+ Convert_use_global_device_ids(op.use_global_device_ids()));
return success();
}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index 3d20110..adefedd 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -134,6 +134,36 @@
// -----
// CHECK: HloModule
+func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+ %0 = "mhlo.all_reduce"(%arg0) ({
+ // Perform max reduction inside the region
+ ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+ %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+ "mhlo.return"(%max) : (tensor<f32>) -> ()
+ })
+ {
+ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 2
+ >,
+ use_global_device_ids
+ } : (tensor<10xf32>) -> tensor<10xf32>
+ func.return %0 : tensor<10xf32>
+}
+
+// CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
+// CHECK: ENTRY
+// CHECK: %[[ARG0:.*]] = f32[10] parameter(0)
+// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
+// CHECK-SAME: channel_id=5
+// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}}
+// CHECK-SAME: use_global_device_ids=true
+// CHECK-SAME: to_apply=%[[COMPUTATION]]
+
+// -----
+
+// CHECK: HloModule
func.func @main(%arg0: tensor<10xf32>) -> tensor<5xf32> {
%0 = "mhlo.reduce_scatter"(%arg0) ({
// Perform max reduction inside the region
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index e1f19fc..9bfd995 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -84,10 +84,28 @@
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: }) {
// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+ // CHECK-NOT: use_global_device_ids
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
+ // CHECK-NOT: use_global_device_ids
+ // CHECK-SAME: :
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
}
+// CHECK-LABEL: func private @test_all_reduce_global
+// CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>)
+%test_all_reduce_global {
+ input = f32[8] parameter(0)
+ // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]])
+ // CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+ // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+ // CHECK: mhlo.return [[ADD]] : tensor<f32>
+ // CHECK: }) {
+ // CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+ // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
+ // CHECK-SAME: use_global_device_ids
+ ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, use_global_device_ids=true, to_apply=add
+}
+
// CHECK-LABEL: func private @test_and
%test_and (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] {
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc
index 9b97d2c..3399e3d 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc
@@ -139,8 +139,8 @@
ChannelHandleAttr channel_handle = ConvertChannel(builder, channel_id, mode);
Location loc = op->getLoc();
Type element_type = getElementTypeOrSelf(input.getType());
- auto all_reduce = builder.create<AllReduceOp>(loc, result_type, input,
- replica_groups, channel_handle);
+ auto all_reduce = builder.create<AllReduceOp>(
+ loc, result_type, input, replica_groups, channel_handle, nullptr);
if (merge_op == "Add") {
BuildReduceBody<AddOp>(element_type, &all_reduce.computation(), &builder);
} else if (merge_op == "Mul") {
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e9b25fc..75ba3c6 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -2936,7 +2936,8 @@
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
- const std::optional<Shape>& shape_with_layout) {
+ const std::optional<Shape>& shape_with_layout,
+ const std::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
@@ -2994,6 +2995,10 @@
instr.set_channel_id(channel_id->handle());
}
+ if (use_global_device_ids.has_value()) {
+ instr.set_use_global_device_ids(*use_global_device_ids);
+ }
+
AddCalledComputation(computation, &instr);
TF_ASSIGN_OR_RETURN(
@@ -4664,9 +4669,11 @@
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
- const std::optional<Shape>& shape_with_layout) {
+ const std::optional<Shape>& shape_with_layout,
+ const std::optional<bool> use_global_device_ids) {
return operand.builder()->AllReduce(operand, computation, replica_groups,
- channel_id, shape_with_layout);
+ channel_id, shape_with_layout,
+ use_global_device_ids);
}
XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 0e672ee..47a37e4 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -751,10 +751,12 @@
const std::optional<Layout>& layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);
- XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
- absl::Span<const ReplicaGroup> replica_groups = {},
- const std::optional<ChannelHandle>& channel_id = std::nullopt,
- const std::optional<Shape>& shape_with_layout = std::nullopt);
+ XlaOp AllReduce(
+ XlaOp operand, const XlaComputation& computation,
+ absl::Span<const ReplicaGroup> replica_groups = {},
+ const std::optional<ChannelHandle>& channel_id = std::nullopt,
+ const std::optional<Shape>& shape_with_layout = std::nullopt,
+ const std::optional<bool> use_global_device_ids = std::nullopt);
XlaOp ReduceScatter(
XlaOp operand, const XlaComputation& computation,
@@ -1367,7 +1369,8 @@
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
- const std::optional<Shape>& shape_with_layout);
+ const std::optional<Shape>& shape_with_layout,
+ const std::optional<bool> use_global_device_ids);
friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
int64_t scatter_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
@@ -2348,7 +2351,8 @@
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<ChannelHandle>& channel_id = std::nullopt,
- const std::optional<Shape>& shape_with_layout = std::nullopt);
+ const std::optional<Shape>& shape_with_layout = std::nullopt,
+ const std::optional<bool> use_global_device_ids = std::nullopt);
XlaOp ReduceScatter(
XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index aecca02..023b827 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1045,10 +1045,18 @@
let arguments = (ins
HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups,
- OptionalAttr<ChannelHandle>:$channel_handle
+ OptionalAttr<ChannelHandle>:$channel_handle,
+ UnitAttr:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor);
+ // use_global_device_ids is rarely used, so we add a simplified
+ // builder method for convenience.
+ let builders = [
+ OpBuilder<(ins
+ "::mlir::Type":$result_type, "::mlir::Value":$operand,
+ "::mlir::DenseIntElementsAttr":$replica_groups,
+ "::mlir::mhlo::ChannelHandleAttr":$channel_handle)>];
let hasCustomHLOConverter = 1;
}
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index d0145cd..6a246e9 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -617,6 +617,19 @@
} // namespace
//===----------------------------------------------------------------------===//
+// AllReduceOp
+//===----------------------------------------------------------------------===//
+
+void AllReduceOp::build(
+ ::mlir::OpBuilder& ods_builder, ::mlir::OperationState& ods_state,
+ ::mlir::Type result_type, ::mlir::Value operand,
+ ::mlir::DenseIntElementsAttr replica_groups,
+ /*optional*/ ::mlir::mhlo::ChannelHandleAttr channel_handle) {
+ AllReduceOp::build(ods_builder, ods_state, result_type, operand,
+ replica_groups, channel_handle, nullptr);
+}
+
+//===----------------------------------------------------------------------===//
// ReduceScatterOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
index 1e0c770..9e5028d 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
@@ -41,6 +41,26 @@
// -----
+func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+ %0 = "mhlo.all_reduce"(%arg0) ({
+ // Perform max reduction inside the region
+ ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+ %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+ "mhlo.return"(%max) : (tensor<f32>) -> ()
+ })
+ {
+ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+ channel_handle = #mhlo.channel_handle<
+ handle = 5,
+ type = 2
+ >,
+ use_global_device_ids
+ } : (tensor<10xf32>) -> tensor<10xf32>
+ func.return %0 : tensor<10xf32>
+}
+
+// -----
+
func.func @invalid_reduce_scatter(%data: tensor<4x0xf32>) -> tensor<4x4xf32> {
// expected-error@+1 {{operand scatter dimension cannot be zero}}
%0 = "mhlo.reduce_scatter"(%data) ({
diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc
index 62025b4..7400eee 100644
--- a/tensorflow/compiler/xla/python/ops.cc
+++ b/tensorflow/compiler/xla/python/ops.cc
@@ -76,16 +76,16 @@
py::arg("channel_id") = std::nullopt,
py::arg("shape_with_layout") = std::nullopt,
py::arg("use_global_device_ids") = std::nullopt);
- ops.def(
- "AllReduce",
- static_cast<XlaOp (*)(
- XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
- const std::optional<ChannelHandle>&, const std::optional<Shape>&)>(
- &AllReduce),
- py::arg("operand"), py::arg("computation"),
- py::arg("replica_groups") = py::list(),
- py::arg("channel_id") = std::nullopt,
- py::arg("shape_with_layout") = std::nullopt);
+ ops.def("AllReduce",
+ static_cast<XlaOp (*)(
+ XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
+ const std::optional<ChannelHandle>&, const std::optional<Shape>&,
+ const std::optional<bool>)>(&AllReduce),
+ py::arg("operand"), py::arg("computation"),
+ py::arg("replica_groups") = py::list(),
+ py::arg("channel_id") = std::nullopt,
+ py::arg("shape_with_layout") = std::nullopt,
+ py::arg("use_global_device_ids") = std::nullopt);
ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"),
py::arg("computation"), py::arg("scatter_dimension"),
py::arg("shard_count"), py::arg("replica_groups") = py::list(),