Eliminate explicitly captured/returned pass through resource operands of WhileRegion.

PiperOrigin-RevId: 363217235
Change-Id: I638c5aecba05da551fff6b182c1afdefebe32ac9
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
index c2f67df..65215f1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
@@ -858,6 +858,9 @@
 // -----
 
 // Test that the pass can lift resources out of WhileRegion
+
+!tf_ref = type tensor<*x!tf.resource<tensor<f32>>>
+
 // CHECK-LABEL: func @cluster_with_whileregion
 func @cluster_with_whileregion() -> () {
   // CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>}
@@ -866,16 +869,17 @@
   // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"()
   // CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[COUNT]], %[[READ]])
   %0 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
-  %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_ref
+  %pass_through = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> !tf_ref
+  %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v3"} : () -> !tf_ref
   "tf_device.cluster"() ( {
-    %2:3 = "tf.WhileRegion"(%0, %1, %unused) ({
+    %2:4 = "tf.WhileRegion"(%0, %1, %pass_through, %unused) ({
             // CHECK: (%[[CARG0:.+]]: tensor<i32>, %[[CARG1:.+]]: tensor<f32>):
             // CHECK: %[[CAST:.+]] = "tf.Cast"(%[[CARG1]])
             // CHECK: "tf.Less"(%[[CARG0]], %[[CAST]])
             // CHECK: "tf.Yield"
-            ^bb0(%carg0: tensor<i32>, %carg1:tensor<*x!tf.resource<tensor<f32>>>, %carg2: tensor<*x!tf.resource<tensor<f32>>>):
-               %read0 = "tf.ReadVariableOp"(%carg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
+            ^bb0(%carg0: tensor<i32>, %carg1: !tf_ref, %carg2: !tf_ref, %carg3: !tf_ref):
+               %read0 = "tf.ReadVariableOp"(%carg1) : (!tf_ref) -> tensor<f32>
                %cast = "tf.Cast"(%read0) : (tensor<f32>) -> tensor<i32>
                %cond = "tf.Less"(%carg0, %cast) : (tensor<i32>, tensor<i32>) -> tensor<i1>
                "tf.Yield"(%cond) : (tensor<i1>) -> ()
@@ -886,20 +890,20 @@
             // CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>}
             // CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]])
             // CHECK-NEXT: "tf.Yield"(%[[ADD2]], %[[ADD1]])
-            ^bb1(%barg0: tensor<i32>, %barg1:tensor<*x!tf.resource<tensor<f32>>>, %barg2: tensor<*x!tf.resource<tensor<f32>>>):
-              %read0 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
+            ^bb1(%barg0: tensor<i32>, %barg1: !tf_ref, %barg2: !tf_ref, %barg3: !tf_ref):
+              %read0 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
               %add0 = "tf.AddV2"(%read0, %read0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-              "tf.AssignVariableOp"(%barg1, %add0) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
-              %read1 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
+              "tf.AssignVariableOp"(%barg1, %add0) : (!tf_ref, tensor<f32>) -> ()
+              %read1 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
               %add1 = "tf.AddV2"(%read1, %read1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-              "tf.AssignVariableOp"(%barg1, %add1) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
+              "tf.AssignVariableOp"(%barg1, %add1) : (!tf_ref, tensor<f32>) -> ()
               %constant = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
               %add2 = "tf.AddV2"(%barg0, %constant) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-              %id = "tf.Identity"(%barg2) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
-              "tf.Yield"(%add2, %barg1, %id) : (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) -> ()
+              %id = "tf.Identity"(%barg3) : (!tf_ref) -> !tf_ref
+              "tf.Yield"(%add2, %barg1, %pass_through, %id) : (tensor<i32>, !tf_ref, !tf_ref, !tf_ref) -> ()
             }) {device = "", is_stateless = false}
-         : (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
-         -> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
+         : (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
+         -> (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
     tf_device.return
   }) {cluster_attr = "cluster_attr"} : () -> ()
   // CHECK: tf_device.return %[[WHILE]]#1 : tensor<f32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
index 14e5d2b..b89c873 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
@@ -377,20 +377,21 @@
   for (OpResult result : llvm::reverse(op.getResults())) {
     if (!IsResource(result)) continue;
     int result_idx = result.getResultNumber();
-    auto body_arg = body.front()
-                        .getTerminator()
-                        ->getOperand(result_idx)
-                        .dyn_cast<BlockArgument>();
-    if (!body_arg || body_arg.getArgNumber() != result_idx) {
+    Operation *yield_op = body.front().getTerminator();
+    Value yield_operand = yield_op->getOperand(result_idx);
+    Value while_operand = op.getOperand(result_idx);
+    Value body_arg = body.getArgument(result_idx);
+    Value cond_arg = cond.getArgument(result_idx);
+    if (yield_operand != body_arg && yield_operand != while_operand) {
       return op.emitOpError("Result #") << result_idx << " is not tied to arg #"
                                         << result_idx << " of the body";
     }
-    body.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
-    cond.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
+    body_arg.replaceAllUsesWith(while_operand);
+    cond_arg.replaceAllUsesWith(while_operand);
+    result.replaceAllUsesWith(while_operand);
     body.front().getTerminator()->eraseOperand(result_idx);
     body.eraseArgument(result_idx);
     cond.eraseArgument(result_idx);
-    result.replaceAllUsesWith(op.getOperand(result_idx));
     op.getOperation()->eraseOperand(result_idx);
     can_eliminate.set(result_idx);
   }