[XLA:GPU] Add AllReduce{Start,Done} to MLIR LHLO dialect.
PiperOrigin-RevId: 378640706
Change-Id: Id6db9a5737cd43b5068b65c69057bb3dd4297099
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 5340d15..fae8798 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -534,6 +534,7 @@
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h",
+ "include/mlir-hlo/utils/lhlo_utils.h",
],
includes = ["include"],
deps = [
@@ -578,6 +579,7 @@
":hlo_ops_base_structs",
":hlo_ops_common",
":infer_fusibility_op_interface",
+ ":lhlo",
":lhlo_gpu_ops_enums",
":lhlo_gpu_ops_inc_gen",
":lhlo_gpu_ops_structs",
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
index f087a99..2c0b18d 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td
@@ -230,4 +230,31 @@
BoolAttr:$is_lower);
}
+def LHLOGPU_AllReduceStartOp :
+ LHLOGPU_Op<"all_reduce_start", [SameOperandsElementType, SameVariadicOperandSize]> {
+ let summary = "AllReduceStart operator";
+ let description = [{
+ Performs an asynchronous custom reduction across replicas.
+ }];
+ let arguments = (ins
+ Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
+ Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
+ I64ElementsAttr:$replica_groups,
+ DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
+ OptionalAttr<ChannelHandle>:$channel_id,
+ DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
+ );
+ let regions = (region SizedRegion<1>:$computation);
+ let verifier = [{ return Verify(*this); }];
+}
+
+def LHLOGPU_AllReduceDoneOp:
+ LHLOGPU_Op<"all_reduce_done", [SameVariadicOperandSize]> {
+ let summary = "AllReduceDone operator";
+ let arguments = (ins
+ Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
+ Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results
+ );
+}
+
#endif // LHLO_GPU_OPS
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/lhlo_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/lhlo_utils.h
new file mode 100644
index 0000000..0dbbb18
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/lhlo_utils.h
@@ -0,0 +1,100 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
+
+#include "llvm/ADT/SmallSet.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace lmhlo {
+
+// Verifies replica groups attached to collective communication operations.
+// If the attribute is not empty, it must be a rank 2 tensor, and each replica
+// should appear exactly once. If `is_uniform_sized` is true, then we also check
+// that each group is of the same size. If the operation has
+// `use_global_device_ids` set, then replica group cannot be empty.
+template <typename OpT>
+LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) {
+ DenseIntElementsAttr attr = op.replica_groups();
+ auto replica_group_type = attr.getType().dyn_cast<RankedTensorType>();
+ if (!replica_group_type || replica_group_type.getRank() != 2 ||
+ !replica_group_type.getElementType().isInteger(/*width=*/64))
+ return op.emitOpError(
+ "replica groups should be a rank 2 tensor of 64 bit integers");
+
+ if (replica_group_type.getShape().equals(ArrayRef<int64_t>{0, 0})) {
+ if (op.use_global_device_ids()) {
+ return op.emitOpError(
+ "if `use_global_device_ids` is set, the replica groups cannot be "
+ "empty");
+ }
+ return success();
+ }
+
+ int64_t max_replica_id_seen = 0;
+ llvm::SmallSet<int64_t, 8> replica_seen;
+ for (int64_t id : attr.getValues<int64_t>()) {
+ // Replica groups are stored in a 2D tensor. If the op supports non-uniform
+ // groups, null replica IDs are stored as -1.
+ if (id == -1) {
+ if (is_uniform_sized) {
+ return op.emitOpError("Invalid replica id -1");
+ }
+ continue;
+ }
+
+ if (!replica_seen.insert(id).second) {
+ return op.emitOpError("replica id #") << id << " seen more than once";
+ }
+ max_replica_id_seen = std::max(max_replica_id_seen, id);
+ }
+
+ for (int64_t id = 0; id <= max_replica_id_seen; id++) {
+ if (!replica_seen.contains(id)) {
+ return op.emitOpError("replica id #")
+ << id << " not seen in replica groups";
+ }
+ }
+ return success();
+}
+
+template <typename OpT>
+static LogicalResult VerifyAllReduce(OpT op) {
+ if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
+ return failure();
+
+ // AllReduce has variadic operands and results that have the same size.
+ // Each member of the operand should have the same type as the corresponding
+ // member of the result.
+ for (auto it : llvm::enumerate(
+ llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
+ Type operandType = std::get<0>(it.value());
+ Type resultType = std::get<1>(it.value());
+ if (operandType != resultType)
+ return op.emitOpError("requires operand #")
+ << it.index() << " (type: " << operandType << ") and result #"
+ << it.index() << " (type: " << resultType << ") to have same type";
+ }
+ return success();
+}
+
+} // namespace lmhlo
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
index 42c97ac..4f6d407 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc
@@ -29,6 +29,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
+#include "mlir-hlo/utils/lhlo_utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -61,6 +62,14 @@
using mlir::hlo::parseWindowAttributes;
using mlir::hlo::printWindowAttributes;
+//===----------------------------------------------------------------------===//
+// AllReduceStartOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(AllReduceStartOp op) {
+ return lmhlo::VerifyAllReduce(op);
+}
+
} // namespace lmhlo_gpu
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
index 72be3a0..73e7985 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
@@ -33,6 +33,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
+#include "mlir-hlo/utils/lhlo_utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -86,46 +87,6 @@
// AllToAllOp
//===----------------------------------------------------------------------===//
-// Verifies replica groups attached to collective communication operations.
-// If the attribute is not empty, it must be a rank 2 tensor, and each replica
-// should appear exactly once. If `is_uniform_sized` is true, then we also check
-// that each group is of the same size. If the operation has
-// `use_global_device_id` set, then replica group cannot be empty.
-template <typename OpT>
-LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) {
- DenseIntElementsAttr attr = op.replica_groups();
- auto replica_group_type = attr.getType().dyn_cast<RankedTensorType>();
- if (!replica_group_type || replica_group_type.getRank() != 2 ||
- !replica_group_type.getElementType().isInteger(/*width=*/64))
- return op.emitOpError(
- "replica groups should be a rank 2 tensor of 64 bit integers");
-
- if (replica_group_type.getShape().equals(ArrayRef<int64_t>{0, 0}))
- return success();
-
- int64_t max_replica_id_seen = 0;
- llvm::SmallSet<int64_t, 8> replica_seen;
- for (int64_t id : attr.getValues<int64_t>()) {
- if (is_uniform_sized && id == -1) {
- return op.emitOpError("Invalid replica id -1");
- }
- if (id != -1) {
- if (!replica_seen.insert(id).second) {
- return op.emitOpError("replica id #") << id << " seen more than once";
- }
- max_replica_id_seen = std::max(max_replica_id_seen, id);
- }
- }
-
- for (int64_t id = 0; id <= max_replica_id_seen; id++) {
- if (!replica_seen.contains(id)) {
- return op.emitOpError("replica id #")
- << id << " not seen in replica groups";
- }
- }
- return success();
-}
-
// TODO(jurahul): Add verification for output shape.
static LogicalResult Verify(AllGatherOp op) {
return VerifyReplicaGroups(op, /*is_uniform_sized=*/true);
@@ -140,24 +101,7 @@
// AllReduceOp
//===----------------------------------------------------------------------===//
-static LogicalResult Verify(AllReduceOp op) {
- if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
- return failure();
-
- // AllReduce has variadic operands and results that have the same size.
- // Each member of the operand should have the same type as the corresponding
- // member of the result.
- for (auto it : llvm::enumerate(
- llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
- Type operandType = std::get<0>(it.value());
- Type resultType = std::get<1>(it.value());
- if (operandType != resultType)
- return op.emitOpError("requires operand #")
- << it.index() << " (type: " << operandType << ") and result #"
- << it.index() << " (type: " << resultType << ") to have same type";
- }
- return success();
-}
+static LogicalResult Verify(AllReduceOp op) { return VerifyAllReduce(op); }
//===----------------------------------------------------------------------===//
// CaseOp
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index 2e9297e..f83cd56 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -272,6 +272,70 @@
// -----
+HloModule AsyncAllReduce
+
+// Test all-reduce
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK-LABEL: func @test_async_all_reduce
+// CHECK-SAME: [[INPUT:%.*]]: memref<8xf32>
+// CHECK-SAME: [[OUTPUT_BYTES:%.*]]: memref<32xi8>
+%test_async_all_reduce {
+ param0 = f32[8] parameter(0)
+ // CHECK: [[OUTPUT:%.*]] = memref.view [[OUTPUT_BYTES]]{{.*}} : memref<32xi8> to memref<8xf32>
+ // CHECK: "lmhlo_gpu.all_reduce_start"([[INPUT]], [[OUTPUT]])
+ // CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+ // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+ // CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
+ // CHECK: }) {
+ // CHECK-SAME: channel_id = {handle = 1 : i64, type = 0 : i64}
+ // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
+ // CHECK: "lmhlo_gpu.all_reduce_done"([[INPUT]], [[OUTPUT]])
+ start = (f32[8], f32[8]) all-reduce-start(param0),
+ channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
+ ROOT done = f32[8] all-reduce-done(start)
+}
+
+// -----
+
+HloModule AsyncAllReduceTwoOperands
+
+// Test all-reduce
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+// CHECK-LABEL: func @test_async_all_reduce_two_operands
+// CHECK-SAME: [[INPUT0:%.*]]: memref<8xf32>
+// CHECK-SAME: [[INPUT1:%.*]]: memref<9xf32>
+// CHECK-SAME: [[OUTPUT1_BYTES:%.*]]: memref<36xi8>
+// CHECK-SAME: [[OUTPUT0_BYTES:%.*]]: memref<32xi8>
+%test_async_all_reduce_two_operands {
+ param0 = f32[8] parameter(0)
+ param1 = f32[9] parameter(1)
+ // CHECK: [[OUTPUT0:%.*]] = memref.view [[OUTPUT0_BYTES]]{{.*}} : memref<32xi8> to memref<8xf32>
+ // CHECK: [[OUTPUT1:%.*]] = memref.view [[OUTPUT1_BYTES]]{{.*}} : memref<36xi8> to memref<9xf32>
+ // CHECK: "lmhlo_gpu.all_reduce_start"([[INPUT0]], [[INPUT1]], [[OUTPUT0]], [[OUTPUT1]])
+ // CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+ // CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+ // CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
+ // CHECK: }) {
+ // CHECK-SAME: channel_id = {handle = 1 : i64, type = 0 : i64}
+ // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
+ // CHECK: "lmhlo_gpu.all_reduce_done"([[INPUT0]], [[INPUT1]], [[OUTPUT0]], [[OUTPUT1]])
+ start = ((f32[8], f32[9]), (f32[8], f32[9])) all-reduce-start(param0, param1),
+ channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
+ ROOT done = (f32[8], f32[9]) all-reduce-done(start)
+}
+
+// -----
+
HloModule ConvForward
// CHECK-LABEL: func @main
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 71405ee..9f0ed97 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -281,6 +281,10 @@
return EmitAllGatherOp(instr);
case HloOpcode::kAllReduce:
return EmitAllReduceOp(instr);
+ case HloOpcode::kAllReduceStart:
+ return EmitAllReduceStartOp(instr);
+ case HloOpcode::kAllReduceDone:
+ return EmitAllReduceDoneOp(instr);
case HloOpcode::kAnd:
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
case HloOpcode::kAtan2:
@@ -1194,6 +1198,43 @@
return all_reduce_op;
}
+StatusOr<lmhlo_gpu::AllReduceStartOp> LhloDialectEmitter::EmitAllReduceStartOp(
+ const HloInstruction* instr) {
+ llvm::SmallVector<Value, 4> operands;
+ for (const HloInstruction* operand : instr->operands()) {
+ TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
+ }
+ // Only include result index {1}. {0} always aliases the inputs.
+ TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{1}));
+
+ Location loc = getLocation(instr);
+ lmhlo_gpu::AllReduceStartOp all_reduce_start_op =
+ builder_.create<lmhlo_gpu::AllReduceStartOp>(
+ loc, llvm::None, operands, llvm::ArrayRef<NamedAttribute>{});
+
+ auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
+ TF_RETURN_IF_ERROR(
+ SetupCommonCollectiveOpAttributes(all_reduce_start_op, instr, builder_));
+ all_reduce_start_op.use_global_device_idsAttr(
+ builder_.getBoolAttr(all_reduce->use_global_device_ids()));
+ TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion(
+ *instr->called_computations()[0], &all_reduce_start_op.computation(),
+ &builder_));
+ return all_reduce_start_op;
+}
+
+StatusOr<lmhlo_gpu::AllReduceDoneOp> LhloDialectEmitter::EmitAllReduceDoneOp(
+ const HloInstruction* instr) {
+ llvm::SmallVector<Value, 4> operands;
+ for (const HloInstruction* operand : instr->operands()) {
+ TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
+ }
+ // We don't need to add buffers for the outputs, as these always alias inputs.
+ return builder_.create<lmhlo_gpu::AllReduceDoneOp>(
+ getLocation(instr), llvm::None, operands,
+ llvm::ArrayRef<NamedAttribute>{});
+}
+
StatusOr<lmhlo::CollectivePermuteOp>
LhloDialectEmitter::EmitCollectivePermuteOp(const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto permute_op,
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index a9dd4f3..efbf3a7 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -95,6 +95,10 @@
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
const xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo_gpu::AllReduceStartOp> EmitAllReduceStartOp(
+ const xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo_gpu::AllReduceDoneOp> EmitAllReduceDoneOp(
+ const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::CollectivePermuteOp> EmitCollectivePermuteOp(
const xla::HloInstruction* instr);
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index f7abe20..c4ace6d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -194,6 +194,7 @@
break;
case HloOpcode::kAllReduce:
case HloOpcode::kAllReduceScatter:
+ case HloOpcode::kAllReduceStart:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow: