Add folding support for quantized transpose ops

PiperOrigin-RevId: 459202454
diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
index 8e6da84..12b9440 100644
--- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
@@ -152,3 +152,20 @@
 // QDQ-NEXT:  %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0) {num_splits = 4 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
 // QDQ-NEXT:  return %[[split]]#0 : tensor<2xf32>
 }
+
+// CHECK-LABEL: FoldTranspose
+func.func @FoldTranspose(%arg0: tensor<1x10x20x3xf32>) -> tensor<1x20x40x16xf32> {
+  %cst = arith.constant dense<[1, 20, 40, 16]> : tensor<4xi32>
+  %cst_0 = arith.constant dense<[2, 0, 1, 3]> : tensor<4xi32>
+  %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform<i32:f32, 1.8527095877721169E-10>>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform<i32:f32, 1.8527095877721169E-10>>
+  %1 = "tfl.pseudo_qconst"() {qtype = tensor<3x3x16x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, value = dense<"0x0303040002010303FFFFFD0304020401FF0000FEFF0003FF01FD0203FF0202FEFE0003010201FD04FE0402030303000202FD0100FDFE0402FEFEFE01020101FD0204FEFDFC03FFFE0101FDFE02040002FDFFFE03FFFE0201FEFDFF00FFFDFEFD030201FD01FC01FF010003FF0401FCFD0101FC0000FE03FEFE010102000002FE02030100FE00FEFDFD0003FD000303000103FE01FF02000002FF0101FDFDFF02FFFF00000203FF0003030302FDFF03FFFF030001020102FD04FE0104FE030401030102FEFCFEFD03FD03FD000102FE02020001020000FE030202030103FFFC01FC000302000304FCFF03FD04FC00010400010100030303FC02FCFEFE01000303000100010003FE000303010301010102FEFC01FD020301FFFDFFFCFDFEFCFE030001FDFCFE000202FE020300FD00FD02FF0001FF0002FF01FD010102FDFE04FCFE0000FD01000101FF0402FF020103FC020301FF03010204FDFFFE0202FF0302FF02FFFF01FF01FF04FD0002FF00FC00FC0101010404FE03040300000301FD0001FE04FF040103FF01FD0301FF0002040403FF03FE04FDFD0103FCFE01FDFCFF03FC010200FDFE020200FF00FFFC03FE"> : tensor<3x3x16x3xi8>} : () -> tensor<3x3x16x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>
+  %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x10x20x3x!quant.uniform<i8:f32, 3.9215686274509805E-9:-1>>} : (tensor<1x10x20x3xf32>) -> tensor<1x10x20x3x!quant.uniform<i8:f32, 3.9215686274509805E-9:-1>>
+  %3 = "tfl.transpose"(%1, %cst_0) : (tensor<3x3x16x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, tensor<4xi32>) -> tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>
+  %4 = "tfl.transpose_conv"(%cst, %3, %2, %0) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, tensor<1x10x20x3x!quant.uniform<i8:f32, 3.9215686274509805E-9:-1>>, tensor<16x!quant.uniform<i32:f32, 1.8527095877721169E-10>>) -> tensor<1x20x40x16x!quant.uniform<i8:f32, 0.047058823529411764>>
+  %5 = "tfl.dequantize"(%4) : (tensor<1x20x40x16x!quant.uniform<i8:f32, 0.047058823529411764>>) -> tensor<1x20x40x16xf32>
+  return %5 : tensor<1x20x40x16xf32>
+
+  // CHECK-NOT: "tfl.transpose"
+  // CHECK: "tfl.pseudo_qconst"() {qtype = tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, value = dense<"0x03030402FD010302010103FE0301020001010001FD02030101FE0400020100FDFEFD01FC01FF02FEFCFE000303FCFE00FF0301FF04010303FF0402FE01FF01000002FD03FD03FC020202FE0204FD03FF01FFFD03FEFE010003FFFF010103FD00FCFEFE020300FFFE02FD03010402040201010401FCFDFDFF0102FE010003FD00FD02FF03FF000201FF00FD0204FD010102FFFF02020003000102FF0002FF0204040300FEFFFEFDFCFC000000000201020000010001FF00FFFF01FF03FE0003FF03FFFEFE03FE03FF0000FE0303FE0002FF01FF01FF04FDFD01FD020101FDFE0101030303020203030301FD010104FD000103FC03FF02FE020402000002FDFF0103FF03010102FDFE02FF00FE01FD02FEFE0002FD02FE0203FFFFFC01FC0102FE04FCFEFC00FCFCFF03000301FFFE03030100030001000302FC01FD0000FD010101FC01020201FDFFFE02FE00FE0201020003040203010100010404FE00FDFE04FE0401FEFDFDFD00FD04FEFCFF03FFFDFF01FF04030403020200020303FF00FF03FD000104FEFD04FCFCFDFE02FF02000003FF00FF030002FDFEFD030300030401000104FCFE030103FC01FD00FC03FE"> : tensor<16x3x3x3xi8>} : () -> tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>
+  // CHECK-NEXT: "tfl.transpose_conv"
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
index 5d9b184..d42fb84 100644
--- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
@@ -20,6 +20,7 @@
 
 #include "llvm/Support/Casting.h"
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
@@ -188,6 +189,91 @@
   }
 };
 
+// Fold the constant quantized Transpose ops.
+struct FoldTransposeOp : public OpRewritePattern<TransposeOp> {
+  explicit FoldTransposeOp(MLIRContext* context)
+      : OpRewritePattern<TransposeOp>(context, 1) {}
+
+  // Computes the permutation of a constant `input_tensor` according to `perm`.
+  // The function recursively traverses the dimensions of the output tensor in
+  // a row-major order and writes the value in the output tensor into
+  // `new_values`.
+  void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
+                          ArrayRef<int64_t> output_shape, int num_dimensions,
+                          int output_axis, std::vector<uint64_t>* input_indices,
+                          std::vector<Attribute>* new_values) const {
+    // Refer to the implementation of `Transpose` function in
+    // tensorflow/lite/kernels/internal/reference/reference_ops.h
+    assert(output_axis < num_dimensions);
+    const int input_axis = perm[output_axis];
+    for (int i = 0; i < output_shape[output_axis]; ++i) {
+      // Update the input indices on `input_axis`.
+      assert(input_axis < input_indices->size());
+      input_indices->operator[](input_axis) = static_cast<uint64_t>(i);
+      // Write the value from `input_tensor` if it is the last axis or
+      // recurse into the next axis.
+      const bool is_last_axis = output_axis == num_dimensions - 1;
+      if (is_last_axis) {
+        new_values->push_back(
+            input_tensor.getValues<Attribute>()[*input_indices]);
+      } else {
+        ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
+                           output_axis + 1, input_indices, new_values);
+      }
+    }
+  }
+
+  LogicalResult matchAndRewrite(TransposeOp op,
+                                PatternRewriter& rewriter) const override {
+    Operation* def_op = op.input().getDefiningOp();
+    auto qconst_op = llvm::dyn_cast_or_null<QConstOp>(def_op);
+    if (qconst_op == nullptr) return failure();
+
+    DenseIntElementsAttr perm_tensor;
+    if (!matchPattern(op.perm(), m_Constant(&perm_tensor))) return failure();
+
+    if (!(getElementTypeOrSelf(op.output().getType()))
+             .isa<quant::UniformQuantizedType>())
+      return failure();
+
+    ElementsAttr input_tensor = qconst_op.value();
+
+    assert(perm_tensor.getType().getRank() == 1);
+    const int num_dimensions = input_tensor.getType().getRank();
+    assert(perm_tensor.getType().getNumElements() == num_dimensions);
+
+    ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
+    auto output_type = op.output().getType().cast<ShapedType>();
+
+    SmallVector<int32_t, 4> perm;
+    SmallVector<int64_t, 4> output_shape;
+    for (int i = 0; i < num_dimensions; ++i) {
+      perm.push_back(perm_tensor.getValues<IntegerAttr>()[i].getInt());
+      output_shape.push_back(input_shape[perm[i]]);
+
+      // Check that the derived output shape matches the static shape.
+      assert(!output_type.hasStaticShape() ||
+             output_type.getShape()[i] == output_shape[i]);
+    }
+
+    std::vector<Attribute> new_values;
+    new_values.reserve(input_tensor.getType().getNumElements());
+    std::vector<uint64_t> input_indices(num_dimensions);
+    ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
+                       /*output_axis=*/0, &input_indices, &new_values);
+    auto result_type =
+        RankedTensorType::get(output_shape, output_type.getElementType());
+    auto values_type = RankedTensorType::get(
+        output_shape, output_type.getElementType()
+                          .cast<quant::UniformQuantizedType>()
+                          .getStorageType());
+    rewriter.replaceOpWithNewOp<QConstOp>(
+        op, TypeAttr::get(result_type),
+        DenseIntElementsAttr::get(values_type, new_values));
+    return success();
+  }
+};
+
 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
 // output.
 template <typename OpTy>
@@ -250,8 +336,9 @@
 
   RewritePatternSet phase_2_patterns(&getContext());
   TFL::populateWithGenerated(phase_2_patterns);
-  phase_2_patterns.add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
-                       RemoveVolatileOps<kPreserveInputsAndOutputs>>(ctx);
+  phase_2_patterns
+      .add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
+           RemoveVolatileOps<kPreserveInputsAndOutputs>, FoldTransposeOp>(ctx);
   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
 }