[XLA:GPU] Add AllReduce{Start,Done} to MLIR LHLO dialect.

PiperOrigin-RevId: 378640706
Change-Id: Id6db9a5737cd43b5068b65c69057bb3dd4297099
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 5340d15..fae8798 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -534,6 +534,7 @@
     hdrs = [
         "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h",
         "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h",
+        "include/mlir-hlo/utils/lhlo_utils.h",
     ],
     includes = ["include"],
     deps = [
@@ -578,6 +579,7 @@
         ":hlo_ops_base_structs",
         ":hlo_ops_common",
         ":infer_fusibility_op_interface",
+        ":lhlo",
         ":lhlo_gpu_ops_enums",
         ":lhlo_gpu_ops_inc_gen",
         ":lhlo_gpu_ops_structs",
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
index f087a99..2c0b18d 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
@@ -230,4 +230,31 @@
     BoolAttr:$is_lower);
 }
 
+def LHLOGPU_AllReduceStartOp :
+  LHLOGPU_Op<"all_reduce_start", [SameOperandsElementType, SameVariadicOperandSize]> {
+  let summary = "AllReduceStart operator";
+  let description = [{
+    Performs an asynchronous custom reduction across replicas.
+  }];
+  let arguments = (ins
+    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
+    Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
+    I64ElementsAttr:$replica_groups,
+    DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
+    OptionalAttr<ChannelHandle>:$channel_id,
+    DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
+  );
+  let regions = (region SizedRegion<1>:$computation);
+  let verifier = [{ return Verify(*this); }];
+}
+
+def LHLOGPU_AllReduceDoneOp:
+  LHLOGPU_Op<"all_reduce_done", [SameVariadicOperandSize]> {
+  let summary = "AllReduceDone operator";
+  let arguments = (ins
+    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
+    Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results
+  );
+}
+
 #endif // LHLO_GPU_OPS
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/lhlo_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/lhlo_utils.h
new file mode 100644
index 0000000..0dbbb18
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/lhlo_utils.h
@@ -0,0 +1,100 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
+
+#include "llvm/ADT/SmallSet.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace lmhlo {
+
+// Verifies replica groups attached to collective communication operations.
+// If the attribute is not empty, it must be a rank 2 tensor, and each replica
+// should appear exactly once. If `is_uniform_sized` is true, then we also check
+// that each group is of the same size. If the operation has
+// `use_global_device_ids` set, then replica group cannot be empty.
+template <typename OpT>
+LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) {
+  DenseIntElementsAttr attr = op.replica_groups();
+  auto replica_group_type = attr.getType().dyn_cast<RankedTensorType>();
+  if (!replica_group_type || replica_group_type.getRank() != 2 ||
+      !replica_group_type.getElementType().isInteger(/*width=*/64))
+    return op.emitOpError(
+        "replica groups should be a rank 2 tensor of 64 bit integers");
+
+  if (replica_group_type.getShape().equals(ArrayRef<int64_t>{0, 0})) {
+    if (op.use_global_device_ids()) {
+      return op.emitOpError(
+          "if `use_global_device_ids` is set, the replica groups cannot be "
+          "empty");
+    }
+    return success();
+  }
+
+  int64_t max_replica_id_seen = 0;
+  llvm::SmallSet<int64_t, 8> replica_seen;
+  for (int64_t id : attr.getValues<int64_t>()) {
+    // Replica groups are stored in a 2D tensor. If the op supports non-uniform
+    // groups, null replica IDs are stored as -1.
+    if (id == -1) {
+      if (is_uniform_sized) {
+        return op.emitOpError("Invalid replica id -1");
+      }
+      continue;
+    }
+
+    if (!replica_seen.insert(id).second) {
+      return op.emitOpError("replica id #") << id << " seen more than once";
+    }
+    max_replica_id_seen = std::max(max_replica_id_seen, id);
+  }
+
+  for (int64_t id = 0; id <= max_replica_id_seen; id++) {
+    if (!replica_seen.contains(id)) {
+      return op.emitOpError("replica id #")
+             << id << " not seen in replica groups";
+    }
+  }
+  return success();
+}
+
+template <typename OpT>
+static LogicalResult VerifyAllReduce(OpT op) {
+  if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
+    return failure();
+
+  // AllReduce has variadic operands and results that have the same size.
+  // Each member of the operand should have the same type as the corresponding
+  // member of the result.
+  for (auto it : llvm::enumerate(
+           llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
+    Type operandType = std::get<0>(it.value());
+    Type resultType = std::get<1>(it.value());
+    if (operandType != resultType)
+      return op.emitOpError("requires operand #")
+             << it.index() << " (type: " << operandType << ") and result #"
+             << it.index() << " (type: " << resultType << ") to have same type";
+  }
+  return success();
+}
+
+}  // namespace lmhlo
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
index 42c97ac..4f6d407 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
@@ -29,6 +29,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
+#include "mlir-hlo/utils/lhlo_utils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -61,6 +62,14 @@
 using mlir::hlo::parseWindowAttributes;
 using mlir::hlo::printWindowAttributes;
 
+//===----------------------------------------------------------------------===//
+// AllReduceStartOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(AllReduceStartOp op) {
+  return lmhlo::VerifyAllReduce(op);
+}
+
 }  // namespace lmhlo_gpu
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
index 72be3a0..73e7985 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
@@ -33,6 +33,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
+#include "mlir-hlo/utils/lhlo_utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
@@ -86,46 +87,6 @@
 // AllToAllOp
 //===----------------------------------------------------------------------===//
 
-// Verifies replica groups attached to collective communication operations.
-// If the attribute is not empty, it must be a rank 2 tensor, and each replica
-// should appear exactly once. If `is_uniform_sized` is true, then we also check
-// that each group is of the same size. If the operation has
-// `use_global_device_id` set, then replica group cannot be empty.
-template <typename OpT>
-LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) {
-  DenseIntElementsAttr attr = op.replica_groups();
-  auto replica_group_type = attr.getType().dyn_cast<RankedTensorType>();
-  if (!replica_group_type || replica_group_type.getRank() != 2 ||
-      !replica_group_type.getElementType().isInteger(/*width=*/64))
-    return op.emitOpError(
-        "replica groups should be a rank 2 tensor of 64 bit integers");
-
-  if (replica_group_type.getShape().equals(ArrayRef<int64_t>{0, 0}))
-    return success();
-
-  int64_t max_replica_id_seen = 0;
-  llvm::SmallSet<int64_t, 8> replica_seen;
-  for (int64_t id : attr.getValues<int64_t>()) {
-    if (is_uniform_sized && id == -1) {
-      return op.emitOpError("Invalid replica id -1");
-    }
-    if (id != -1) {
-      if (!replica_seen.insert(id).second) {
-        return op.emitOpError("replica id #") << id << " seen more than once";
-      }
-      max_replica_id_seen = std::max(max_replica_id_seen, id);
-    }
-  }
-
-  for (int64_t id = 0; id <= max_replica_id_seen; id++) {
-    if (!replica_seen.contains(id)) {
-      return op.emitOpError("replica id #")
-             << id << " not seen in replica groups";
-    }
-  }
-  return success();
-}
-
 // TODO(jurahul): Add verification for output shape.
 static LogicalResult Verify(AllGatherOp op) {
   return VerifyReplicaGroups(op, /*is_uniform_sized=*/true);
@@ -140,24 +101,7 @@
 // AllReduceOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult Verify(AllReduceOp op) {
-  if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
-    return failure();
-
-  // AllReduce has variadic operands and results that have the same size.
-  // Each member of the operand should have the same type as the corresponding
-  // member of the result.
-  for (auto it : llvm::enumerate(
-           llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
-    Type operandType = std::get<0>(it.value());
-    Type resultType = std::get<1>(it.value());
-    if (operandType != resultType)
-      return op.emitOpError("requires operand #")
-             << it.index() << " (type: " << operandType << ") and result #"
-             << it.index() << " (type: " << resultType << ") to have same type";
-  }
-  return success();
-}
+static LogicalResult Verify(AllReduceOp op) { return VerifyAllReduce(op); }
 
 //===----------------------------------------------------------------------===//
 // CaseOp
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index 2e9297e..f83cd56 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -272,6 +272,70 @@
 
 // -----
 
+HloModule AsyncAllReduce
+
+// Test all-reduce
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK-LABEL: func @test_async_all_reduce
+// CHECK-SAME: [[INPUT:%.*]]: memref<8xf32>
+// CHECK-SAME: [[OUTPUT_BYTES:%.*]]: memref<32xi8>
+%test_async_all_reduce {
+  param0 = f32[8] parameter(0)
+  // CHECK:  [[OUTPUT:%.*]] = memref.view [[OUTPUT_BYTES]]{{.*}} : memref<32xi8> to memref<8xf32>
+  // CHECK:  "lmhlo_gpu.all_reduce_start"([[INPUT]], [[OUTPUT]])
+  // CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+  // CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+  // CHECK:    "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
+  // CHECK:  }) {
+  // CHECK-SAME:  channel_id = {handle = 1 : i64, type = 0 : i64}
+  // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
+  // CHECK:  "lmhlo_gpu.all_reduce_done"([[INPUT]], [[OUTPUT]])
+  start = (f32[8], f32[8]) all-reduce-start(param0),
+      channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
+  ROOT done = f32[8] all-reduce-done(start)
+}
+
+// -----
+
+HloModule AsyncAllReduceTwoOperands
+
+// Test all-reduce
+add {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK-LABEL: func @test_async_all_reduce_two_operands
+// CHECK-SAME: [[INPUT0:%.*]]: memref<8xf32>
+// CHECK-SAME: [[INPUT1:%.*]]: memref<9xf32>
+// CHECK-SAME: [[OUTPUT1_BYTES:%.*]]: memref<36xi8>
+// CHECK-SAME: [[OUTPUT0_BYTES:%.*]]: memref<32xi8>
+%test_async_all_reduce_two_operands {
+  param0 = f32[8] parameter(0)
+  param1 = f32[9] parameter(1)
+  // CHECK:  [[OUTPUT0:%.*]] = memref.view [[OUTPUT0_BYTES]]{{.*}} : memref<32xi8> to memref<8xf32>
+  // CHECK:  [[OUTPUT1:%.*]] = memref.view [[OUTPUT1_BYTES]]{{.*}} : memref<36xi8> to memref<9xf32>
+  // CHECK:  "lmhlo_gpu.all_reduce_start"([[INPUT0]], [[INPUT1]], [[OUTPUT0]], [[OUTPUT1]])
+  // CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+  // CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+  // CHECK:    "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
+  // CHECK:  }) {
+  // CHECK-SAME:  channel_id = {handle = 1 : i64, type = 0 : i64}
+  // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
+  // CHECK:  "lmhlo_gpu.all_reduce_done"([[INPUT0]], [[INPUT1]], [[OUTPUT0]], [[OUTPUT1]])
+  start = ((f32[8], f32[9]), (f32[8], f32[9])) all-reduce-start(param0, param1),
+      channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
+  ROOT done = (f32[8], f32[9]) all-reduce-done(start)
+}
+
+// -----
+
 HloModule ConvForward
 
 // CHECK-LABEL: func @main
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 71405ee..9f0ed97 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -281,6 +281,10 @@
       return EmitAllGatherOp(instr);
     case HloOpcode::kAllReduce:
       return EmitAllReduceOp(instr);
+    case HloOpcode::kAllReduceStart:
+      return EmitAllReduceStartOp(instr);
+    case HloOpcode::kAllReduceDone:
+      return EmitAllReduceDoneOp(instr);
     case HloOpcode::kAnd:
       return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
     case HloOpcode::kAtan2:
@@ -1194,6 +1198,43 @@
   return all_reduce_op;
 }
 
+StatusOr<lmhlo_gpu::AllReduceStartOp> LhloDialectEmitter::EmitAllReduceStartOp(
+    const HloInstruction* instr) {
+  llvm::SmallVector<Value, 4> operands;
+  for (const HloInstruction* operand : instr->operands()) {
+    TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
+  }
+  // Only include result index {1}. {0} always aliases the inputs.
+  TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{1}));
+
+  Location loc = getLocation(instr);
+  lmhlo_gpu::AllReduceStartOp all_reduce_start_op =
+      builder_.create<lmhlo_gpu::AllReduceStartOp>(
+          loc, llvm::None, operands, llvm::ArrayRef<NamedAttribute>{});
+
+  auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
+  TF_RETURN_IF_ERROR(
+      SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_));
+  all_reduce_start_op.use_global_device_idsAttr(
+      builder_.getBoolAttr(all_reduce->use_global_device_ids()));
+  TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
+      *instr->called_computations()[0], &all_reduce_start_op.computation(),
+      &builder_));
+  return all_reduce_start_op;
+}
+
+StatusOr<lmhlo_gpu::AllReduceDoneOp> LhloDialectEmitter::EmitAllReduceDoneOp(
+    const HloInstruction* instr) {
+  llvm::SmallVector<Value, 4> operands;
+  for (const HloInstruction* operand : instr->operands()) {
+    TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
+  }
+  // We don't need to add buffers for the outputs, as these always alias inputs.
+  return builder_.create<lmhlo_gpu::AllReduceDoneOp>(
+      getLocation(instr), llvm::None, operands,
+      llvm::ArrayRef<NamedAttribute>{});
+}
+
 StatusOr<lmhlo::CollectivePermuteOp>
 LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto permute_op,
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index a9dd4f3..efbf3a7 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -95,6 +95,10 @@
       const xla::HloInstruction* instr);
   xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
       const xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo_gpu::AllReduceStartOp> EmitAllReduceStartOp(
+      const xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo_gpu::AllReduceDoneOp> EmitAllReduceDoneOp(
+      const xla::HloInstruction* instr);
   xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp(
       const xla::HloInstruction* instr);
 
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index f7abe20..c4ace6d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -194,6 +194,7 @@
             break;
           case HloOpcode::kAllReduce:
           case HloOpcode::kAllReduceScatter:
+          case HloOpcode::kAllReduceStart:
           case HloOpcode::kMap:
           case HloOpcode::kReduce:
           case HloOpcode::kReduceWindow: