Add flex fallback for unsupported grouped convolution
PiperOrigin-RevId: 413825071
Change-Id: Ia98426ae165324680c8f93dfdec6375e3cc814c6
diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
index 5f498a4..272ed57 100644
--- a/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/end2end/conv_2d_nchw.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
+# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,2,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
node {
name: "input"
@@ -17,7 +17,7 @@
size: 1
}
dim {
- size: 8
+ size: 2
}
dim {
size: 8
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 0c2d00b..7ccd4c3 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -694,4 +694,58 @@
// CHECK: return %[[MATMUL]] : tensor<2x4xf32>
}
+func @UnsupportedGroupConv(%arg0: tensor<?x128x24xf32>) -> (tensor<?x6x14xf32>) {
+ %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x2x14xf32>} : () -> tensor<3x2x14xf32>
+ %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<14xf32>} : () -> tensor<14xf32>
+ %cst_1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x24x14xf32>} : () -> tensor<3x24x14xf32>
+ %cst_2 = "tf.Const"() {value = dense<0.000000e+00> : tensor<14xf32>} : () -> tensor<14xf32>
+ %cst_3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %cst_4 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
+ %0 = "tf.ExpandDims"(%arg0, %cst_4) {device = ""} : (tensor<?x128x24xf32>, tensor<i32>) -> tensor<?x1x128x24xf32>
+ %1 = "tf.ExpandDims"(%cst, %cst_3) {device = ""} : (tensor<3x2x14xf32>, tensor<i32>) -> tensor<1x3x2x14xf32>
+ %2 = "tf.ExpandDims"(%cst_1, %cst_3) {device = ""} : (tensor<3x24x14xf32>, tensor<i32>) -> tensor<1x3x24x14xf32>
+ %3 = "tf.Conv2D"(%0, %2) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x128x24xf32>, tensor<1x3x24x14xf32>) -> tensor<?x1x26x14xf32>
+ %4 = "tf.Squeeze"(%3) {device = "", squeeze_dims = [-3]} : (tensor<?x1x26x14xf32>) -> tensor<?x26x14xf32>
+ %5 = "tf.BiasAdd"(%4, %cst_2) {data_format = "NHWC", device = ""} : (tensor<?x26x14xf32>, tensor<14xf32>) -> tensor<?x26x14xf32>
+ %6 = "tf.ExpandDims"(%5, %cst_4) {device = ""} : (tensor<?x26x14xf32>, tensor<i32>) -> tensor<?x1x26x14xf32>
+ %7 = "tf.Conv2D"(%6, %1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x26x14xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+ %8 = "tf.Squeeze"(%7) {device = "", squeeze_dims = [-3]} : (tensor<?x1x6x14xf32>) -> tensor<?x6x14xf32>
+ %9 = "tf.BiasAdd"(%8, %cst_0) {data_format = "NHWC", device = ""} : (tensor<?x6x14xf32>, tensor<14xf32>) -> tensor<?x6x14xf32>
+ %10 = "tf.Identity"(%9) {device = ""} : (tensor<?x6x14xf32>) -> tensor<?x6x14xf32>
+ %11 = "tf.Identity"(%10) {device = ""} : (tensor<?x6x14xf32>) -> tensor<?x6x14xf32>
+ return %11 : tensor<?x6x14xf32>
+
+ // CHECK-LABEL: UnsupportedGroupConv
+ // CHECK: "tfl.conv_2d"
+ // CHECK-NOT: "tfl.conv_2d"
+ // CHECK: "tf.Conv2D"
+}
+
+func @UnsupportedGroupConv_UnrankedTensorType(%arg0: tensor<*xf32>, %arg1: tensor<1x3x2x14xf32>) -> (tensor<?x1x6x14xf32>) {
+ %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+ return %0 : tensor<?x1x6x14xf32>
+
+ // CHECK-LABEL: UnsupportedGroupConv_UnrankedTensorType
+ // CHECK-NOT: "tfl.conv_2d"
+ // CHECK: "tf.Conv2D"
+}
+
+func @UnsupportedGroupConv_DynamicDimAtInputDimThree(%arg0: tensor<?x1x26x?xf32>, %arg1: tensor<1x3x2x14xf32>) -> (tensor<?x1x6x14xf32>) {
+ %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x26x?xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+ return %0 : tensor<?x1x6x14xf32>
+
+ // CHECK-LABEL: UnsupportedGroupConv_DynamicDimAtInputDimThree
+ // CHECK-NOT: "tfl.conv_2d"
+ // CHECK: "tf.Conv2D"
+}
+
+func @UnsupportedGroupConv_MultipleGroup(%arg0: tensor<?x1x26x14xf32>, %arg1: tensor<1x3x2x14xf32>) -> (tensor<?x1x6x14xf32>) {
+ %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 5, 1], use_cudnn_on_gpu = true} : (tensor<?x1x26x14xf32>, tensor<1x3x2x14xf32>) -> tensor<?x1x6x14xf32>
+ return %0 : tensor<?x1x6x14xf32>
+
+ // CHECK-LABEL: UnsupportedGroupConv_MultipleGroup
+ // CHECK-NOT: "tfl.conv_2d"
+ // CHECK: "tf.Conv2D"
+}
+
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 48ff038..f322d27 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -225,6 +225,21 @@
!filter_type.hasStaticShape())
return failure();
+ Value input = tf_op.input();
+ RankedTensorType input_type =
+ input.getType().template dyn_cast<RankedTensorType>();
+ // Safe guard for skipping grouped convolution legalization.
+ // Only rank size four input will be only available by the tf.Conv2D
+ // operator verification.
+ if (!input_type || input_type.isDynamicDim(3)) {
+ return failure();
+ }
+ // Check if the given op is based on unsupported grouped convolution.
+ // Dim size zero will be verified by the tf.Conv2D operator verification.
+ if (input_type.getDimSize(3) / filter_type.getDimSize(2) != 1) {
+ return failure();
+ }
+
// TensorFlow convolution op only has two inputs, while the TFLite one has
// three, with the bias vector marked as optional. However, TOCO has a
// dedicated pass, EnsureBiasVectors, to create default bias vectors for all
@@ -243,7 +258,6 @@
auto bias =
rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
- auto input = tf_op.input();
if (op->getAttrOfType<StringAttr>("padding").getValue() == "EXPLICIT") {
// Add Const op for padding value.
ArrayRef<Attribute> padding_attr_array =