Add shape inference for tfl conv2d.
PiperOrigin-RevId: 358973892
Change-Id: I6ebe4f44d612c6981bd2633f9cbb394741268c12
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index ebb91df..2cd581b 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -259,6 +259,7 @@
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
+ "//tensorflow/core:framework",
"//tensorflow/lite/schema:schema_fbs",
"//third_party/eigen3",
"@llvm-project//llvm:Support",
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index e55fee4..0891eed 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -46,6 +46,7 @@
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/core/framework/kernel_shape_util.h"
namespace mlir {
namespace TFL {
@@ -948,6 +949,108 @@
// results.insert<RemoveOptionalZeroBias<Conv2DOp>>(context);
}
+static LogicalResult ComputeConvWindowedOutputSize(
+ int64_t input_size, int64_t filter_size, int64_t dilation_rate,
+ int64_t stride, tensorflow::Padding padding, int64_t *output_size) {
+ int64_t pad_low;
+ int64_t pad_high;
+
+ tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
+ input_size, filter_size, dilation_rate, stride, padding, output_size,
+ &pad_low, &pad_high);
+ // Return failure if expected_output_size could not be calculated.
+ if (!status.ok()) return failure();
+ return success();
+}
+
+LogicalResult Conv2DOp::inferReturnTypes(
+ MLIRContext *, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attr, RegionRange,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ Conv2DOpAdaptor op(operands, attr);
+
+ const Value input = op.input();
+ const Value filter = op.filter();
+
+ const RankedTensorType input_ty =
+ input.getType().dyn_cast_or_null<RankedTensorType>();
+ const RankedTensorType filter_ty =
+ filter.getType().dyn_cast_or_null<RankedTensorType>();
+ // If indeed both input type & filter type are ranked type and have ranks.
+ // We will need to check their ranks are valid.
+ if ((input_ty && input_ty.hasRank() && input_ty.getRank() != 4) ||
+ (filter_ty && filter_ty.hasRank() && filter_ty.getRank() != 4)) {
+ return emitOptionalError(location, "Invalid ranks");
+ }
+
+ // If either input or filter is unranked, we will just return unranked output
+ // shape.
+ if (!input_ty || !filter_ty || !input_ty.hasRank() || !filter_ty.hasRank()) {
+ Type result_type;
+ result_type = UnrankedTensorType::get(
+ input.getType().cast<ShapedType>().getElementType());
+ inferredReturnTypes.assign({result_type});
+ return success();
+ }
+
+ auto stride_h = op.stride_h().getInt();
+ auto stride_w = op.stride_w().getInt();
+ auto dilation_h = op.dilation_h_factor().getInt();
+ auto dilation_w = op.dilation_w_factor().getInt();
+
+ // We don't have EXPLICIT PADDING in TfLite.
+ auto paddings = op.padding().getValue();
+ tensorflow::Padding padding;
+ auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
+ if (!padding_is_valid.ok()) {
+ return emitOptionalError(location, "invalid padding format provided");
+ }
+
+ // Output always have rank 4. All dimensions are initialized to
+ // dynamic size and can be partially inferred.
+ // TFL's conv2d is always NHWC format & the filter is OHWI.
+ SmallVector<int64_t, 4> return_shape(4, ShapedType::kDynamicSize);
+ return_shape[0] = input_ty.getDimSize(0);
+ return_shape[3] = filter_ty.getDimSize(0);
+
+ // Spatial dimensions can be inferred only when both input and filter are
+ // ranked because we need to get their spatial dimensions.
+
+ // Height.
+ if (!input_ty.isDynamicDim(1) && !filter_ty.isDynamicDim(1)) {
+ int64_t output_height;
+ if (failed(ComputeConvWindowedOutputSize(
+ input_ty.getDimSize(1), filter_ty.getDimSize(1), dilation_h,
+ stride_h, padding, &output_height))) {
+ return failure();
+ }
+ return_shape[1] = output_height;
+ }
+
+ // Width.
+ if (!input_ty.isDynamicDim(2) && !filter_ty.isDynamicDim(2)) {
+ int64_t output_width;
+ if (failed(ComputeConvWindowedOutputSize(
+ input_ty.getDimSize(2), filter_ty.getDimSize(2), dilation_w,
+ stride_w, padding, &output_width))) {
+ return failure();
+ }
+ return_shape[2] = output_width;
+ }
+
+ auto result_type =
+ mlir::RankedTensorType::get(return_shape, input_ty.getElementType());
+
+ inferredReturnTypes.assign({result_type});
+ return success();
+}
+
+bool Conv2DOp::isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
+ if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
+ if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// DepthwiseConv2DO
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index b65a1da..5300649 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -478,9 +478,13 @@
string customOption = ?;
}
-class TFL_ConvOp<string mnemonic, string opSummary, int index> :
- TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
- AffineQuantizedOpInterface, AffineOpCoefficient<index, 1>, TFL_SparseOp]> {
+class TFL_ConvOp<string mnemonic, string opSummary, int index,
+ list<OpTrait> additional_traits = []> :
+ TFL_Op<mnemonic,[NoSideEffect,
+ AccumulatorUniformScale<2, 0, 1>,
+ AffineQuantizedOpInterface,
+ AffineOpCoefficient<index, 1>,
+ TFL_SparseOp] # additional_traits> {
let summary = opSummary # " operator";
let description = [{
@@ -863,7 +867,8 @@
let results = (outs AnyTensor:$output);
}
-def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
+def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0,
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
@@ -874,6 +879,9 @@
std::vector<int> GetSparseOperands() { return {1}; }
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
+
+ // Returns whether the return types are compatible.
+ static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r);
}];
}
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index fd33655..7473c5d 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -565,7 +565,7 @@
// -----
func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
- // expected-error @+1 {{attribute 'padding' failed to satisfy constraint: padding enum}}
+ // expected-error @+1 {{invalid padding format provided}}
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SOMETHING", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
}
diff --git a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir
new file mode 100644
index 0000000..1a1ade7
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir
@@ -0,0 +1,76 @@
+// RUN: tf-opt -split-input-file -verify-diagnostics --tf-shape-inference %s | FileCheck %s
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeValidPadding
+func @testConv2dShapeValidPadding(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+ // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x108x76x128xf32>
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+ return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceSamePadding
+func @testConv2dShapeInferenceSamePadding(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+ // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32>
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+ return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceDilation
+func @testConv2dShapeInferenceDilation(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+ // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32>
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+ return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceStrides
+func @testConv2dShapeInferenceStrides(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+ // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x56x40x128xf32>
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+ return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceUnranked
+func @testConv2dShapeInferenceUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceDynamic
+func @testConv2dShapeInferenceDynamic(%arg0: tensor<1x?x?x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+ // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x?x?x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x?x?x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+ return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+func @testConv2dShapeInvalidRanks(%arg0: tensor<1x112x80xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+ // expected-error @+1 {{Invalid ranks}}
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+ return %0 : tensor<1x?x?x128xf32>
+}
+}
+