Update lower static tensor list pass with the following changes:
* Don't reorder FuncOps manually (dialect conversion can automatically handle dependencies between SSA values).
* Apply conversion on the whole ModuleOp
* Rely on dialect conversion to automatically roll back changes to IR in case of legalization failure.

PiperOrigin-RevId: 354364707
Change-Id: Iac586eba16fde8c25d0b7da9484f747a1a3874e3
diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
index 9ae152f..7adc6ef 100644
--- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
@@ -1,4 +1,6 @@
-// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
+// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s
+
+// -----
 
 // CHECK-LABEL: tensorlistConst
 func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
@@ -431,3 +433,29 @@
 // CHECK:  [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?x10xf32>
 // CHECK:  return [[RESULT]] : tensor<?x10xf32>
 // CHECK:  }
+
+// -----
+
+// CHECK-LABEL: tensorlistReserveWithDynamicShape
+func @tensorlistReserveWithDynamicShape(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
+  %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
+  %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+
+// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
+// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
+// CHECK: return %1 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: tensorlistConcat
+func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
+  %list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
+  %t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
+  return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
+
+// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
+// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
+// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 400df2a..7194e73 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -58,6 +58,7 @@
 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -72,20 +73,21 @@
 namespace mlir {
 namespace {
 
-class TensorListPatternRewriter : public PatternRewriter {
- public:
-  explicit TensorListPatternRewriter(FuncOp fn)
-      : PatternRewriter(fn.getContext()) {}
-};
-
 /// Lower TensorList ops in functions for subsequent legalization.
 struct LowerStaticTensorListPass
     : public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
+  LowerStaticTensorListPass() = default;
+  LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
+
   void runOnOperation() override;
 
-  // Apply type and op changes within a function.
-  LogicalResult RewriteFunction(FuncOp func,
-                                TensorListPatternRewriter *rewriter);
+  Option<bool> allow_tensorlist_pass_through{
+      *this, "allow-tensorlist-pass-through",
+      llvm::cl::desc(
+          "When specified to true, if the tensorlist ops can't be properly "
+          "legalized by this pass, then the IR won't be changed so that "
+          "tensorlist ops can pass through (default false)"),
+      llvm::cl::init(false)};
 };
 
 Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
@@ -335,7 +337,8 @@
     if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
           dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
           dtype.isInteger(32) || dtype.isInteger(64))) {
-      op.emitError(
+      rewriter.notifyMatchFailure(
+          op,
           "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
           "integer or 16-bit/32-bit/64-bit float type during TF Lite "
           "transformation pass");
@@ -393,7 +396,8 @@
           if (element_shape_acquired) break;
         }
         if (!element_shape_acquired) {
-          op.emitError(
+          rewriter.notifyMatchFailure(
+              op,
               "requires element_shape to be 1D tensor during TF Lite "
               "transformation pass");
           return failure();
@@ -972,8 +976,7 @@
 
 #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
 
-LogicalResult LowerStaticTensorListPass::RewriteFunction(
-    FuncOp func, TensorListPatternRewriter *rewriter) {
+void LowerStaticTensorListPass::runOnOperation() {
   auto *context = &getContext();
 
   // TensorFlow operations that doesn't have operands and results of type
@@ -996,7 +999,7 @@
                       TF::TensorListGetItemOp, TF::TensorListLengthOp,
                       TF::TensorListPushBackOp, TF::TensorListReserveOp,
                       TF::TensorListSetItemOp, TF::TensorListStackOp,
-                      TF::TensorListResizeOp>();
+                      TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
   // TODO(hinsu): Use TFLite constant op for constants.
   target.addLegalOp<ConstantOp>();
   target.addLegalOp<FuncOp>();
@@ -1016,29 +1019,10 @@
                   ConvertTensorListSetItem, ConvertTensorListStack,
                   ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
       context);
-  return applyPartialConversion(func, target, std::move(patterns));
-}
-
-void LowerStaticTensorListPass::runOnOperation() {
-  // TODO(haoliang): currently we process the `main` function first, and the
-  // remaining functions may be processed in arbitrary order. However, this will
-  // have a potential issue when one function taking a `DT_VARIANT` is processed
-  // before the function that produces the `DT_VARIANT`. We need to carefully
-  // order the functions to be processed.
-  std::vector<FuncOp> funcs_in_module;
-  for (auto func : getOperation().getOps<FuncOp>()) {
-    // Always place the main function to be the first in the list.
-    if (func.getName() == "main") {
-      funcs_in_module.insert(funcs_in_module.begin(), func);
-    } else {
-      funcs_in_module.push_back(func);
-    }
-  }
-  for (auto func : funcs_in_module) {
-    TensorListPatternRewriter rewriter(func);
-    if (failed(RewriteFunction(func, &rewriter))) {
+  if (failed(applyPartialConversion(getOperation(), target,
+                                    std::move(patterns)))) {
+    if (!allow_tensorlist_pass_through) {
       signalPassFailure();
-      return;
     }
   }
 }