[XLA:GPU] Add support for PartitionId

PiperOrigin-RevId: 354599221
Change-Id: I8afe7e516507031172876bc19355127f5acf3a0b
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
index 57fdfb6..a948c32 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
@@ -637,6 +637,14 @@
   }];
 }
 
+class BASE_HLO_PartitionIdOp {
+  string summary = "PartitionId operator";
+
+  string description = [{
+    Returns the unique ID (int32 scalar) of the partition.
+  }];
+}
+
 
 class BASE_HLO_AllReduceOp {
   string summary = "AllReduce operator";
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
index ee39bfc..9706473 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
@@ -608,6 +608,10 @@
   let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
 }
 
+def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
+  let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
+}
+
 def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
                             BASE_HLO_TriangularSolveOp {
   let arguments = (ins
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index ef6ab6c..038775c 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -323,6 +323,8 @@
       return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
     case HloOpcode::kOutfeed:
       return EmitOutfeedOp(instr);
+    case HloOpcode::kPartitionId:
+      return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
     case HloOpcode::kPad:
       return EmitPadOp(instr);
     case HloOpcode::kPopulationCount:
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bef6c95..f440377 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2948,15 +2948,27 @@
   return Status::OK();
 }
 
+template <typename ThunkType, typename OpT>
+Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
+    MlirEmitterInput input) {
+  auto op = mlir::cast<OpT>(input.op);
+  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
+                      GetAllocationSliceForMlir(op.getOperand()));
+  AddThunkToThunkSequence(
+      absl::make_unique<ThunkType>(input.thunk_info, result_slice));
+  return Status::OK();
+}
+
 Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
   TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
-  auto replica_id_op = mlir::cast<mlir::lmhlo::ReplicaIdOp>(input.op);
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
-                      GetAllocationSliceForMlir(replica_id_op.getOperand()));
-  AddThunkToThunkSequence(
-      absl::make_unique<ReplicaIdThunk>(input.thunk_info, result_slice));
+  return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
+                                          mlir::lmhlo::ReplicaIdOp>(input);
+}
 
-  return Status::OK();
+Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
+  TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
+  return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
+                                          mlir::lmhlo::PartitionIdOp>(input);
 }
 
 Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index d6192b9..9c29bed 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -205,7 +205,12 @@
   Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input);
   Status HandleAllToAll(HloInstruction* hlo) override;
   Status HandleAfterAll(HloInstruction* after_all) override;
+
+  template <typename ThunkType, typename OpT>
+  Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input);
   Status HandleReplicaId(HloInstruction* hlo) override;
+  Status HandlePartitionId(HloInstruction* hlo) override;
+
   Status HandleCollectivePermute(HloInstruction* hlo) override;
 
   Status EmitOp(MlirEmitterInput mlir_input);
diff --git a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc
index 3d09a19..8f0b2ff 100644
--- a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc
@@ -18,11 +18,7 @@
 namespace xla {
 namespace gpu {
 
-ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
-                               const BufferAllocation::Slice& dest)
-    : Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
-
-Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
+Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) {
   auto op_profiler =
       params.profiler->MakeScopedInstructionProfiler(profile_index());
 
@@ -30,9 +26,10 @@
 
   TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
                       params.GetGlobalDeviceId());
-  TF_ASSIGN_OR_RETURN(int replica_id,
-                      params.device_assn->ReplicaIdForDevice(global_device_id));
-  params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4);
+  TF_ASSIGN_OR_RETURN(auto logical_ids, params.device_assn->LogicalIdsForDevice(
+                                            global_device_id));
+  int id = kind() == Kind::kReplicaId ? logical_ids.first : logical_ids.second;
+  params.stream->ThenMemset32(&dest_addr, id, /*size=*/4);
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h
index 80aee41..b16a7a1 100644
--- a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h
@@ -23,17 +23,31 @@
 namespace xla {
 namespace gpu {
 
-// Thunk that implements the ReplicaId HLO.
-class ReplicaIdThunk : public Thunk {
- public:
-  ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
-
+// Thunk that implements the ReplicaId(Idx == 0) or PartitionId(Idx == 1).
+class ReplicaOrPartitionIdThunk : public Thunk {
   Status ExecuteOnStream(const ExecuteParams& params) override;
 
+ protected:
+  ReplicaOrPartitionIdThunk(Kind kind, ThunkInfo thunk_info,
+                            const BufferAllocation::Slice& dest)
+      : Thunk(kind, thunk_info), dest_(dest) {}
+
  private:
   const BufferAllocation::Slice dest_;
 };
 
+class ReplicaIdThunk : public ReplicaOrPartitionIdThunk {
+ public:
+  ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
+      : ReplicaOrPartitionIdThunk(Kind::kReplicaId, thunk_info, dest) {}
+};
+
+class PartitionIdThunk : public ReplicaOrPartitionIdThunk {
+ public:
+  PartitionIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
+      : ReplicaOrPartitionIdThunk(Kind::kPartitionId, thunk_info, dest) {}
+};
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc
index db4b2ff..5c49b9d 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/thunk.cc
@@ -72,6 +72,8 @@
       return "kOutfeed";
     case Thunk::kReplicaId:
       return "kReplicaId";
+    case Thunk::kPartitionId:
+      return "kPartitionId";
     case Thunk::kSequential:
       return "kSequential";
     case Thunk::kTriangularSolve:
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 9166ee1..7a0e031 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -64,6 +64,7 @@
     kNcclAllToAll,
     kOutfeed,
     kReplicaId,
+    kPartitionId,
     kSequential,
     kTriangularSolve,
     kTuple,