Add a tflite canonicalization pass that replaces PackOps with Reshape
PackOps with only a single operand can be replace with a Reshape.
PiperOrigin-RevId: 380617913
Change-Id: I135085f06538dac2fe8ff07a334db329e55813a8
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 5a57e47..aa6d0fd 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -1737,9 +1737,40 @@
}
};
+// Replace PackOp with a reshape when there is only one operand.
+struct ReplacePackWithReshape : public RewritePattern {
+ explicit ReplacePackWithReshape(MLIRContext *context)
+ : RewritePattern(PackOp::getOperationName(), 2, context) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ TFL::PackOp pack_op = cast<TFL::PackOp>(op);
+ if (pack_op.getNumOperands() != 1) return failure();
+
+ Location loc = pack_op.getLoc();
+ auto output_type = pack_op.getType().cast<ShapedType>();
+ if (!output_type.hasStaticShape()) return failure();
+
+ // This is to workaround the unnecessary cast i64 -> i32.
+ SmallVector<int32_t, 4> new_shape_array;
+ for (auto size : output_type.getShape()) {
+ new_shape_array.push_back(static_cast<int32_t>(size));
+ }
+
+ auto new_shape = rewriter.create<TFL::ConstOp>(
+ loc, DenseIntElementsAttr::get(
+ RankedTensorType::get(new_shape_array.size(),
+ rewriter.getIntegerType(32)),
+ new_shape_array));
+
+ rewriter.replaceOpWithNewOp<ReshapeOp>(op, output_type,
+ pack_op.getOperand(0), new_shape);
+ return success();
+ }
+};
+
void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<RemoveRedundantUnpackPack>(context);
+ results.insert<RemoveRedundantUnpackPack, ReplacePackWithReshape>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
index 7e8a633..7f4f0a9 100644
--- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
@@ -112,6 +112,16 @@
// -----
+// CHECK-LABEL: @ReplacePackWithReshape
+func @ReplacePackWithReshape(%arg0: tensor<5xf32>) -> tensor<1x5xf32> {
+ %1 = "tfl.pack"(%arg0) {axis = 0 : i32, values_count = 1 : i32} : (tensor<5xf32>) -> (tensor<1x5xf32>)
+ // CHECK: reshape
+ // CHECK-NOT: pack
+ return %1: tensor<1x5xf32>
+}
+
+// -----
+
func @Int64SliceBeginSize(%arg0: tensor<4x128x32xf32>) -> tensor<1x128x32xf32> {
%0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64>
%1 = "tfl.pseudo_const"() {value = dense<[1, 128, 32]> : tensor<3xi64>} : () -> tensor<3xi64>