Add lowering for tf.XlaReplicaId
PiperOrigin-RevId: 409003187
Change-Id: I1eba79fb8ea4c47244ab6ef25467b6fbf292e753
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 107f027..b70a4e8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -383,7 +383,6 @@
TypeID::get<TF::XdivyOp>(),
TypeID::get<TF::XlaAllReduceOp>(),
TypeID::get<TF::XlaGatherOp>(),
- TypeID::get<TF::XlaReplicaIdOp>(),
TypeID::get<TF::Xlog1pyOp>(),
TypeID::get<TF::ZerosLikeOp>(),
diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc
index a03e5bc..76907a5 100644
--- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc
@@ -167,6 +167,7 @@
TypeID::get<TF::XlaDotV2Op>(),
TypeID::get<TF::XlaDynamicSliceOp>(),
TypeID::get<TF::XlaEinsumOp>(),
+ TypeID::get<TF::XlaReplicaIdOp>(),
TypeID::get<TF::XlaSortOp>(),
TypeID::get<TF::XlogyOp>(),
TypeID::get<TF::ZetaOp>(),