Add flex fallback for unsupported grouped convolution

PiperOrigin-RevId: 413825071
Change-Id: Ia98426ae165324680c8f93dfdec6375e3cc814c6
diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
index 5f498a4..272ed57 100644
--- a/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
+# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,2,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
 
 node {
   name: "input"
@@ -17,7 +17,7 @@
           size: 1
         }
         dim {
-          size: 8
+          size: 2
         }
         dim {
           size: 8
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 0c2d00b..7ccd4c3 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -694,4 +694,58 @@
   // CHECK: return %[[MATMUL]] : tensor<2x4xf32>
 }
 
+func @UnsupportedGroupConv(%arg0: tensor<?x128x24xf32>) -> (tensor<?x6x14xf32>) {
+  %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x2x14xf32>} : () -> tensor<3x2x14xf32>
+  %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<14xf32>} : () -> tensor<14xf32>
+  %cst_1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x24x14xf32>} : () -> tensor<3x24x14xf32>
+  %cst_2 = "tf.Const"() {value = dense<0.000000e+00> : tensor<14xf32>} : () -> tensor<14xf32>
+  %cst_3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %cst_4 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
+  %0 = "tf.ExpandDims"(%arg0, %cst_4) {device = ""} : (tensor<?x128x24xf32>, tensor<i32>) -> tensor<?x1x128x24xf32>
+  %1 = "tf.ExpandDims"(%cst, %cst_3) {device = ""} : (tensor<3x2x14xf32>, tensor<i32>) -> tensor<1x3x2x14xf32>
+  %2 = "tf.ExpandDims"(%cst_1, %cst_3) {device = ""} : (tensor<3x24x14xf32>, tensor<i32>) -> tensor<1x3x24x14xf32>
+  %3 = "tf.Conv2D"(%0, %2) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x128x24xf32>, tensor<1x3x24x14xf32>) -> tensor<?x1x26x14xf32>
+  %4 = "tf.Squeeze"(%3) {device = "", squeeze_dims = [-3]} : (tensor<?x1x26x14xf32>) -> tensor<?x26x14xf32>
+  %5 = "tf.BiasAdd"(%4, %cst_2) {data_format = "NHWC", device = ""} : (tensor<?x26x14xf32>, tensor<14xf32>) -> tensor<?x26x14xf32>
+  %6 = "tf.ExpandDims"(%5, %cst_4) {device = ""} : (tensor<?x26x14xf32>, tensor<i32>) -> tensor<?x1x26x14xf32>
+  %7 = "tf.Conv2D"(%6, %1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x26x14xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+  %8 = "tf.Squeeze"(%7) {device = "", squeeze_dims = [-3]} : (tensor<?x1x6x14xf32>) -> tensor<?x6x14xf32>
+  %9 = "tf.BiasAdd"(%8, %cst_0) {data_format = "NHWC", device = ""} : (tensor<?x6x14xf32>, tensor<14xf32>) -> tensor<?x6x14xf32>
+  %10 = "tf.Identity"(%9) {device = ""} : (tensor<?x6x14xf32>) -> tensor<?x6x14xf32>
+  %11 = "tf.Identity"(%10) {device = ""} : (tensor<?x6x14xf32>) -> tensor<?x6x14xf32>
+  return %11 : tensor<?x6x14xf32>
+
+  // CHECK-LABEL: UnsupportedGroupConv
+  // CHECK: "tfl.conv_2d"
+  // CHECK-NOT: "tfl.conv_2d"
+  // CHECK: "tf.Conv2D"
+}
+
+func @UnsupportedGroupConv_UnrankedTensorType(%arg0: tensor<*xf32>, %arg1: tensor<1x3x2x14xf32>) -> (tensor<?x1x6x14xf32>) {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+  return %0 : tensor<?x1x6x14xf32>
+
+  // CHECK-LABEL: UnsupportedGroupConv_UnrankedTensorType
+  // CHECK-NOT: "tfl.conv_2d"
+  // CHECK: "tf.Conv2D"
+}
+
+func @UnsupportedGroupConv_DynamicDimAtInputDimThree(%arg0: tensor<?x1x26x?xf32>, %arg1: tensor<1x3x2x14xf32>) -> (tensor<?x1x6x14xf32>) {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x26x?xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+  return %0 : tensor<?x1x6x14xf32>
+
+  // CHECK-LABEL: UnsupportedGroupConv_DynamicDimAtInputDimThree
+  // CHECK-NOT: "tfl.conv_2d"
+  // CHECK: "tf.Conv2D"
+}
+
+func @UnsupportedGroupConv_MultipleGroup(%arg0: tensor<?x1x26x14xf32>, %arg1: tensor<1x3x2x14xf32>) -> (tensor<?x1x6x14xf32>) {
+  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x26x14xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+  return %0 : tensor<?x1x6x14xf32>
+
+  // CHECK-LABEL: UnsupportedGroupConv_MultipleGroup
+  // CHECK-NOT: "tfl.conv_2d"
+  // CHECK: "tf.Conv2D"
+}
+
 }
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 48ff038..f322d27 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -225,6 +225,21 @@
         !filter_type.hasStaticShape())
       return failure();
 
+    Value input = tf_op.input();
+    RankedTensorType input_type =
+        input.getType().template dyn_cast<RankedTensorType>();
+    // Safe guard for skipping grouped convolution legalization.
+    // Only rank size four input will be only available by the tf.Conv2D
+    // operator verification.
+    if (!input_type || input_type.isDynamicDim(3)) {
+      return failure();
+    }
+    // Check if the given op is based on unsupported grouped convolution.
+    // Dim size zero will be verified by the tf.Conv2D operator verification.
+    if (input_type.getDimSize(3) / filter_type.getDimSize(2) != 1) {
+      return failure();
+    }
+
     // TensorFlow convolution op only has two inputs, while the TFLite one has
     // three, with the bias vector marked as optional. However, TOCO has a
     // dedicated pass, EnsureBiasVectors, to create default bias vectors for all
@@ -243,7 +258,6 @@
     auto bias =
         rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
 
-    auto input = tf_op.input();
     if (op->getAttrOfType<StringAttr>("padding").getValue() == "EXPLICIT") {
       // Add Const op for padding value.
       ArrayRef<Attribute> padding_attr_array =