[MLIR] Fix incorrect tfl.while canonicalization
- Move operand -> result forwarding to the correct place where we know we
are dealing with a pass through operand
PiperOrigin-RevId: 319310847
Change-Id: If29e9051b19866abcf089bde20bae83e0a7e09b8
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index d5f3cf6..853c641 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -2213,8 +2213,7 @@
LogicalResult matchAndRewrite(WhileOp while_op,
PatternRewriter &rewriter) const override {
- // Replace values simply passed through the body with extern values
- // (in both body and condition regions as well as while result). The
+ // Replace values simply passed through the body with extern values. The
// block arguments of body and while match and so the corresponding cond
// argument can be easily found.
bool unchanged = true;
@@ -2222,23 +2221,18 @@
auto &cond_block = while_op.cond().front();
auto &yield = *body_block.getTerminator();
for (auto ba : body_block.getArguments()) {
- int arg_no = ba.getArgNumber();
- if (ba == yield.getOperand(arg_no)) {
+ if (ba == yield.getOperand(ba.getArgNumber())) {
unchanged = false;
- auto value = while_op.getOperand(arg_no);
+ auto value = while_op.getOperand(ba.getArgNumber());
ba.replaceAllUsesWith(value);
- cond_block.getArgument(arg_no).replaceAllUsesWith(value);
-
- // This could be relaxed and casts inserted.
- if (while_op.getResult(arg_no).getType() == value.getType())
- while_op.getResult(arg_no).replaceAllUsesWith(value);
+ cond_block.getArgument(ba.getArgNumber()).replaceAllUsesWith(value);
}
}
// The While ops operands and result types need to match
SmallVector<Value, 4> new_operands;
SmallVector<Value, 4> new_body_yield;
- SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
+ SmallVector<bool, 4> const_operand(while_op.getNumOperands(), false);
llvm::SmallVector<Type, 4> types;
new_operands.reserve(while_op.getNumOperands());
new_body_yield.reserve(while_op.getNumOperands());
@@ -2252,13 +2246,15 @@
auto value = while_op.getOperand(while_index);
if (body_block.getArgument(arg_index).use_empty() &&
cond_block.getArgument(arg_index).use_empty() &&
- while_op.getResult(arg_index).use_empty()) {
+ // This could be relaxed and casts inserted.
+ while_op.getResult(while_index).getType() == value.getType()) {
unchanged = false;
body_block.eraseArgument(arg_index);
cond_block.eraseArgument(arg_index);
- // Mark operand for removal.
- removed_operand[while_index] = true;
+ // Mark operand as constant and replace all uses with input to while.
+ while_op.getResult(while_index).replaceAllUsesWith(value);
+ const_operand[while_index] = true;
} else {
new_operands.push_back(value);
new_body_yield.push_back(yield.getOperand(while_index));
@@ -2280,7 +2276,7 @@
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
int new_index = 0;
for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
- if (removed_operand[op_index]) continue;
+ if (const_operand[op_index]) continue;
op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
++new_index;
}
diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
index c95d37b..b9a24a6 100644
--- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
@@ -111,27 +111,3 @@
// CHECK: [[VAL_2:%.*]] = constant dense<[1, 128, 32]> : tensor<3xi32>
// CHECK: [[VAL_3:%.*]] = "tfl.slice"(%arg0, [[VAL_1]], [[VAL_2]]) : (tensor<4x128x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x128x32xf32>
}
-
-// -----
-
-// CHECK-LABEL: @WhileCanonicalizeBug
-// Make sure that second output of the tf.while is not incorrectly inferred as
-// pass through just because the corresponding input is not used in either
-// condition or body. The tensor<f32> result of the loop can be either %arg1
-// (if the body never executes, or 22.0 if the body executes atleast once).
-func @WhileCanonicalizeBug(%arg0: tensor<i32>, %arg1: tensor<f32>) -> tensor<f32> {
- %0:2 = "tfl.while"(%arg0, %arg1) ( {
- ^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
- %limit = constant dense<100> : tensor<i32>
- %test = "tfl.less"(%arg0, %limit) : (tensor<i32>, tensor<i32>) -> tensor<i1>
- "tfl.yield"(%test) : (tensor<i1>) -> ()
- }, {
- ^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
- %cst = constant dense<22.0> : tensor<f32>
- %stride = constant dense<1> : tensor<i32>
- %inc = tfl.add %arg2, %stride {fused_activation_function = "NONE"} : tensor<i32>
- "tfl.yield"(%inc, %cst) : (tensor<i32>, tensor<f32>) -> ()
- }) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
- // CHECK: return %0#1 : tensor<f32>
- return %0#1 : tensor<f32>
-}