Add tf lite pattern to transform tfl.squeeze to tfl.reshape op.

PiperOrigin-RevId: 282654714
Change-Id: Id4483f0e12883fabbe8c5b82b41f14f1a45689bd
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index b7df64a..49e9820 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -577,3 +577,13 @@
 // CHECK: tfl.hard_swish
 // CHECK: tfl.depthwise_conv_2d
 }
+
+// CHECK-LABEL: squeezeToReshape
+func @squeezeToReshape(%arg0: tensor<1x1x2xf32>) -> tensor<2xf32> {
+  %0 = "tfl.squeeze"(%arg0) : (tensor<1x1x2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+
+  // CHECK: [[cst:.*]] = constant dense<2> : tensor<1xi32>
+  // CHECK: %0 = "tfl.reshape"(%[[arg:.*]], %[[cst:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32>
+  // CHECK: return %0
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 4adf024..e1e336f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -122,6 +122,20 @@
   return ExpandTo4DForConvImpl(a, true);
 }
 
+DenseElementsAttr GetShape(Value *output_val) {
+  auto output_type = output_val->getType().cast<RankedTensorType>();
+  auto shape_vector = output_type.getShape();
+  std::vector<int32_t> shape(shape_vector.size());
+  for (int i = 0; i < shape_vector.size(); ++i) {
+    shape[i] = shape_vector[i];
+  }
+  return mlir::DenseElementsAttr::get(
+      RankedTensorType::get(
+          {static_cast<int>(shape.size())},
+          mlir::IntegerType::get(32, output_val->getContext())),
+      llvm::makeArrayRef(shape));
+}
+
 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
 
 // Fuse Add with proceeding FullyConnected.
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index bb00a7e..7c4c015 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -260,3 +260,10 @@
 
 foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
   in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
+
+def GetShape: NativeCodeCall<"GetShape($0)">;
+
+def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
+          (TFL_ReshapeOp $input,
+           (ConstantOp (GetShape $squeeze_op))),
+          [(AnyStaticShapeTensor $squeeze_op)]>;