Add tf lite pattern to transform tfl.squeeze to tfl.reshape op.
PiperOrigin-RevId: 282654714
Change-Id: Id4483f0e12883fabbe8c5b82b41f14f1a45689bd
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index b7df64a..49e9820 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -577,3 +577,13 @@
// CHECK: tfl.hard_swish
// CHECK: tfl.depthwise_conv_2d
}
+
+// CHECK-LABEL: squeezeToReshape
+func @squeezeToReshape(%arg0: tensor<1x1x2xf32>) -> tensor<2xf32> {
+ %0 = "tfl.squeeze"(%arg0) : (tensor<1x1x2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+
+ // CHECK: [[cst:.*]] = constant dense<2> : tensor<1xi32>
+ // CHECK: %0 = "tfl.reshape"(%[[arg:.*]], %[[cst:.*]]) : (tensor<1x1x2xf32>, tensor<1xi32>) -> tensor<2xf32>
+ // CHECK: return %0
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 4adf024..e1e336f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -122,6 +122,20 @@
return ExpandTo4DForConvImpl(a, true);
}
+DenseElementsAttr GetShape(Value *output_val) {
+ auto output_type = output_val->getType().cast<RankedTensorType>();
+ auto shape_vector = output_type.getShape();
+ std::vector<int32_t> shape(shape_vector.size());
+ for (int i = 0; i < shape_vector.size(); ++i) {
+ shape[i] = shape_vector[i];
+ }
+ return mlir::DenseElementsAttr::get(
+ RankedTensorType::get(
+ {static_cast<int>(shape.size())},
+ mlir::IntegerType::get(32, output_val->getContext())),
+ llvm::makeArrayRef(shape));
+}
+
#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
// Fuse Add with proceeding FullyConnected.
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index bb00a7e..7c4c015 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -260,3 +260,10 @@
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
+
+def GetShape: NativeCodeCall<"GetShape($0)">;
+
+def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
+ (TFL_ReshapeOp $input,
+ (ConstantOp (GetShape $squeeze_op))),
+ [(AnyStaticShapeTensor $squeeze_op)]>;