Fix a flaky mlir test.
PiperOrigin-RevId: 354570431
Change-Id: I0cde67a9487888e7e285462b17fe371f2231a736
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir
index fd3519a..2ae5ebc 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir
@@ -9,7 +9,7 @@
// CHECK-SAME: body = @while_body
// CHECK-SAME: cond = @while_cond
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-// CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.remote_run "/job:worker/replica:0/task:1" @_job_worker_replica_0_task_1(%[[ARG_1]])
+// CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.remote_run "/job:worker/replica:0/task:1" @[[MAIN_PARTITION_0:.*]](%[[ARG_1]])
// CHECK-NEXT: return %[[RESULT_0]], %[[RESULT_1]]
func @main(%arg0: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"}) -> (tensor<i32>, tensor<i32>) {
%1 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant, device="/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i32>) -> (tensor<i32>)
@@ -40,9 +40,9 @@
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Const"() {value = dense<16> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: tf_device.send %[[RESULT_2]] "key-0" "/job:worker/replica:0/task:1/device:CPU:0"
// CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-// CHECK-NEXT: tf_device.remote_run "/job:worker/replica:0/task:1" @_job_worker_replica_0_task_1_0()
+// CHECK-NEXT: tf_device.remote_run "/job:worker/replica:0/task:1" @[[BODY_PARTITION_0:.*]]() : () -> ()
// CHECK-NEXT: tf_device.send %[[RESULT_2]]
-// CHECK-NEXT: tf_device.remote_run "/job:worker/replica:0/task:2" @_job_worker_replica_0_task_2()
+// CHECK-NEXT: tf_device.remote_run "/job:worker/replica:0/task:2" @[[BODY_PARTITION_1:.*]]() : () -> ()
// TODO(tf-runtime): Allow while body having remote inputs and outputs.
func @while_body(%arg0: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<i32>) {
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
@@ -59,20 +59,20 @@
return %1 : tensor<i32>
}
// Subgraph of @main function that is placed on worker:1
-// CHECK: func @_job_worker_replica_0_task_1(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
+// CHECK: func @[[MAIN_PARTITION_0]](%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
// CHECK-SAME: host = "/job:worker/replica:0/task:1"
// CHECK-NEXT: %[[RESULT_0:.*]] = "tf.AddV2"(%[[ARG_0]], %[[ARG_0]])
// CHECK-SAME: device = "/job:worker/replica:0/task:1/device:CPU:0"
// CHECK-NEXT: return %[[RESULT_0]]
// Subgraph of @while_body function that is placed on worker:1
-// CHECK: func @_job_worker_replica_0_task_1_0() attributes {host = "/job:worker/replica:0/task:1"}
+// CHECK: func @[[BODY_PARTITION_0]]() attributes {host = "/job:worker/replica:0/task:1"}
// CHECK-NEXT: %[[RESULT_0:.*]] = tf_device.receive "key-0"
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.AddV2"(%[[RESULT_0]], %[[RESULT_0]])
// CHECK-SAME: device = "/job:worker/replica:0/task:1/device:CPU:0"
// Subgraph of @while_body function that is placed on worker:2
-// CHECK: func @_job_worker_replica_0_task_2() attributes {host = "/job:worker/replica:0/task:2"}
+// CHECK: func @[[BODY_PARTITION_1]]() attributes {host = "/job:worker/replica:0/task:2"}
// CHECK-NEXT: %[[RESULT_0:.*]] = tf_device.receive "key-1"
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.AddV2"(%[[RESULT_0]], %[[RESULT_0]])
// CHECK-SAME: device = "/job:worker/replica:0/task:2/device:CPU:0"