Sort `tf.TPUReplicatedInput` ops by index attribute when forming `tf_device.replicate` in TPU cluster formation pass.

Padding map relies on index set in `tf.TPUReplicatedInput` ops. To make non negative indices ordering be deterministic, the ops can be sorted by index when added to the `tf_device.replicate` op. Ordering of `tf.TPUReplicatedInput` with an index not set (default -1) can be ignored.

PiperOrigin-RevId: 284790304
Change-Id: I830a77c03273cbe7c2cebc21762579c69478f4d6
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
index e2c81c9..86e6f1b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
@@ -344,7 +344,8 @@
 // CHECK:      %[[OP_B:[0-9]*]] = "tf.opB"
 // CHECK:      %[[OP_C:[0-9]*]] = "tf.opC"
 // CHECK:      %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
-// CHECK-SAME: ([%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>, [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>)
+// CHECK-DAG:  [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
+// CHECK-DAG:  [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
 // CHECK-SAME: n = 2 : i32
 // CHECK-NEXT:   %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( {
 // CHECK:          %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
@@ -357,6 +358,32 @@
 // CHECK:      return %[[REPLICATE]]#0, %[[REPLICATE]]#3
 
 
+// Test `tf.TPUReplicatedInput` ops are sorted by their `index` attribute.
+// Non-negative `index` should preceed `index` of -1, and ordering of ops with
+// `index` of -1 does not matter.
+// CHECK-LABEL: func @sort_replicated_input
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>, %[[ARG_3:.*]]: tensor<i1>, %[[ARG_4:.*]]: tensor<i1>, %[[ARG_5:.*]]: tensor<i1>)
+func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>, %arg5: tensor<i1>) {
+  %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  %1 = "tf.TPUReplicatedInput"(%arg1, %arg1) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  %2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 0 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  %3 = "tf.TPUReplicatedInput"(%arg3, %arg3) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  %4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  %5 = "tf.TPUReplicatedInput"(%arg5, %arg5) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  "tf.opA"(%0, %1, %2, %3, %4, %5) {_tpu_replicate = "replicate", device = "device"} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+  return
+}
+
+// CHECK:      tf_device.replicate
+// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
+// CHECK-SAME: [%[[ARG_4]], %[[ARG_4]]] as %{{[a-z0-9]*}}
+// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}}
+// CHECK-DAG:  [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
+// CHECK-DAG:  [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}}
+// CHECK-DAG:  [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}}
+
+
 // -----
 
 
@@ -441,3 +468,44 @@
   %0:2 = "tf.TPUReplicatedOutput"(%arg0) : (tensor<i1>) -> (tensor<i1>, tensor<i1>)
   return
 }
+
+
+// -----
+
+
+// Test bad TPUReplicatedInput positive `index` attribute.
+func @bad_positive_index_input(%arg0: tensor<i1>) {
+  // expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got 1}}
+  %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  "tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+  return
+}
+
+
+// -----
+
+
+// Test bad TPUReplicatedInput negative `index` attribute.
+func @bad_negative_index_input(%arg0: tensor<i1>) {
+  // expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got -2}}
+  %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  "tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+  return
+}
+
+
+// -----
+
+
+// Test TPUReplicatedInput with conflicting `index` attribute. This will result
+// in gaps in the TPUReplicatedInput ordering.
+func @input_index_gaps(%arg0: tensor<i1>) {
+  // expected-error@+1 {{failed to sort 'tf.TPUReplicatedInput' ops, gap(s) found in indices}}
+  %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  %1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+  "tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> ()
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
index 3fb311f..7ac179d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
@@ -59,6 +59,7 @@
 constexpr char kDeviceAttr[] = "device";
 constexpr char kNameAttr[] = "name";
 constexpr char kNumReplicasAttr[] = "num_replicas";
+constexpr char kIndexAttr[] = "index";
 
 constexpr char kBadTPUReplicateAttrMsg[] =
     "requires '_tpu_replicate' string attribute";
@@ -259,6 +260,40 @@
   for (Operation* user : preceding_users) user->moveBefore(op_after_launch_op);
 }
 
+// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
+// of -1 are always after ops with a non negative `index`, and an arbitrary
+// ordering is used as there are no dependencies on their relative ordering.
+LogicalResult SortTPUReplicatedInputsByIndex(
+    llvm::ArrayRef<Operation*> inputs,
+    llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
+  const int input_size = inputs.size();
+  sorted_inputs->resize(input_size, nullptr);
+  int last_index = input_size - 1;
+
+  for (Operation* input : inputs) {
+    int64_t index = -1;
+    if (auto index_attr = input->getAttrOfType<IntegerAttr>(kIndexAttr))
+      index = index_attr.getInt();
+
+    if (index >= input_size || index < -1)
+      return input->emitError() << "'" << input->getName().getStringRef()
+                                << "' index is not in range [-1, " << input_size
+                                << "), got " << index;
+
+    if (index == -1)
+      (*sorted_inputs)[last_index--] = input;
+    else
+      (*sorted_inputs)[index] = input;
+  }
+
+  if (llvm::any_of(*sorted_inputs, [](Operation* op) { return op == nullptr; }))
+    return inputs.front()->emitError()
+           << "failed to sort '" << inputs.front()->getName().getStringRef()
+           << "' ops, gap(s) found in indices";
+
+  return success();
+}
+
 // Creates a `tf_device.replicate` to represent replication for the cluster, if
 // necessary.
 LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
@@ -270,14 +305,18 @@
     return launch_op.emitError() << "requires '" << kNumReplicasAttr
                                  << "' int attribute to be at least 1";
 
-  // Collect all used TPUReplicatedInput ops.
-  llvm::SmallSetVector<Operation*, 8> replicated_input_ops;
+  // Collect all used TPUReplicatedInput ops and sort by `index`.
+  llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
   mlir::visitUsedValuesDefinedAbove(
       launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
         Operation* def = operand->get()->getDefiningOp();
         if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
-          replicated_input_ops.insert(def);
+          unique_replicated_input_ops.insert(def);
       });
+  llvm::SmallVector<Operation*, 8> replicated_input_ops;
+  if (failed(SortTPUReplicatedInputsByIndex(
+          unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
+    return failure();
 
   // Check if number of operands of each used TPUReplicatedInput op matches
   // `num_replicas`. Collect all their operands and associated type for creating