Update lower static tensor list pass with the following changes:
* Don't reorder FuncOps manually (dialect conversion can automatically handle dependencies between SSA values).
* Apply conversion on the whole ModuleOp
* Rely on dialect conversion to automatically roll back changes to IR in case of legalization failure.
PiperOrigin-RevId: 354364707
Change-Id: Iac586eba16fde8c25d0b7da9484f747a1a3874e3
diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
index 9ae152f..7adc6ef 100644
--- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
@@ -1,4 +1,6 @@
-// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
+// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s
+
+// -----
// CHECK-LABEL: tensorlistConst
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
@@ -431,3 +433,29 @@
// CHECK: [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?x10xf32>
// CHECK: return [[RESULT]] : tensor<?x10xf32>
// CHECK: }
+
+// -----
+
+// CHECK-LABEL: tensorlistReserveWithDynamicShape
+func @tensorlistReserveWithDynamicShape(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
+ %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
+ %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+
+// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
+// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
+// CHECK: return %1 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: tensorlistConcat
+func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
+ %list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
+ %t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
+ return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
+
+// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
+// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
+// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 400df2a..7194e73 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -58,6 +58,7 @@
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/framework/tensor.h"
@@ -72,20 +73,21 @@
namespace mlir {
namespace {
-class TensorListPatternRewriter : public PatternRewriter {
- public:
- explicit TensorListPatternRewriter(FuncOp fn)
- : PatternRewriter(fn.getContext()) {}
-};
-
/// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
+ LowerStaticTensorListPass() = default;
+ LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
+
void runOnOperation() override;
- // Apply type and op changes within a function.
- LogicalResult RewriteFunction(FuncOp func,
- TensorListPatternRewriter *rewriter);
+ Option<bool> allow_tensorlist_pass_through{
+ *this, "allow-tensorlist-pass-through",
+ llvm::cl::desc(
+ "When specified to true, if the tensorlist ops can't be properly "
+ "legalized by this pass, then the IR won't be changed so that "
+ "tensorlist ops can pass through (default false)"),
+ llvm::cl::init(false)};
};
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
@@ -335,7 +337,8 @@
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) {
- op.emitError(
+ rewriter.notifyMatchFailure(
+ op,
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass");
@@ -393,7 +396,8 @@
if (element_shape_acquired) break;
}
if (!element_shape_acquired) {
- op.emitError(
+ rewriter.notifyMatchFailure(
+ op,
"requires element_shape to be 1D tensor during TF Lite "
"transformation pass");
return failure();
@@ -972,8 +976,7 @@
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
-LogicalResult LowerStaticTensorListPass::RewriteFunction(
- FuncOp func, TensorListPatternRewriter *rewriter) {
+void LowerStaticTensorListPass::runOnOperation() {
auto *context = &getContext();
// TensorFlow operations that doesn't have operands and results of type
@@ -996,7 +999,7 @@
TF::TensorListGetItemOp, TF::TensorListLengthOp,
TF::TensorListPushBackOp, TF::TensorListReserveOp,
TF::TensorListSetItemOp, TF::TensorListStackOp,
- TF::TensorListResizeOp>();
+ TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
// TODO(hinsu): Use TFLite constant op for constants.
target.addLegalOp<ConstantOp>();
target.addLegalOp<FuncOp>();
@@ -1016,29 +1019,10 @@
ConvertTensorListSetItem, ConvertTensorListStack,
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
context);
- return applyPartialConversion(func, target, std::move(patterns));
-}
-
-void LowerStaticTensorListPass::runOnOperation() {
- // TODO(haoliang): currently we process the `main` function first, and the
- // remaining functions may be processed in arbitrary order. However, this will
- // have a potential issue when one function taking a `DT_VARIANT` is processed
- // before the function that produces the `DT_VARIANT`. We need to carefully
- // order the functions to be processed.
- std::vector<FuncOp> funcs_in_module;
- for (auto func : getOperation().getOps<FuncOp>()) {
- // Always place the main function to be the first in the list.
- if (func.getName() == "main") {
- funcs_in_module.insert(funcs_in_module.begin(), func);
- } else {
- funcs_in_module.push_back(func);
- }
- }
- for (auto func : funcs_in_module) {
- TensorListPatternRewriter rewriter(func);
- if (failed(RewriteFunction(func, &rewriter))) {
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ if (!allow_tensorlist_pass_through) {
signalPassFailure();
- return;
}
}
}