Add a tflite canonicalization pass that replaces PackOps with Reshape

PackOps with only a single operand can be replace with a Reshape.

PiperOrigin-RevId: 380617913
Change-Id: I135085f06538dac2fe8ff07a334db329e55813a8
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 5a57e47..aa6d0fd 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -1737,9 +1737,40 @@
   }
 };
 
+// Replace PackOp with a reshape when there is only one operand.
+struct ReplacePackWithReshape : public RewritePattern {
+  explicit ReplacePackWithReshape(MLIRContext *context)
+      : RewritePattern(PackOp::getOperationName(), 2, context) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    TFL::PackOp pack_op = cast<TFL::PackOp>(op);
+    if (pack_op.getNumOperands() != 1) return failure();
+
+    Location loc = pack_op.getLoc();
+    auto output_type = pack_op.getType().cast<ShapedType>();
+    if (!output_type.hasStaticShape()) return failure();
+
+    // This is to workaround the unnecessary cast i64 -> i32.
+    SmallVector<int32_t, 4> new_shape_array;
+    for (auto size : output_type.getShape()) {
+      new_shape_array.push_back(static_cast<int32_t>(size));
+    }
+
+    auto new_shape = rewriter.create<TFL::ConstOp>(
+        loc, DenseIntElementsAttr::get(
+                 RankedTensorType::get(new_shape_array.size(),
+                                       rewriter.getIntegerType(32)),
+                 new_shape_array));
+
+    rewriter.replaceOpWithNewOp<ReshapeOp>(op, output_type,
+                                           pack_op.getOperand(0), new_shape);
+    return success();
+  }
+};
+
 void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                          MLIRContext *context) {
-  results.insert<RemoveRedundantUnpackPack>(context);
+  results.insert<RemoveRedundantUnpackPack, ReplacePackWithReshape>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
index 7e8a633..7f4f0a9 100644
--- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir
@@ -112,6 +112,16 @@
 
 // -----
 
+// CHECK-LABEL: @ReplacePackWithReshape
+func @ReplacePackWithReshape(%arg0: tensor<5xf32>) -> tensor<1x5xf32> {
+  %1 = "tfl.pack"(%arg0) {axis = 0 : i32, values_count = 1 : i32} : (tensor<5xf32>) -> (tensor<1x5xf32>)
+  // CHECK: reshape
+  // CHECK-NOT: pack
+  return %1: tensor<1x5xf32>
+}
+
+// -----
+
 func @Int64SliceBeginSize(%arg0: tensor<4x128x32xf32>) -> tensor<1x128x32xf32> {
   %0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64>
   %1 = "tfl.pseudo_const"() {value = dense<[1, 128, 32]> : tensor<3xi64>} : () -> tensor<3xi64>