Sort `tf.TPUReplicatedInput` ops by index attribute when forming `tf_device.replicate` in TPU cluster formation pass.
Padding map relies on index set in `tf.TPUReplicatedInput` ops. To make non negative indices ordering be deterministic, the ops can be sorted by index when added to the `tf_device.replicate` op. Ordering of `tf.TPUReplicatedInput` with an index not set (default -1) can be ignored.
PiperOrigin-RevId: 284790304
Change-Id: I830a77c03273cbe7c2cebc21762579c69478f4d6
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
index e2c81c9..86e6f1b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
@@ -344,7 +344,8 @@
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
-// CHECK-SAME: ([%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>, [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>)
+// CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
+// CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
// CHECK-SAME: n = 2 : i32
// CHECK-NEXT: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( {
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
@@ -357,6 +358,32 @@
// CHECK: return %[[REPLICATE]]#0, %[[REPLICATE]]#3
+// Test `tf.TPUReplicatedInput` ops are sorted by their `index` attribute.
+// Non-negative `index` should preceed `index` of -1, and ordering of ops with
+// `index` of -1 does not matter.
+// CHECK-LABEL: func @sort_replicated_input
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>, %[[ARG_3:.*]]: tensor<i1>, %[[ARG_4:.*]]: tensor<i1>, %[[ARG_5:.*]]: tensor<i1>)
+func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>, %arg5: tensor<i1>) {
+ %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ %1 = "tf.TPUReplicatedInput"(%arg1, %arg1) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ %2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 0 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ %3 = "tf.TPUReplicatedInput"(%arg3, %arg3) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ %4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ %5 = "tf.TPUReplicatedInput"(%arg5, %arg5) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ "tf.opA"(%0, %1, %2, %3, %4, %5) {_tpu_replicate = "replicate", device = "device"} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
+ "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+ return
+}
+
+// CHECK: tf_device.replicate
+// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
+// CHECK-SAME: [%[[ARG_4]], %[[ARG_4]]] as %{{[a-z0-9]*}}
+// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}}
+// CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
+// CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}}
+// CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}}
+
+
// -----
@@ -441,3 +468,44 @@
%0:2 = "tf.TPUReplicatedOutput"(%arg0) : (tensor<i1>) -> (tensor<i1>, tensor<i1>)
return
}
+
+
+// -----
+
+
+// Test bad TPUReplicatedInput positive `index` attribute.
+func @bad_positive_index_input(%arg0: tensor<i1>) {
+ // expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got 1}}
+ %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ "tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
+ "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+ return
+}
+
+
+// -----
+
+
+// Test bad TPUReplicatedInput negative `index` attribute.
+func @bad_negative_index_input(%arg0: tensor<i1>) {
+ // expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got -2}}
+ %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ "tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
+ "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+ return
+}
+
+
+// -----
+
+
+// Test TPUReplicatedInput with conflicting `index` attribute. This will result
+// in gaps in the TPUReplicatedInput ordering.
+func @input_index_gaps(%arg0: tensor<i1>) {
+ // expected-error@+1 {{failed to sort 'tf.TPUReplicatedInput' ops, gap(s) found in indices}}
+ %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ %1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
+ "tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> ()
+ "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+ return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
index 3fb311f..7ac179d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
@@ -59,6 +59,7 @@
constexpr char kDeviceAttr[] = "device";
constexpr char kNameAttr[] = "name";
constexpr char kNumReplicasAttr[] = "num_replicas";
+constexpr char kIndexAttr[] = "index";
constexpr char kBadTPUReplicateAttrMsg[] =
"requires '_tpu_replicate' string attribute";
@@ -259,6 +260,40 @@
for (Operation* user : preceding_users) user->moveBefore(op_after_launch_op);
}
+// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
+// of -1 are always after ops with a non negative `index`, and an arbitrary
+// ordering is used as there are no dependencies on their relative ordering.
+LogicalResult SortTPUReplicatedInputsByIndex(
+ llvm::ArrayRef<Operation*> inputs,
+ llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
+ const int input_size = inputs.size();
+ sorted_inputs->resize(input_size, nullptr);
+ int last_index = input_size - 1;
+
+ for (Operation* input : inputs) {
+ int64_t index = -1;
+ if (auto index_attr = input->getAttrOfType<IntegerAttr>(kIndexAttr))
+ index = index_attr.getInt();
+
+ if (index >= input_size || index < -1)
+ return input->emitError() << "'" << input->getName().getStringRef()
+ << "' index is not in range [-1, " << input_size
+ << "), got " << index;
+
+ if (index == -1)
+ (*sorted_inputs)[last_index--] = input;
+ else
+ (*sorted_inputs)[index] = input;
+ }
+
+ if (llvm::any_of(*sorted_inputs, [](Operation* op) { return op == nullptr; }))
+ return inputs.front()->emitError()
+ << "failed to sort '" << inputs.front()->getName().getStringRef()
+ << "' ops, gap(s) found in indices";
+
+ return success();
+}
+
// Creates a `tf_device.replicate` to represent replication for the cluster, if
// necessary.
LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
@@ -270,14 +305,18 @@
return launch_op.emitError() << "requires '" << kNumReplicasAttr
<< "' int attribute to be at least 1";
- // Collect all used TPUReplicatedInput ops.
- llvm::SmallSetVector<Operation*, 8> replicated_input_ops;
+ // Collect all used TPUReplicatedInput ops and sort by `index`.
+ llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
mlir::visitUsedValuesDefinedAbove(
launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
Operation* def = operand->get()->getDefiningOp();
if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
- replicated_input_ops.insert(def);
+ unique_replicated_input_ops.insert(def);
});
+ llvm::SmallVector<Operation*, 8> replicated_input_ops;
+ if (failed(SortTPUReplicatedInputsByIndex(
+ unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
+ return failure();
// Check if number of operands of each used TPUReplicatedInput op matches
// `num_replicas`. Collect all their operands and associated type for creating