[MLIR] Fix incorrect tfl.while canonicalization

- Move operand -> result forwarding to the correct place where we know we
  are dealing with a pass through operand

PiperOrigin-RevId: 319310847
Change-Id: If29e9051b19866abcf089bde20bae83e0a7e09b8
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index d5f3cf6..853c641 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -2213,8 +2213,7 @@
 
   LogicalResult matchAndRewrite(WhileOp while_op,
                                 PatternRewriter &rewriter) const override {
-    // Replace values simply passed through the body with extern values
-    // (in both body and condition regions as well as while result). The
+    // Replace values simply passed through the body with extern values. The
     // block arguments of body and while match and so the corresponding cond
     // argument can be easily found.
     bool unchanged = true;
@@ -2222,23 +2221,18 @@
     auto &cond_block = while_op.cond().front();
     auto &yield = *body_block.getTerminator();
     for (auto ba : body_block.getArguments()) {
-      int arg_no = ba.getArgNumber();
-      if (ba == yield.getOperand(arg_no)) {
+      if (ba == yield.getOperand(ba.getArgNumber())) {
         unchanged = false;
-        auto value = while_op.getOperand(arg_no);
+        auto value = while_op.getOperand(ba.getArgNumber());
         ba.replaceAllUsesWith(value);
-        cond_block.getArgument(arg_no).replaceAllUsesWith(value);
-
-        // This could be relaxed and casts inserted.
-        if (while_op.getResult(arg_no).getType() == value.getType())
-          while_op.getResult(arg_no).replaceAllUsesWith(value);
+        cond_block.getArgument(ba.getArgNumber()).replaceAllUsesWith(value);
       }
     }
 
     // The While ops operands and result types need to match
     SmallVector<Value, 4> new_operands;
     SmallVector<Value, 4> new_body_yield;
-    SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
+    SmallVector<bool, 4> const_operand(while_op.getNumOperands(), false);
     llvm::SmallVector<Type, 4> types;
     new_operands.reserve(while_op.getNumOperands());
     new_body_yield.reserve(while_op.getNumOperands());
@@ -2252,13 +2246,15 @@
       auto value = while_op.getOperand(while_index);
       if (body_block.getArgument(arg_index).use_empty() &&
           cond_block.getArgument(arg_index).use_empty() &&
-          while_op.getResult(arg_index).use_empty()) {
+          // This could be relaxed and casts inserted.
+          while_op.getResult(while_index).getType() == value.getType()) {
         unchanged = false;
         body_block.eraseArgument(arg_index);
         cond_block.eraseArgument(arg_index);
 
-        // Mark operand for removal.
-        removed_operand[while_index] = true;
+        // Mark operand as constant and replace all uses with input to while.
+        while_op.getResult(while_index).replaceAllUsesWith(value);
+        const_operand[while_index] = true;
       } else {
         new_operands.push_back(value);
         new_body_yield.push_back(yield.getOperand(while_index));
@@ -2280,7 +2276,7 @@
     for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
     int new_index = 0;
     for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
-      if (removed_operand[op_index]) continue;
+      if (const_operand[op_index]) continue;
       op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
       ++new_index;
     }
diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
index c95d37b..b9a24a6 100644
--- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
@@ -111,27 +111,3 @@
 // CHECK:  [[VAL_2:%.*]] = constant dense<[1, 128, 32]> : tensor<3xi32>
 // CHECK:  [[VAL_3:%.*]] = "tfl.slice"(%arg0, [[VAL_1]], [[VAL_2]]) : (tensor<4x128x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x128x32xf32>
 }
-
-// -----
-
-// CHECK-LABEL: @WhileCanonicalizeBug
-// Make sure that second output of the tf.while is not incorrectly inferred as
-// pass through just because the corresponding input is not used in either
-// condition or body. The tensor<f32> result of the loop can be either %arg1
-// (if the body never executes, or 22.0 if the body executes atleast once).
-func @WhileCanonicalizeBug(%arg0: tensor<i32>, %arg1: tensor<f32>) -> tensor<f32> {
-  %0:2 = "tfl.while"(%arg0, %arg1) ( {
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
-    %limit = constant dense<100> : tensor<i32>
-    %test = "tfl.less"(%arg0, %limit) : (tensor<i32>, tensor<i32>) -> tensor<i1>
-    "tfl.yield"(%test) : (tensor<i1>) -> ()
-  },  {
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
-    %cst = constant dense<22.0> : tensor<f32>
-    %stride = constant dense<1> : tensor<i32>
-    %inc = tfl.add %arg2, %stride {fused_activation_function = "NONE"} : tensor<i32>
-    "tfl.yield"(%inc, %cst) : (tensor<i32>, tensor<f32>) -> ()
-  }) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
-  // CHECK: return %0#1 : tensor<f32>
-  return %0#1 : tensor<f32>
-}