Only handle static shapes in optimiza_op_order.cc
Some readability improvements by making variable names clear.
PiperOrigin-RevId: 394593093
Change-Id: I397805356a43d97e25d6084e6951c3667e18706c
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir
index c2a7a16..a46ce59 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir
@@ -29,3 +29,14 @@
// CHECK-NEXT: tfl.dequantize
// CHECK-NEXT: tfl.unpack
}
+
+// CHECK-LABEL: no_pushdown_dynamic_shape
+func @no_pushdown_dynamic_shape(%arg0: tensor<?x1000x1000x!quant.uniform<i8:f32, 7.812500e-03>>, %arg1: tensor<1x1xi32>) -> tensor<?x1x1000xf32> {
+ %0 = "tfl.dequantize"(%arg0) : (tensor<?x1000x1000x!quant.uniform<i8:f32, 7.812500e-03>>) -> tensor<?x1000x1000xf32>
+ %1 = "tfl.gather"(%0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32}: (tensor<?x1000x1000xf32>, tensor<1x1xi32>) -> tensor<?x1x1000xf32>
+ return %1 : tensor<?x1x1000xf32>
+
+// CHECK-NEXT: tfl.dequantize
+// CHECK-NEXT: tfl.gather
+}
+
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc
index abfd688..e9613a4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
@@ -26,49 +27,57 @@
namespace {
// Dequantize ops will produce 3x larger tensors, so we want to move it after
-// some passthrought ops to reduce the memory consumption.
+// some passthrough ops to reduce the memory consumption.
struct PushDownDequantize : public OpRewritePattern<DequantizeOp> {
explicit PushDownDequantize(MLIRContext* context)
: OpRewritePattern<DequantizeOp>(context) {}
- LogicalResult matchAndRewrite(DequantizeOp op,
+ LogicalResult matchAndRewrite(DequantizeOp dequantize_op,
PatternRewriter& rewriter) const override {
- if (!op->hasOneUse()) return failure();
+ if (!dequantize_op->hasOneUse()) return failure();
- auto use = op->use_begin();
- Operation* user = use->getOwner();
+ auto use = dequantize_op->use_begin();
+ Operation* passthrough_op = use->getOwner();
unsigned operand_index = use->getOperandNumber();
- if (user->hasTrait<OpTrait::IsTerminator>()) return failure();
+ if (passthrough_op->hasTrait<OpTrait::IsTerminator>()) return failure();
- auto get_num_elements = [](Value value) {
- return value.getType().cast<TensorType>().getNumElements();
+ auto get_num_elements = [](RankedTensorType tensor) {
+ return tensor.getNumElements();
};
// If the op is the pass-through op with (3x) smaller output, the dequantize
// op can be pushed down to the single result of this op.
- if (!llvm::dyn_cast<mlir::SameScalesOpInterface>(user) ||
- user->getNumResults() != 1 ||
- get_num_elements(user->getOperand(operand_index)) <=
- get_num_elements(user->getResult(0))) {
+ if (!llvm::dyn_cast<mlir::SameScalesOpInterface>(passthrough_op) ||
+ passthrough_op->getNumResults() != 1) {
+ return failure();
+ }
+ // Only push down the dequantize op when the output is smaller, so that it
+ // can have smaller memory usage.
+ auto input_type =
+ dequantize_op.output().getType().dyn_cast<RankedTensorType>();
+ auto output_type =
+ passthrough_op->getResult(0).getType().dyn_cast<RankedTensorType>();
+ if (!input_type || !output_type || !input_type.hasStaticShape() ||
+ !output_type.hasStaticShape() ||
+ get_num_elements(input_type) <= get_num_elements(output_type)) {
return failure();
}
// Set the output type of the dequantize op and push it down.
- Type result_type = user->getResult(0).getType();
- op.output().setType(result_type);
- user->replaceAllUsesWith(op);
+ dequantize_op.output().setType(output_type);
+ passthrough_op->replaceAllUsesWith(dequantize_op);
- // Set the input type of the pass through op and pull it up.
- Type user_new_type =
- QuantizedType::getQuantizedElementType(op.input().getType())
- .castFromExpressedType(result_type);
- user->getResult(0).setType(user_new_type);
- user->setOperand(operand_index, op.input());
+ // Set the input type of the passthrough op and pull it up.
+ Type new_output_type =
+ QuantizedType::getQuantizedElementType(dequantize_op.input().getType())
+ .castFromExpressedType(output_type);
+ passthrough_op->getResult(0).setType(new_output_type);
+ passthrough_op->setOperand(operand_index, dequantize_op.input());
- // Set the input of the dequantize to the result of the pass throught op.
+ // Set the input of the dequantize to the result of the passthrough op.
// And switch the order of the ops.
- op->setOperand(0, user->getResult(0));
- op->moveAfter(user);
+ dequantize_op->setOperand(0, passthrough_op->getResult(0));
+ dequantize_op->moveAfter(passthrough_op);
return success();
}
};