Add folding support for quantized reshape ops in TFL quantization passes

PiperOrigin-RevId: 459442303
diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
index 12b9440..e1e4036 100644
--- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
@@ -169,3 +169,15 @@
   // CHECK: "tfl.pseudo_qconst"() {qtype = tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, value = dense<"0x03030402FD010302010103FE0301020001010001FD02030101FE0400020100FDFEFD01FC01FF02FEFCFE000303FCFE00FF0301FF04010303FF0402FE01FF01000002FD03FD03FC020202FE0204FD03FF01FFFD03FEFE010003FFFF010103FD00FCFEFE020300FFFE02FD03010402040201010401FCFDFDFF0102FE010003FD00FD02FF03FF000201FF00FD0204FD010102FFFF02020003000102FF0002FF0204040300FEFFFEFDFCFC000000000201020000010001FF00FFFF01FF03FE0003FF03FFFEFE03FE03FF0000FE0303FE0002FF01FF01FF04FDFD01FD020101FDFE0101030303020203030301FD010104FD000103FC03FF02FE020402000002FDFF0103FF03010102FDFE02FF00FE01FD02FEFE0002FD02FE0203FFFFFC01FC0102FE04FCFEFC00FCFCFF03000301FFFE03030100030001000302FC01FD0000FD010101FC01020201FDFFFE02FE00FE0201020003040203010100010404FE00FDFE04FE0401FEFDFDFD00FD04FEFCFF03FFFDFF01FF04030403020200020303FF00FF03FD000104FEFD04FCFCFDFE02FF02000003FF00FF030002FDFEFD030300030401000104FCFE030103FC01FD00FC03FE"> : tensor<16x3x3x3xi8>} : () -> tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>
   // CHECK-NEXT: "tfl.transpose_conv"
 }
+
+// CHECK-LABEL: FoldReshape
+func.func @FoldReshape(%arg0: tensor<4xi32>, %arg1: tensor<1x48x80x16x!quant.uniform<i8:f32, 0.047054948993757659:-128>>, %arg2: tensor<1x!quant.uniform<i32:f32, 0.0010538385465422978>>) -> tensor<1x96x160x1x!quant.uniform<i8:f32, 0.37102097156001074:-14>> {
+  %cst = arith.constant dense<[1, 2, 2, 16]> : tensor<4xi32>
+  %0 = "tfl.pseudo_qconst"() {qtype = tensor<2x2x1x16x!quant.uniform<i8<-127:127>:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26]], [[47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]]], [[[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52]], [[40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<2x2x1x16xi8>} : () -> tensor<2x2x1x16x!quant.uniform<i8<-127:127>:f32, 0.022395913056501255>>
+  %1 = "tfl.reshape"(%0, %cst) : (tensor<2x2x1x16x!quant.uniform<i8<-127:127>:f32, 0.022395913056501255>>, tensor<4xi32>) -> tensor<1x2x2x16x!quant.uniform<i8<-127:127>:f32, 0.022395913056501255>>
+  %2 = "tfl.transpose_conv"(%arg0, %1, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x2x2x16x!quant.uniform<i8<-127:127>:f32, 0.022395913056501255>>, tensor<1x48x80x16x!quant.uniform<i8:f32, 0.047054948993757659:-128>>, tensor<1x!quant.uniform<i32:f32, 0.0010538385465422978>>) -> tensor<1x96x160x1x!quant.uniform<i8:f32, 0.37102097156001074:-14>>
+  return %2 : tensor<1x96x160x1x!quant.uniform<i8:f32, 0.37102097156001074:-14>>
+  // CHECK-NOT: "tfl.reshape"
+  // CHECK{LITERAL}: "tfl.pseudo_qconst"() {qtype = tensor<1x2x2x16x!quant.uniform<i8<-127:127>:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26], [47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]], [[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52], [40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<1x2x2x16xi8>} : () -> tensor<1x2x2x16x!quant.uniform<i8<-127:127>:f32, 0.022395913056501255>>
+  // CHECK-NEXT: "tfl.transpose_conv"
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
index d42fb84..ad145c1 100644
--- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
@@ -274,6 +274,60 @@
   }
 };
 
+// Fold constant quantized Reshape ops.
+struct FoldReshapeOp : public OpRewritePattern<ReshapeOp> {
+  // Does not take ownership of context, which must refer to a valid value that
+  // outlives this object.
+  explicit FoldReshapeOp(MLIRContext* context)
+      : OpRewritePattern<ReshapeOp>(context, /*benefit=*/1) {}
+
+  LogicalResult matchAndRewrite(ReshapeOp op,
+                                PatternRewriter& rewriter) const override {
+    Operation* def_op = op.input().getDefiningOp();
+    auto qconst_op = llvm::dyn_cast_or_null<QConstOp>(def_op);
+    if (qconst_op == nullptr) return failure();
+
+    auto dense_elements =
+        qconst_op.value().dyn_cast_or_null<DenseElementsAttr>();
+    if (dense_elements == nullptr) return failure();
+
+    // Handle per tensor cases only.
+    if (!(getElementTypeOrSelf(op.getType()))
+             .isa<quant::UniformQuantizedType>()) {
+      return failure();
+    }
+
+    // Remove identity reshape with both static result and input shape.
+    auto result_type = op.getType().cast<ShapedType>();
+    auto input_type = op.input().getType().cast<ShapedType>();
+
+    // Constant folding
+    // If the result type isn't static, tries to derive the result type from
+    // the #2 operand.
+    if (!result_type.hasStaticShape()) {
+      DenseIntElementsAttr shape_elements;
+      if (!matchPattern(op.shape(), m_Constant(&shape_elements)))
+        return failure();
+
+      SmallVector<int64_t, 4> shape_data;
+      for (const APInt& it : shape_elements.getValues<APInt>()) {
+        shape_data.push_back(it.getSExtValue());
+      }
+      result_type =
+          RankedTensorType::get(shape_data, input_type.getElementType());
+    }
+    auto values_type = RankedTensorType::get(
+        result_type.getShape(), result_type.getElementType()
+                                    .cast<quant::UniformQuantizedType>()
+                                    .getStorageType());
+
+    DenseElementsAttr reshaped_elements = dense_elements.reshape(values_type);
+    rewriter.replaceOpWithNewOp<QConstOp>(op, TypeAttr::get(result_type),
+                                          reshaped_elements);
+    return success();
+  }
+};
+
 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
 // output.
 template <typename OpTy>
@@ -336,9 +390,9 @@
 
   RewritePatternSet phase_2_patterns(&getContext());
   TFL::populateWithGenerated(phase_2_patterns);
-  phase_2_patterns
-      .add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
-           RemoveVolatileOps<kPreserveInputsAndOutputs>, FoldTransposeOp>(ctx);
+  phase_2_patterns.add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
+                       RemoveVolatileOps<kPreserveInputsAndOutputs>,
+                       FoldTransposeOp, FoldReshapeOp>(ctx);
   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
 }