Accept default attributes in quantized composite functions
PiperOrigin-RevId: 432158031
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
index d598435..c127028 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
@@ -275,10 +275,10 @@
}
if (identifier_to_attr.count(
llvm::StringRef(std::string(key_and_value_pair[1]))) == 0) {
- float_func.emitError(
- absl::StrCat("Couldn't find the attribute corresponding to ",
- key_and_value_pair[1]));
- return failure();
+ float_func.emitWarning(absl::StrCat("Using the default value for the '",
+ key_and_value_pair[0],
+ "' attribute"));
+ continue;
}
inner_op.setAttr(llvm::StringRef(std::string(key_and_value_pair[0])),
identifier_to_attr[llvm::StringRef(
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir
index 547269b..72892b5 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions | FileCheck %s
+// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions -symbol-dce | FileCheck %s
module {
func @add(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
@@ -102,7 +102,7 @@
// CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]], %[[w_quant]], %[[b_quant]],
// CHECK-SAME: %[[in_scale]], %[[in_zp]], %[[w_scale]], %[[w_zp]],
// CHECK-SAME: %[[b_scale]], %[[w_zp]], %[[out_scale]], %[[out_zp]])
-// CHECK-SAME: f = @quantized_conv2d_relu6_fn
+// CHECK-SAME: f = @quantized_conv2d_relu6_fn_0
// CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[conv_quant]], %[[out_scale]], %[[out_zp]])
// CHECK-SAME: f = @dequantize_i8
@@ -119,3 +119,38 @@
// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"
// CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}
}
+
+// -----
+
+module {
+ func @conv_with_default_attributes(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) {
+ %cst = "tf.Const"() {value = dense<[[[[1.600000e-01, 1.000000e-01], [5.100000e-01, 5.400000e-01], [-5.000000e-01, 4.100000e-01]], [[-3.500000e-01, 5.000000e-02], [-0.00999999977, 1.600000e-01], [-4.800000e-01, -2.400000e-01]]], [[[-3.500000e-01, -2.100000e-01], [-1.400000e-01, -2.000000e-02], [4.800000e-01, 3.500000e-01]], [[-1.900000e-01, 3.200000e-01], [0.00999999977, -7.000000e-02], [2.000000e-01, -4.000000e-02]]]]> : tensor<2x2x3x2xf32>} : () -> tensor<2x2x3x2xf32>
+ %cst_0 = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
+ %0 = "quant.qcast"(%cst) : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2x!quant.uniform<i8<-127:127>:f32:3, {4.000000e-03,5.000000e-03}>>
+ %1 = "quant.dcast"(%0) : (tensor<2x2x3x2x!quant.uniform<i8<-127:127>:f32:3, {4.000000e-03,5.000000e-03}>>) -> tensor<*xf32>
+ %2 = "quant.qcast"(%arg0) : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3x!quant.uniform<i8:f32, 8.000000e-03>>
+ %3 = "quant.dcast"(%2) : (tensor<1x2x2x3x!quant.uniform<i8:f32, 8.000000e-03>>) -> tensor<*xf32>
+ %4 = "tf.PartitionedCall"(%3, %1, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @fused_conv2d_relu6_fn_1} : (tensor<*xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
+ %5 = "quant.qcast"(%4) : (tensor<*xf32>) -> tensor<*x!quant.uniform<i8:f32, 5.000000e-02:-1>>
+ %6 = "quant.dcast"(%5) : (tensor<*x!quant.uniform<i8:f32, 5.000000e-02:-1>>) -> tensor<*xf32>
+ return %6 : tensor<*xf32>
+ }
+ func private @fused_conv2d_relu6_fn_1(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.fused_function} {
+ %0 = "tf._FusedConv2D"(%arg0, %arg1, %arg2) {attr_map = "0:strides,2:padding,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu"], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+// CHECK-LABEL: func @conv_with_default_attributes
+
+// CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0
+// CHECK-SAME: f = @quantize_i8
+// CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]]
+// CHECK-SAME: f = @quantized_conv2d_relu6_fn_0
+// CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[conv_quant]]
+// CHECK-SAME: f = @dequantize_i8
+// CHECK: return %[[dequantize]]
+
+// CHECK-LABEL: func private @quantized_conv2d_relu6_fn_0
+// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"
+// CHECK-SAME: {dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 2, 1]}
+}