[XLA:GPU] Add support for PartitionId
PiperOrigin-RevId: 354599221
Change-Id: I8afe7e516507031172876bc19355127f5acf3a0b
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
index 57fdfb6..a948c32 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
@@ -637,6 +637,14 @@
}];
}
+class BASE_HLO_PartitionIdOp {
+ string summary = "PartitionId operator";
+
+ string description = [{
+ Returns the unique ID (int32 scalar) of the partition.
+ }];
+}
+
class BASE_HLO_AllReduceOp {
string summary = "AllReduce operator";
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
index ee39bfc..9706473 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
@@ -608,6 +608,10 @@
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
+def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
+ let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
+}
+
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index ef6ab6c..038775c 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -323,6 +323,8 @@
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
case HloOpcode::kOutfeed:
return EmitOutfeedOp(instr);
+ case HloOpcode::kPartitionId:
+ return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
case HloOpcode::kPad:
return EmitPadOp(instr);
case HloOpcode::kPopulationCount:
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bef6c95..f440377 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2948,15 +2948,27 @@
return Status::OK();
}
+template <typename ThunkType, typename OpT>
+Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
+ MlirEmitterInput input) {
+ auto op = mlir::cast<OpT>(input.op);
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
+ GetAllocationSliceForMlir(op.getOperand()));
+ AddThunkToThunkSequence(
+ absl::make_unique<ThunkType>(input.thunk_info, result_slice));
+ return Status::OK();
+}
+
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
- auto replica_id_op = mlir::cast<mlir::lmhlo::ReplicaIdOp>(input.op);
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
- GetAllocationSliceForMlir(replica_id_op.getOperand()));
- AddThunkToThunkSequence(
- absl::make_unique<ReplicaIdThunk>(input.thunk_info, result_slice));
+ return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
+ mlir::lmhlo::ReplicaIdOp>(input);
+}
- return Status::OK();
+Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
+ TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
+ return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
+ mlir::lmhlo::PartitionIdOp>(input);
}
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index d6192b9..9c29bed 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -205,7 +205,12 @@
Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input);
Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleAfterAll(HloInstruction* after_all) override;
+
+ template <typename ThunkType, typename OpT>
+ Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input);
Status HandleReplicaId(HloInstruction* hlo) override;
+ Status HandlePartitionId(HloInstruction* hlo) override;
+
Status HandleCollectivePermute(HloInstruction* hlo) override;
Status EmitOp(MlirEmitterInput mlir_input);
diff --git a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc
index 3d09a19..8f0b2ff 100644
--- a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc
@@ -18,11 +18,7 @@
namespace xla {
namespace gpu {
-ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
- const BufferAllocation::Slice& dest)
- : Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
-
-Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
+Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
@@ -30,9 +26,10 @@
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
- TF_ASSIGN_OR_RETURN(int replica_id,
- params.device_assn->ReplicaIdForDevice(global_device_id));
- params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4);
+ TF_ASSIGN_OR_RETURN(auto logical_ids, params.device_assn->LogicalIdsForDevice(
+ global_device_id));
+ int id = kind() == Kind::kReplicaId ? logical_ids.first : logical_ids.second;
+ params.stream->ThenMemset32(&dest_addr, id, /*size=*/4);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h
index 80aee41..b16a7a1 100644
--- a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h
@@ -23,17 +23,31 @@
namespace xla {
namespace gpu {
-// Thunk that implements the ReplicaId HLO.
-class ReplicaIdThunk : public Thunk {
- public:
- ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
-
+// Thunk that implements the ReplicaId(Idx == 0) or PartitionId(Idx == 1).
+class ReplicaOrPartitionIdThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
+ protected:
+ ReplicaOrPartitionIdThunk(Kind kind, ThunkInfo thunk_info,
+ const BufferAllocation::Slice& dest)
+ : Thunk(kind, thunk_info), dest_(dest) {}
+
private:
const BufferAllocation::Slice dest_;
};
+class ReplicaIdThunk : public ReplicaOrPartitionIdThunk {
+ public:
+ ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
+ : ReplicaOrPartitionIdThunk(Kind::kReplicaId, thunk_info, dest) {}
+};
+
+class PartitionIdThunk : public ReplicaOrPartitionIdThunk {
+ public:
+ PartitionIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
+ : ReplicaOrPartitionIdThunk(Kind::kPartitionId, thunk_info, dest) {}
+};
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc
index db4b2ff..5c49b9d 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/thunk.cc
@@ -72,6 +72,8 @@
return "kOutfeed";
case Thunk::kReplicaId:
return "kReplicaId";
+ case Thunk::kPartitionId:
+ return "kPartitionId";
case Thunk::kSequential:
return "kSequential";
case Thunk::kTriangularSolve:
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 9166ee1..7a0e031 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -64,6 +64,7 @@
kNcclAllToAll,
kOutfeed,
kReplicaId,
+ kPartitionId,
kSequential,
kTriangularSolve,
kTuple,