[TFLite/MLIR] Fix a bug when fusing mul into conv_2d.
PiperOrigin-RevId: 340705010
Change-Id: Iad3ce0c2bf5185183c7a12eca0d8a1fa57217638
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index 336a465..8854034 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -1382,3 +1382,18 @@
// CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16>
// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
}
+
+// CHECK-LABEL: fuseExpanded1DMulIntoConv2d
+func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x256xf32> {
+ %cst_0 = constant dense<1.4> : tensor<256x3x3x207xf32>
+ %cst_1 = constant dense<1.5> : tensor<256xf32>
+ %cst_2 = constant dense<2.0> : tensor<1x1x1x256xf32>
+ %0 = "tfl.conv_2d"(%arg0, %cst_0, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x8x8x207xf32>, tensor<256x3x3x207xf32>, tensor<256xf32>) -> tensor<1x8x8x256xf32>
+ %1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x8x8x256xf32>, tensor<1x1x1x256xf32>) -> tensor<1x8x8x256xf32>
+ return %1 : tensor<1x8x8x256xf32>
+
+// CHECK: %[[CST_0:.*]] = constant dense<2.800000e+00> : tensor<256x3x3x207xf32>
+// CHECK: %[[CST_1:.*]] = constant dense<3.000000e+00> : tensor<1x1x1x256xf32>
+// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
+
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index e6a3fd8..8e4dd27 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -200,16 +200,16 @@
ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
auto elements = a.dyn_cast<DenseElementsAttr>();
auto shape = elements.getType().getShape();
- if (shape.size() == 4) {
- return elements;
+ if (!shape.empty()) {
+ // Checks that elements are essentially 1d.
+ assert(elements.getNumElements() == shape.back());
}
std::vector<int64_t> shape_data = {1, 1, 1, 1};
- if (shape.size() == 1 || shape.empty()) {
- if (is_depthwise)
- shape_data[3] = shape.empty() ? 1 : shape[0];
- else
- shape_data[0] = shape.empty() ? 1 : shape[0];
- }
+ const int vector_length = elements.getNumElements();
+ if (is_depthwise)
+ shape_data[3] = vector_length;
+ else
+ shape_data[0] = vector_length;
auto new_shape =
RankedTensorType::get(shape_data, elements.getType().getElementType());
return elements.reshape(new_shape);