Only handle static shapes in optimiza_op_order.cc

Some readability improvements by making variable names clear.

PiperOrigin-RevId: 394593093
Change-Id: I397805356a43d97e25d6084e6951c3667e18706c
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir
index c2a7a16..a46ce59 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize_op_order.mlir
@@ -29,3 +29,14 @@
 // CHECK-NEXT: tfl.dequantize
 // CHECK-NEXT: tfl.unpack
 }
+
+// CHECK-LABEL: no_pushdown_dynamic_shape
+func @no_pushdown_dynamic_shape(%arg0: tensor<?x1000x1000x!quant.uniform<i8:f32, 7.812500e-03>>, %arg1: tensor<1x1xi32>) -> tensor<?x1x1000xf32> {
+  %0 = "tfl.dequantize"(%arg0) : (tensor<?x1000x1000x!quant.uniform<i8:f32, 7.812500e-03>>) -> tensor<?x1000x1000xf32>
+  %1 = "tfl.gather"(%0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32}: (tensor<?x1000x1000xf32>, tensor<1x1xi32>) -> tensor<?x1x1000xf32>
+  return %1 : tensor<?x1x1000xf32>
+
+// CHECK-NEXT: tfl.dequantize
+// CHECK-NEXT: tfl.gather
+}
+
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc
index abfd688..e9613a4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
@@ -26,49 +27,57 @@
 namespace {
 
 // Dequantize ops will produce 3x larger tensors, so we want to move it after
-// some passthrought ops to reduce the memory consumption.
+// some passthrough ops to reduce the memory consumption.
 struct PushDownDequantize : public OpRewritePattern<DequantizeOp> {
   explicit PushDownDequantize(MLIRContext* context)
       : OpRewritePattern<DequantizeOp>(context) {}
 
-  LogicalResult matchAndRewrite(DequantizeOp op,
+  LogicalResult matchAndRewrite(DequantizeOp dequantize_op,
                                 PatternRewriter& rewriter) const override {
-    if (!op->hasOneUse()) return failure();
+    if (!dequantize_op->hasOneUse()) return failure();
 
-    auto use = op->use_begin();
-    Operation* user = use->getOwner();
+    auto use = dequantize_op->use_begin();
+    Operation* passthrough_op = use->getOwner();
     unsigned operand_index = use->getOperandNumber();
-    if (user->hasTrait<OpTrait::IsTerminator>()) return failure();
+    if (passthrough_op->hasTrait<OpTrait::IsTerminator>()) return failure();
 
-    auto get_num_elements = [](Value value) {
-      return value.getType().cast<TensorType>().getNumElements();
+    auto get_num_elements = [](RankedTensorType tensor) {
+      return tensor.getNumElements();
     };
 
     // If the op is the pass-through op with (3x) smaller output, the dequantize
     // op can be pushed down to the single result of this op.
-    if (!llvm::dyn_cast<mlir::SameScalesOpInterface>(user) ||
-        user->getNumResults() != 1 ||
-        get_num_elements(user->getOperand(operand_index)) <=
-            get_num_elements(user->getResult(0))) {
+    if (!llvm::dyn_cast<mlir::SameScalesOpInterface>(passthrough_op) ||
+        passthrough_op->getNumResults() != 1) {
+      return failure();
+    }
+    // Only push down the dequantize op when the output is smaller, so that it
+    // can have smaller memory usage.
+    auto input_type =
+        dequantize_op.output().getType().dyn_cast<RankedTensorType>();
+    auto output_type =
+        passthrough_op->getResult(0).getType().dyn_cast<RankedTensorType>();
+    if (!input_type || !output_type || !input_type.hasStaticShape() ||
+        !output_type.hasStaticShape() ||
+        get_num_elements(input_type) <= get_num_elements(output_type)) {
       return failure();
     }
 
     // Set the output type of the dequantize op and push it down.
-    Type result_type = user->getResult(0).getType();
-    op.output().setType(result_type);
-    user->replaceAllUsesWith(op);
+    dequantize_op.output().setType(output_type);
+    passthrough_op->replaceAllUsesWith(dequantize_op);
 
-    // Set the input type of the pass through op and pull it up.
-    Type user_new_type =
-        QuantizedType::getQuantizedElementType(op.input().getType())
-            .castFromExpressedType(result_type);
-    user->getResult(0).setType(user_new_type);
-    user->setOperand(operand_index, op.input());
+    // Set the input type of the passthrough op and pull it up.
+    Type new_output_type =
+        QuantizedType::getQuantizedElementType(dequantize_op.input().getType())
+            .castFromExpressedType(output_type);
+    passthrough_op->getResult(0).setType(new_output_type);
+    passthrough_op->setOperand(operand_index, dequantize_op.input());
 
-    // Set the input of the dequantize to the result of the pass throught op.
+    // Set the input of the dequantize to the result of the passthrough op.
     // And switch the order of the ops.
-    op->setOperand(0, user->getResult(0));
-    op->moveAfter(user);
+    dequantize_op->setOperand(0, passthrough_op->getResult(0));
+    dequantize_op->moveAfter(passthrough_op);
     return success();
   }
 };