[TFLite/MLIR] Fix a bug when fusing mul into conv_2d.

PiperOrigin-RevId: 340705010
Change-Id: Iad3ce0c2bf5185183c7a12eca0d8a1fa57217638
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index 336a465..8854034 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -1382,3 +1382,18 @@
   // CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16>
   // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
 }
+
+// CHECK-LABEL: fuseExpanded1DMulIntoConv2d
+func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x256xf32> {
+  %cst_0 = constant dense<1.4> : tensor<256x3x3x207xf32>
+  %cst_1 = constant dense<1.5> : tensor<256xf32>
+  %cst_2 = constant dense<2.0> : tensor<1x1x1x256xf32>
+  %0 = "tfl.conv_2d"(%arg0, %cst_0, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x8x8x207xf32>, tensor<256x3x3x207xf32>, tensor<256xf32>) -> tensor<1x8x8x256xf32>
+  %1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x8x8x256xf32>, tensor<1x1x1x256xf32>) -> tensor<1x8x8x256xf32>
+  return %1 : tensor<1x8x8x256xf32>
+
+// CHECK: %[[CST_0:.*]] = constant dense<2.800000e+00> : tensor<256x3x3x207xf32>
+// CHECK: %[[CST_1:.*]] = constant dense<3.000000e+00> : tensor<1x1x1x256xf32>
+// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
+
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index e6a3fd8..8e4dd27 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -200,16 +200,16 @@
 ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
   auto elements = a.dyn_cast<DenseElementsAttr>();
   auto shape = elements.getType().getShape();
-  if (shape.size() == 4) {
-    return elements;
+  if (!shape.empty()) {
+    // Checks that elements are essentially 1d.
+    assert(elements.getNumElements() == shape.back());
   }
   std::vector<int64_t> shape_data = {1, 1, 1, 1};
-  if (shape.size() == 1 || shape.empty()) {
-    if (is_depthwise)
-      shape_data[3] = shape.empty() ? 1 : shape[0];
-    else
-      shape_data[0] = shape.empty() ? 1 : shape[0];
-  }
+  const int vector_length = elements.getNumElements();
+  if (is_depthwise)
+    shape_data[3] = vector_length;
+  else
+    shape_data[0] = vector_length;
   auto new_shape =
       RankedTensorType::get(shape_data, elements.getType().getElementType());
   return elements.reshape(new_shape);