`TPUReshardVariables` is inserted even when there is outside compilation. Before this change, all models with a `parallel_execute` were skipped, because `TPUReshardVariables` does not work with model parallelism. Now model parallelism detection is precise: It's detected when there is more than 1 `TPUExecuteAndUpdateVariables`.

PiperOrigin-RevId: 463647960
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
index 4d85456..5594918 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
@@ -160,7 +160,8 @@
 // -----
 
 // Tests that the pass does not format variables when model parallelism is
-// present.
+// present. Model parallelism is present when there is more than 1
+// TPUExecuteAndUpdateVariables in a parallel_execute.
 
 !tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
 !tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
@@ -168,7 +169,8 @@
 module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
   // CHECK-LABEL: func @main
   // CHECK-NOT: TPUReshardVariables
-  func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
+  func.func @main(
+             %arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
              %arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
              %arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
              %arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
@@ -197,7 +199,8 @@
             "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
             tf_device.return
           }) {device = "/device:CPU:0"} : () -> ()
-          %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
+          %rep:2 = tf_device.replicate(
+	                      [%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
                               [%arg2, %arg3] as %arg31: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>)
                   {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
             %id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
@@ -210,13 +213,120 @@
               }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
               tf_device.return
             }, {
+              "tf_device.launch"() ({
+                "tf.TPUExecuteAndUpdateVariables"(%compile#1)
+                      {device_var_reads_indices = [], device_var_updates_indices = []}
+                        : (tensor<2x!tf_type.string>) -> ()
+                tf_device.return
+              }) {device = "TPU_REPLICATED_CORE_1"} : () -> ()
+              tf_device.return
+            }) {} : () -> ()
+            %ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+            tf_device.return %ret : tensor<i32>
+          }
+          "tf.Yield"(%b1) :  (tensor<i32>) -> ()
+      }) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
+    func.return
+  }
+}
+
+// -----
+
+// Tests that the pass formats variables when there is a parallel_execute but
+// only 1 TPUExecuteAndUpdateVariables. This can happen when there is outside
+// compilation and no model parallelism.
+
+!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
+!tf_res_md_f32 = tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
+
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
+  // CHECK-LABEL: func @main
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
+  // CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf_type.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
+  // CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
+  // CHECK-SAME: %[[ARG3:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
+  func.func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
+             %arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
+             %arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
+             %arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
+
+    %0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
+    // CHECK-SAME: device = "/device:TPU:0"
+    // CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
+    // CHECK-SAME: device = "/device:TPU:1"
+    // CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
+    %1 = "tf.WhileRegion"(%0) ({
+       // Condition region
+       // CHECK: ^bb
+       // CHECK: "tf.Yield"
+       ^bb0(%carg0: tensor<i32>):
+          %c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+          %c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+          "tf.Yield"(%c1) : (tensor<i1>) -> ()
+      }, {
+       // Body region
+       // CHECK: ^bb0
+       ^bb0(%barg0: tensor<i32>):
+          %b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+          %b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+          // CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
+          // CHECK-NEXT: "tf._TPUCompileMlir"()
+          %compile:2 = "tf_device.launch"() ({
+            %b2:2 = "tf._TPUCompileMlir"() {
+              NumDynamicShapes = 0 : i64,
+              // The metadata encodes 2 parameter and 2 return values.
+              metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
+              mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
+            tf_device.return %b2#0, %b2#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
+          }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
+          "tf_device.launch"() ({
+            "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf_type.string>) -> ()
+            tf_device.return
+          }) {device = "/device:CPU:0"} : () -> ()
+          // CHECK: tf_device.replicate
+          // CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
+          // CHECK-SAME: [%[[ARG2]], %[[ARG3]]] as %[[R1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>,
+          // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
+          // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
+          %rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf_type.resource<tensor<f32>>>,
+                              [%arg2, %arg3] as %arg31: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>)
+                  {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
+            // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
+            %id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
+            // CHECK: "tf_device.parallel_execute"
+            // CHECK: "tf_device.launch"
+            // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
+            // CHECK-NEXT: tf_device.return
+            // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
+            // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
+	    "tf_device.parallel_execute"() ({
+              "tf_device.launch"() ({
+                "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
+                      {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
+                        : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf_type.string>) -> ()
+                tf_device.return
+              }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
+              tf_device.return
+            }, {
               tf_device.return
             }) {} : () -> ()
             %ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
             tf_device.return %ret : tensor<i32>
           }
+          // CHECK: "tf.Yield"
           "tf.Yield"(%b1) :  (tensor<i32>) -> ()
       }) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
+    // CHECK: %[[DEFAULT:.*]] = "tf.Const"()
+    // CHECK:  tf_device.replicate
+    // CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf_type.resource<tensor<f32>>>,
+    // CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>,
+    // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
+    // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
+    // CHECK: "tf_device.launch"
+    // CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
+    // CHECK-NEXT: tf_device.return
+    // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
     func.return
   }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
index 08dd3e0..462b93e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
@@ -336,23 +336,26 @@
                        tf_device::ReplicateOp replicate) {
   int64_t num_replicas = replicate.n();
   if (num_replicas == 1) return;
+
+  // Set execute_launch when there is exactly one LaunchOp with a
+  // TPUExecuteAndUpdateVariablesOp. More than one means there is model
+  // parallelism, which is not supported with TPUReshardVariables. None
+  // means there is no TPU computation.
   tf_device::LaunchOp execute_launch;
-  for (auto execute_launch_op :
-       replicate.GetBody().getOps<tf_device::LaunchOp>()) {
+  replicate.walk([&](tf_device::LaunchOp execute_launch_op) {
     if (!execute_launch_op.WrapsSingleOp() ||
         !llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
             execute_launch_op.GetBody().front()))
-      continue;
-
+      return WalkResult::advance();
     if (execute_launch == nullptr) {
       execute_launch = execute_launch_op;
-    } else {
-      // We only support one execute op inside replicate.
-      execute_launch = nullptr;
-      break;
+      return WalkResult::advance();
     }
-  }
+    execute_launch = nullptr;
+    return WalkResult::interrupt();
+  });
   if (!execute_launch) return;
+
   auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
       execute_launch.GetBody().front());
   auto compile =
@@ -485,11 +488,7 @@
       replicate = nullptr;
       return WalkResult::interrupt();
     });
-    // Model parallelism is not supported, and can be detected when a
-    // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
-    if (replicate &&
-        replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
-      HandleReplicateOp(while_op, replicate);
+    if (replicate) HandleReplicateOp(while_op, replicate);
   });
 }