Add folding support for quantized transpose ops
PiperOrigin-RevId: 459202454
diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
index 8e6da84..12b9440 100644
--- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir
@@ -152,3 +152,20 @@
// QDQ-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0) {num_splits = 4 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
// QDQ-NEXT: return %[[split]]#0 : tensor<2xf32>
}
+
+// CHECK-LABEL: FoldTranspose
+func.func @FoldTranspose(%arg0: tensor<1x10x20x3xf32>) -> tensor<1x20x40x16xf32> {
+ %cst = arith.constant dense<[1, 20, 40, 16]> : tensor<4xi32>
+ %cst_0 = arith.constant dense<[2, 0, 1, 3]> : tensor<4xi32>
+ %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform<i32:f32, 1.8527095877721169E-10>>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform<i32:f32, 1.8527095877721169E-10>>
+ %1 = "tfl.pseudo_qconst"() {qtype = tensor<3x3x16x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, value = dense<"0x0303040002010303FFFFFD0304020401FF0000FEFF0003FF01FD0203FF0202FEFE0003010201FD04FE0402030303000202FD0100FDFE0402FEFEFE01020101FD0204FEFDFC03FFFE0101FDFE02040002FDFFFE03FFFE0201FEFDFF00FFFDFEFD030201FD01FC01FF010003FF0401FCFD0101FC0000FE03FEFE010102000002FE02030100FE00FEFDFD0003FD000303000103FE01FF02000002FF0101FDFDFF02FFFF00000203FF0003030302FDFF03FFFF030001020102FD04FE0104FE030401030102FEFCFEFD03FD03FD000102FE02020001020000FE030202030103FFFC01FC000302000304FCFF03FD04FC00010400010100030303FC02FCFEFE01000303000100010003FE000303010301010102FEFC01FD020301FFFDFFFCFDFEFCFE030001FDFCFE000202FE020300FD00FD02FF0001FF0002FF01FD010102FDFE04FCFE0000FD01000101FF0402FF020103FC020301FF03010204FDFFFE0202FF0302FF02FFFF01FF01FF04FD0002FF00FC00FC0101010404FE03040300000301FD0001FE04FF040103FF01FD0301FF0002040403FF03FE04FDFD0103FCFE01FDFCFF03FC010200FDFE020200FF00FFFC03FE"> : tensor<3x3x16x3xi8>} : () -> tensor<3x3x16x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>
+ %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x10x20x3x!quant.uniform<i8:f32, 3.9215686274509805E-9:-1>>} : (tensor<1x10x20x3xf32>) -> tensor<1x10x20x3x!quant.uniform<i8:f32, 3.9215686274509805E-9:-1>>
+ %3 = "tfl.transpose"(%1, %cst_0) : (tensor<3x3x16x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, tensor<4xi32>) -> tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>
+ %4 = "tfl.transpose_conv"(%cst, %3, %2, %0) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, tensor<1x10x20x3x!quant.uniform<i8:f32, 3.9215686274509805E-9:-1>>, tensor<16x!quant.uniform<i32:f32, 1.8527095877721169E-10>>) -> tensor<1x20x40x16x!quant.uniform<i8:f32, 0.047058823529411764>>
+ %5 = "tfl.dequantize"(%4) : (tensor<1x20x40x16x!quant.uniform<i8:f32, 0.047058823529411764>>) -> tensor<1x20x40x16xf32>
+ return %5 : tensor<1x20x40x16xf32>
+
+ // CHECK-NOT: "tfl.transpose"
+ // CHECK: "tfl.pseudo_qconst"() {qtype = tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>, value = dense<"0x03030402FD010302010103FE0301020001010001FD02030101FE0400020100FDFEFD01FC01FF02FEFCFE000303FCFE00FF0301FF04010303FF0402FE01FF01000002FD03FD03FC020202FE0204FD03FF01FFFD03FEFE010003FFFF010103FD00FCFEFE020300FFFE02FD03010402040201010401FCFDFDFF0102FE010003FD00FD02FF03FF000201FF00FD0204FD010102FFFF02020003000102FF0002FF0204040300FEFFFEFDFCFC000000000201020000010001FF00FFFF01FF03FE0003FF03FFFEFE03FE03FF0000FE0303FE0002FF01FF01FF04FDFD01FD020101FDFE0101030303020203030301FD010104FD000103FC03FF02FE020402000002FDFF0103FF03010102FDFE02FF00FE01FD02FEFE0002FD02FE0203FFFFFC01FC0102FE04FCFEFC00FCFCFF03000301FFFE03030100030001000302FC01FD0000FD010101FC01020201FDFFFE02FE00FE0201020003040203010100010404FE00FDFE04FE0401FEFDFDFD00FD04FEFCFF03FFFDFF01FF04030403020200020303FF00FF03FD000104FEFD04FCFCFDFE02FF02000003FF00FF030002FDFEFD030300030401000104FCFE030103FC01FD00FC03FE"> : tensor<16x3x3x3xi8>} : () -> tensor<16x3x3x3x!quant.uniform<i8<-127:127>:f32, 0.047244094488188976>>
+ // CHECK-NEXT: "tfl.transpose_conv"
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
index 5d9b184..d42fb84 100644
--- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
@@ -20,6 +20,7 @@
#include "llvm/Support/Casting.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
@@ -188,6 +189,91 @@
}
};
+// Fold the constant quantized Transpose ops.
+struct FoldTransposeOp : public OpRewritePattern<TransposeOp> {
+ explicit FoldTransposeOp(MLIRContext* context)
+ : OpRewritePattern<TransposeOp>(context, 1) {}
+
+ // Computes the permutation of a constant `input_tensor` according to `perm`.
+ // The function recursively traverses the dimensions of the output tensor in
+ // a row-major order and writes the value in the output tensor into
+ // `new_values`.
+ void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
+ ArrayRef<int64_t> output_shape, int num_dimensions,
+ int output_axis, std::vector<uint64_t>* input_indices,
+ std::vector<Attribute>* new_values) const {
+ // Refer to the implementation of `Transpose` function in
+ // tensorflow/lite/kernels/internal/reference/reference_ops.h
+ assert(output_axis < num_dimensions);
+ const int input_axis = perm[output_axis];
+ for (int i = 0; i < output_shape[output_axis]; ++i) {
+ // Update the input indices on `input_axis`.
+ assert(input_axis < input_indices->size());
+ input_indices->operator[](input_axis) = static_cast<uint64_t>(i);
+ // Write the value from `input_tensor` if it is the last axis or
+ // recurse into the next axis.
+ const bool is_last_axis = output_axis == num_dimensions - 1;
+ if (is_last_axis) {
+ new_values->push_back(
+ input_tensor.getValues<Attribute>()[*input_indices]);
+ } else {
+ ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
+ output_axis + 1, input_indices, new_values);
+ }
+ }
+ }
+
+ LogicalResult matchAndRewrite(TransposeOp op,
+ PatternRewriter& rewriter) const override {
+ Operation* def_op = op.input().getDefiningOp();
+ auto qconst_op = llvm::dyn_cast_or_null<QConstOp>(def_op);
+ if (qconst_op == nullptr) return failure();
+
+ DenseIntElementsAttr perm_tensor;
+ if (!matchPattern(op.perm(), m_Constant(&perm_tensor))) return failure();
+
+ if (!(getElementTypeOrSelf(op.output().getType()))
+ .isa<quant::UniformQuantizedType>())
+ return failure();
+
+ ElementsAttr input_tensor = qconst_op.value();
+
+ assert(perm_tensor.getType().getRank() == 1);
+ const int num_dimensions = input_tensor.getType().getRank();
+ assert(perm_tensor.getType().getNumElements() == num_dimensions);
+
+ ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
+ auto output_type = op.output().getType().cast<ShapedType>();
+
+ SmallVector<int32_t, 4> perm;
+ SmallVector<int64_t, 4> output_shape;
+ for (int i = 0; i < num_dimensions; ++i) {
+ perm.push_back(perm_tensor.getValues<IntegerAttr>()[i].getInt());
+ output_shape.push_back(input_shape[perm[i]]);
+
+ // Check that the derived output shape matches the static shape.
+ assert(!output_type.hasStaticShape() ||
+ output_type.getShape()[i] == output_shape[i]);
+ }
+
+ std::vector<Attribute> new_values;
+ new_values.reserve(input_tensor.getType().getNumElements());
+ std::vector<uint64_t> input_indices(num_dimensions);
+ ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
+ /*output_axis=*/0, &input_indices, &new_values);
+ auto result_type =
+ RankedTensorType::get(output_shape, output_type.getElementType());
+ auto values_type = RankedTensorType::get(
+ output_shape, output_type.getElementType()
+ .cast<quant::UniformQuantizedType>()
+ .getStorageType());
+ rewriter.replaceOpWithNewOp<QConstOp>(
+ op, TypeAttr::get(result_type),
+ DenseIntElementsAttr::get(values_type, new_values));
+ return success();
+ }
+};
+
// Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
// output.
template <typename OpTy>
@@ -250,8 +336,9 @@
RewritePatternSet phase_2_patterns(&getContext());
TFL::populateWithGenerated(phase_2_patterns);
- phase_2_patterns.add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
- RemoveVolatileOps<kPreserveInputsAndOutputs>>(ctx);
+ phase_2_patterns
+ .add<quant::FoldTrivalRequantizeOp<QuantizeOp>,
+ RemoveVolatileOps<kPreserveInputsAndOutputs>, FoldTransposeOp>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
}