Expose use_global_device_ids as an attribute in MHLO AllReduceOp

This time also add a backwards-compatible custom builder method.

PiperOrigin-RevId: 467222612
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 3f66b7e..ec8e7d7 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -991,6 +991,8 @@
       if (all_reduce->channel_id().has_value())
         attributes.push_back(
             ConvertChannelHandle(all_reduce->channel_id().value()));
+      if (all_reduce->use_global_device_ids())
+        attributes.push_back(ConvertUseGlobalDeviceIds());
       auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
           loc, result_type, operands, attributes);
       TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
@@ -1709,6 +1711,11 @@
                             context_, channel.handle(), channel.type()));
 }
 
+mlir::NamedAttribute HloFunctionImporter::ConvertUseGlobalDeviceIds() {
+  return builder_->getNamedAttr("use_global_device_ids",
+                                builder_->getUnitAttr());
+}
+
 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
                                            const Shape& shape,
                                            llvm::StringRef attr_name) {
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
index f74b931..5e8954d 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
@@ -226,6 +226,9 @@
   // Converts channel id to attribute
   mlir::NamedAttribute ConvertChannelHandle(std::optional<int64_t> channel_id);
 
+  // Convert use global device ids flag to attribute
+  mlir::NamedAttribute ConvertUseGlobalDeviceIds();
+
   // Converts channel handle to attribute
   mlir::NamedAttribute ConvertChannelHandle(const xla::ChannelHandle& channel);
 
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index d6100cf..2cb512a 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
 
 #include <memory>
+#include <optional>
 #include <string>
 
 #include "llvm/ADT/DenseMap.h"
@@ -178,6 +179,12 @@
   return xla::ConvertNx2Attribute(padding).ValueOrDie();
 }
 
+static std::optional<bool> Convert_use_global_device_ids(
+    llvm::Optional<bool> use_global_device_ids) {
+  if (!use_global_device_ids) return {};
+  return *use_global_device_ids;
+}
+
 static std::vector<std::pair<int64_t, int64_t>> Convert_source_target_pairs(
     llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
   return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
@@ -757,9 +764,10 @@
   xla::XlaOp operand;
   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
 
-  value_map[op] = xla::AllReduce(operand, computation,
-                                 Convert_replica_groups(op.replica_groups()),
-                                 Convert_channel_handle(op.channel_handle()));
+  value_map[op] = xla::AllReduce(
+      operand, computation, Convert_replica_groups(op.replica_groups()),
+      Convert_channel_handle(op.channel_handle()), std::nullopt,
+      Convert_use_global_device_ids(op.use_global_device_ids()));
   return success();
 }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index 3d20110..adefedd 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -134,6 +134,36 @@
 // -----
 
 // CHECK:  HloModule
+func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+  %0 = "mhlo.all_reduce"(%arg0) ({
+  // Perform max reduction inside the region
+  ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+    %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+    "mhlo.return"(%max) : (tensor<f32>) -> ()
+  })
+  {
+    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 2
+    >,
+    use_global_device_ids 
+  } : (tensor<10xf32>) -> tensor<10xf32>
+  func.return %0 : tensor<10xf32>
+}
+
+// CHECK:  %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
+// CHECK:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[10] parameter(0)
+// CHECK:  ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]])
+// CHECK-SAME:  channel_id=5
+// CHECK-SAME{LITERAL}:  replica_groups={{0,2,4,6},{1,3,5,7}}
+// CHECK-SAME:  use_global_device_ids=true
+// CHECK-SAME:  to_apply=%[[COMPUTATION]]
+
+// -----
+
+// CHECK:  HloModule
 func.func @main(%arg0: tensor<10xf32>) -> tensor<5xf32> {
   %0 = "mhlo.reduce_scatter"(%arg0) ({
   // Perform max reduction inside the region
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index e1f19fc..9bfd995 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -84,10 +84,28 @@
   // CHECK:    mhlo.return [[ADD]] : tensor<f32>
   // CHECK:  }) {
   // CHECK-SAME:  channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+  // CHECK-NOT: use_global_device_ids
   // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
+  // CHECK-NOT: use_global_device_ids
+  // CHECK-SAME: :
   ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
 }
 
+// CHECK-LABEL:  func private @test_all_reduce_global
+// CHECK-SAME:  ([[INPUT:%.*]]: tensor<8xf32>)
+%test_all_reduce_global {
+  input = f32[8] parameter(0)
+  // CHECK-NEXT:  "mhlo.all_reduce"([[INPUT]])
+  // CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+  // CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+  // CHECK:    mhlo.return [[ADD]] : tensor<f32>
+  // CHECK:  }) {
+  // CHECK-SAME:  channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
+  // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
+  // CHECK-SAME: use_global_device_ids
+  ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, use_global_device_ids=true, to_apply=add
+}
+
 
 // CHECK-LABEL:  func private @test_and
 %test_and (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] {
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc
index 9b97d2c..3399e3d 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc
@@ -139,8 +139,8 @@
   ChannelHandleAttr channel_handle = ConvertChannel(builder, channel_id, mode);
   Location loc = op->getLoc();
   Type element_type = getElementTypeOrSelf(input.getType());
-  auto all_reduce = builder.create<AllReduceOp>(loc, result_type, input,
-                                                replica_groups, channel_handle);
+  auto all_reduce = builder.create<AllReduceOp>(
+      loc, result_type, input, replica_groups, channel_handle, nullptr);
   if (merge_op == "Add") {
     BuildReduceBody<AddOp>(element_type, &all_reduce.computation(), &builder);
   } else if (merge_op == "Mul") {
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e9b25fc..75ba3c6 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -2936,7 +2936,8 @@
 XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
                             absl::Span<const ReplicaGroup> replica_groups,
                             const std::optional<ChannelHandle>& channel_id,
-                            const std::optional<Shape>& shape_with_layout) {
+                            const std::optional<Shape>& shape_with_layout,
+                            const std::optional<bool> use_global_device_ids) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
@@ -2994,6 +2995,10 @@
       instr.set_channel_id(channel_id->handle());
     }
 
+    if (use_global_device_ids.has_value()) {
+      instr.set_use_global_device_ids(*use_global_device_ids);
+    }
+
     AddCalledComputation(computation, &instr);
 
     TF_ASSIGN_OR_RETURN(
@@ -4664,9 +4669,11 @@
 XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
                 absl::Span<const ReplicaGroup> replica_groups,
                 const std::optional<ChannelHandle>& channel_id,
-                const std::optional<Shape>& shape_with_layout) {
+                const std::optional<Shape>& shape_with_layout,
+                const std::optional<bool> use_global_device_ids) {
   return operand.builder()->AllReduce(operand, computation, replica_groups,
-                                      channel_id, shape_with_layout);
+                                      channel_id, shape_with_layout,
+                                      use_global_device_ids);
 }
 
 XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 0e672ee..47a37e4 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -751,10 +751,12 @@
       const std::optional<Layout>& layout = std::nullopt,
       const std::optional<bool> use_global_device_ids = std::nullopt);
 
-  XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
-                  absl::Span<const ReplicaGroup> replica_groups = {},
-                  const std::optional<ChannelHandle>& channel_id = std::nullopt,
-                  const std::optional<Shape>& shape_with_layout = std::nullopt);
+  XlaOp AllReduce(
+      XlaOp operand, const XlaComputation& computation,
+      absl::Span<const ReplicaGroup> replica_groups = {},
+      const std::optional<ChannelHandle>& channel_id = std::nullopt,
+      const std::optional<Shape>& shape_with_layout = std::nullopt,
+      const std::optional<bool> use_global_device_ids = std::nullopt);
 
   XlaOp ReduceScatter(
       XlaOp operand, const XlaComputation& computation,
@@ -1367,7 +1369,8 @@
   friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
                          absl::Span<const ReplicaGroup> replica_groups,
                          const std::optional<ChannelHandle>& channel_id,
-                         const std::optional<Shape>& shape_with_layout);
+                         const std::optional<Shape>& shape_with_layout,
+                         const std::optional<bool> use_global_device_ids);
   friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
                              int64_t scatter_dimension, int64_t shard_count,
                              absl::Span<const ReplicaGroup> replica_groups,
@@ -2348,7 +2351,8 @@
 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
                 absl::Span<const ReplicaGroup> replica_groups = {},
                 const std::optional<ChannelHandle>& channel_id = std::nullopt,
-                const std::optional<Shape>& shape_with_layout = std::nullopt);
+                const std::optional<Shape>& shape_with_layout = std::nullopt,
+                const std::optional<bool> use_global_device_ids = std::nullopt);
 
 XlaOp ReduceScatter(
     XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index aecca02..023b827 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1045,10 +1045,18 @@
   let arguments = (ins
     HLO_Tensor:$operand,
     I64ElementsAttr:$replica_groups,
-    OptionalAttr<ChannelHandle>:$channel_handle
+    OptionalAttr<ChannelHandle>:$channel_handle,
+    UnitAttr:$use_global_device_ids
   );
   let regions = (region SizedRegion<1>:$computation);
   let results = (outs HLO_Tensor);
+  // use_global_device_ids is rarely used, so we add a simplified
+  // builder method for convenience.
+  let builders = [
+    OpBuilder<(ins
+      "::mlir::Type":$result_type, "::mlir::Value":$operand,
+      "::mlir::DenseIntElementsAttr":$replica_groups,
+      "::mlir::mhlo::ChannelHandleAttr":$channel_handle)>];
 
   let hasCustomHLOConverter = 1;
 }
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index d0145cd..6a246e9 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -617,6 +617,19 @@
 }  // namespace
 
 //===----------------------------------------------------------------------===//
+// AllReduceOp
+//===----------------------------------------------------------------------===//
+
+void AllReduceOp::build(
+    ::mlir::OpBuilder& ods_builder, ::mlir::OperationState& ods_state,
+    ::mlir::Type result_type, ::mlir::Value operand,
+    ::mlir::DenseIntElementsAttr replica_groups,
+    /*optional*/ ::mlir::mhlo::ChannelHandleAttr channel_handle) {
+  AllReduceOp::build(ods_builder, ods_state, result_type, operand,
+                     replica_groups, channel_handle, nullptr);
+}
+
+//===----------------------------------------------------------------------===//
 // ReduceScatterOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
index 1e0c770..9e5028d 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
@@ -41,6 +41,26 @@
 
 // -----
 
+func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+  %0 = "mhlo.all_reduce"(%arg0) ({
+  // Perform max reduction inside the region
+  ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
+    %max = mhlo.maximum %lhs, %rhs : tensor<f32>
+    "mhlo.return"(%max) : (tensor<f32>) -> ()
+  })
+  {
+    replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
+    channel_handle = #mhlo.channel_handle<
+      handle = 5,
+      type = 2
+    >,
+    use_global_device_ids
+  } : (tensor<10xf32>) -> tensor<10xf32>
+  func.return %0 : tensor<10xf32>
+}
+
+// -----
+
 func.func @invalid_reduce_scatter(%data: tensor<4x0xf32>) -> tensor<4x4xf32> {
   // expected-error@+1 {{operand scatter dimension cannot be zero}}
   %0 = "mhlo.reduce_scatter"(%data) ({
diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc
index 62025b4..7400eee 100644
--- a/tensorflow/compiler/xla/python/ops.cc
+++ b/tensorflow/compiler/xla/python/ops.cc
@@ -76,16 +76,16 @@
           py::arg("channel_id") = std::nullopt,
           py::arg("shape_with_layout") = std::nullopt,
           py::arg("use_global_device_ids") = std::nullopt);
-  ops.def(
-      "AllReduce",
-      static_cast<XlaOp (*)(
-          XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
-          const std::optional<ChannelHandle>&, const std::optional<Shape>&)>(
-          &AllReduce),
-      py::arg("operand"), py::arg("computation"),
-      py::arg("replica_groups") = py::list(),
-      py::arg("channel_id") = std::nullopt,
-      py::arg("shape_with_layout") = std::nullopt);
+  ops.def("AllReduce",
+          static_cast<XlaOp (*)(
+              XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
+              const std::optional<ChannelHandle>&, const std::optional<Shape>&,
+              const std::optional<bool>)>(&AllReduce),
+          py::arg("operand"), py::arg("computation"),
+          py::arg("replica_groups") = py::list(),
+          py::arg("channel_id") = std::nullopt,
+          py::arg("shape_with_layout") = std::nullopt,
+          py::arg("use_global_device_ids") = std::nullopt);
   ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"),
           py::arg("computation"), py::arg("scatter_dimension"),
           py::arg("shard_count"), py::arg("replica_groups") = py::list(),