Eliminate explicitly captured/returned pass through resource operands of WhileRegion.
PiperOrigin-RevId: 363217235
Change-Id: I638c5aecba05da551fff6b182c1afdefebe32ac9
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
index c2f67df..65215f1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
@@ -858,6 +858,9 @@
// -----
// Test that the pass can lift resources out of WhileRegion
+
+!tf_ref = type tensor<*x!tf.resource<tensor<f32>>>
+
// CHECK-LABEL: func @cluster_with_whileregion
func @cluster_with_whileregion() -> () {
// CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>}
@@ -866,16 +869,17 @@
// CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"()
// CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[COUNT]], %[[READ]])
%0 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
- %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_ref
+ %pass_through = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> !tf_ref
+ %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v3"} : () -> !tf_ref
"tf_device.cluster"() ( {
- %2:3 = "tf.WhileRegion"(%0, %1, %unused) ({
+ %2:4 = "tf.WhileRegion"(%0, %1, %pass_through, %unused) ({
// CHECK: (%[[CARG0:.+]]: tensor<i32>, %[[CARG1:.+]]: tensor<f32>):
// CHECK: %[[CAST:.+]] = "tf.Cast"(%[[CARG1]])
// CHECK: "tf.Less"(%[[CARG0]], %[[CAST]])
// CHECK: "tf.Yield"
- ^bb0(%carg0: tensor<i32>, %carg1:tensor<*x!tf.resource<tensor<f32>>>, %carg2: tensor<*x!tf.resource<tensor<f32>>>):
- %read0 = "tf.ReadVariableOp"(%carg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
+ ^bb0(%carg0: tensor<i32>, %carg1: !tf_ref, %carg2: !tf_ref, %carg3: !tf_ref):
+ %read0 = "tf.ReadVariableOp"(%carg1) : (!tf_ref) -> tensor<f32>
%cast = "tf.Cast"(%read0) : (tensor<f32>) -> tensor<i32>
%cond = "tf.Less"(%carg0, %cast) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%cond) : (tensor<i1>) -> ()
@@ -886,20 +890,20 @@
// CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>}
// CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]])
// CHECK-NEXT: "tf.Yield"(%[[ADD2]], %[[ADD1]])
- ^bb1(%barg0: tensor<i32>, %barg1:tensor<*x!tf.resource<tensor<f32>>>, %barg2: tensor<*x!tf.resource<tensor<f32>>>):
- %read0 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
+ ^bb1(%barg0: tensor<i32>, %barg1: !tf_ref, %barg2: !tf_ref, %barg3: !tf_ref):
+ %read0 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
%add0 = "tf.AddV2"(%read0, %read0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
- "tf.AssignVariableOp"(%barg1, %add0) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
- %read1 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
+ "tf.AssignVariableOp"(%barg1, %add0) : (!tf_ref, tensor<f32>) -> ()
+ %read1 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
%add1 = "tf.AddV2"(%read1, %read1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
- "tf.AssignVariableOp"(%barg1, %add1) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
+ "tf.AssignVariableOp"(%barg1, %add1) : (!tf_ref, tensor<f32>) -> ()
%constant = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%add2 = "tf.AddV2"(%barg0, %constant) : (tensor<i32>, tensor<i32>) -> tensor<i32>
- %id = "tf.Identity"(%barg2) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
- "tf.Yield"(%add2, %barg1, %id) : (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) -> ()
+ %id = "tf.Identity"(%barg3) : (!tf_ref) -> !tf_ref
+ "tf.Yield"(%add2, %barg1, %pass_through, %id) : (tensor<i32>, !tf_ref, !tf_ref, !tf_ref) -> ()
}) {device = "", is_stateless = false}
- : (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
- -> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
+ : (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
+ -> (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
// CHECK: tf_device.return %[[WHILE]]#1 : tensor<f32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
index 14e5d2b..b89c873 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc
@@ -377,20 +377,21 @@
for (OpResult result : llvm::reverse(op.getResults())) {
if (!IsResource(result)) continue;
int result_idx = result.getResultNumber();
- auto body_arg = body.front()
- .getTerminator()
- ->getOperand(result_idx)
- .dyn_cast<BlockArgument>();
- if (!body_arg || body_arg.getArgNumber() != result_idx) {
+ Operation *yield_op = body.front().getTerminator();
+ Value yield_operand = yield_op->getOperand(result_idx);
+ Value while_operand = op.getOperand(result_idx);
+ Value body_arg = body.getArgument(result_idx);
+ Value cond_arg = cond.getArgument(result_idx);
+ if (yield_operand != body_arg && yield_operand != while_operand) {
return op.emitOpError("Result #") << result_idx << " is not tied to arg #"
<< result_idx << " of the body";
}
- body.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
- cond.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
+ body_arg.replaceAllUsesWith(while_operand);
+ cond_arg.replaceAllUsesWith(while_operand);
+ result.replaceAllUsesWith(while_operand);
body.front().getTerminator()->eraseOperand(result_idx);
body.eraseArgument(result_idx);
cond.eraseArgument(result_idx);
- result.replaceAllUsesWith(op.getOperand(result_idx));
op.getOperation()->eraseOperand(result_idx);
can_eliminate.set(result_idx);
}