Add shape inference for tfl conv2d.

PiperOrigin-RevId: 358973892
Change-Id: I6ebe4f44d612c6981bd2633f9cbb394741268c12
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index ebb91df..2cd581b 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -259,6 +259,7 @@
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
+        "//tensorflow/core:framework",
         "//tensorflow/lite/schema:schema_fbs",
         "//third_party/eigen3",
         "@llvm-project//llvm:Support",
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index e55fee4..0891eed 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -46,6 +46,7 @@
 #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/core/framework/kernel_shape_util.h"
 
 namespace mlir {
 namespace TFL {
@@ -948,6 +949,108 @@
   // results.insert<RemoveOptionalZeroBias<Conv2DOp>>(context);
 }
 
+static LogicalResult ComputeConvWindowedOutputSize(
+    int64_t input_size, int64_t filter_size, int64_t dilation_rate,
+    int64_t stride, tensorflow::Padding padding, int64_t *output_size) {
+  int64_t pad_low;
+  int64_t pad_high;
+
+  tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
+      input_size, filter_size, dilation_rate, stride, padding, output_size,
+      &pad_low, &pad_high);
+  // Return failure if expected_output_size could not be calculated.
+  if (!status.ok()) return failure();
+  return success();
+}
+
+LogicalResult Conv2DOp::inferReturnTypes(
+    MLIRContext *, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attr, RegionRange,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  Conv2DOpAdaptor op(operands, attr);
+
+  const Value input = op.input();
+  const Value filter = op.filter();
+
+  const RankedTensorType input_ty =
+      input.getType().dyn_cast_or_null<RankedTensorType>();
+  const RankedTensorType filter_ty =
+      filter.getType().dyn_cast_or_null<RankedTensorType>();
+  // If indeed both input type & filter type are ranked type and have ranks.
+  // We will need to check their ranks are valid.
+  if ((input_ty && input_ty.hasRank() && input_ty.getRank() != 4) ||
+      (filter_ty && filter_ty.hasRank() && filter_ty.getRank() != 4)) {
+    return emitOptionalError(location, "Invalid ranks");
+  }
+
+  // If either input or filter is unranked, we will just return unranked output
+  // shape.
+  if (!input_ty || !filter_ty || !input_ty.hasRank() || !filter_ty.hasRank()) {
+    Type result_type;
+    result_type = UnrankedTensorType::get(
+        input.getType().cast<ShapedType>().getElementType());
+    inferredReturnTypes.assign({result_type});
+    return success();
+  }
+
+  auto stride_h = op.stride_h().getInt();
+  auto stride_w = op.stride_w().getInt();
+  auto dilation_h = op.dilation_h_factor().getInt();
+  auto dilation_w = op.dilation_w_factor().getInt();
+
+  // We don't have EXPLICIT PADDING in TfLite.
+  auto paddings = op.padding().getValue();
+  tensorflow::Padding padding;
+  auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
+  if (!padding_is_valid.ok()) {
+    return emitOptionalError(location, "invalid padding format provided");
+  }
+
+  // Output always have rank 4. All dimensions are initialized to
+  // dynamic size and can be partially inferred.
+  // TFL's conv2d is always NHWC format & the filter is OHWI.
+  SmallVector<int64_t, 4> return_shape(4, ShapedType::kDynamicSize);
+  return_shape[0] = input_ty.getDimSize(0);
+  return_shape[3] = filter_ty.getDimSize(0);
+
+  // Spatial dimensions can be inferred only when both input and filter are
+  // ranked because we need to get their spatial dimensions.
+
+  // Height.
+  if (!input_ty.isDynamicDim(1) && !filter_ty.isDynamicDim(1)) {
+    int64_t output_height;
+    if (failed(ComputeConvWindowedOutputSize(
+            input_ty.getDimSize(1), filter_ty.getDimSize(1), dilation_h,
+            stride_h, padding, &output_height))) {
+      return failure();
+    }
+    return_shape[1] = output_height;
+  }
+
+  // Width.
+  if (!input_ty.isDynamicDim(2) && !filter_ty.isDynamicDim(2)) {
+    int64_t output_width;
+    if (failed(ComputeConvWindowedOutputSize(
+            input_ty.getDimSize(2), filter_ty.getDimSize(2), dilation_w,
+            stride_w, padding, &output_width))) {
+      return failure();
+    }
+    return_shape[2] = output_width;
+  }
+
+  auto result_type =
+      mlir::RankedTensorType::get(return_shape, input_ty.getElementType());
+
+  inferredReturnTypes.assign({result_type});
+  return success();
+}
+
+bool Conv2DOp::isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
+  if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
+  if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // DepthwiseConv2DO
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index b65a1da..5300649 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -478,9 +478,13 @@
   string customOption = ?;
 }
 
-class TFL_ConvOp<string mnemonic, string opSummary, int index> :
-    TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
-    AffineQuantizedOpInterface, AffineOpCoefficient<index, 1>, TFL_SparseOp]> {
+class TFL_ConvOp<string mnemonic, string opSummary, int index,
+                 list<OpTrait> additional_traits = []> :
+    TFL_Op<mnemonic,[NoSideEffect,
+                     AccumulatorUniformScale<2, 0, 1>,
+                     AffineQuantizedOpInterface,
+                     AffineOpCoefficient<index, 1>,
+                     TFL_SparseOp] # additional_traits> {
   let summary = opSummary # " operator";
 
   let description = [{
@@ -863,7 +867,8 @@
   let results = (outs AnyTensor:$output);
 }
 
-def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
+def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0,
+      [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let hasCanonicalizer = 1;
 
   let extraClassDeclaration = [{
@@ -874,6 +879,9 @@
     std::vector<int> GetSparseOperands() { return {1}; }
     std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
     std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
+
+    // Returns whether the return types are compatible.
+    static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r);
   }];
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index fd33655..7473c5d 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -565,7 +565,7 @@
 // -----
 
 func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
-  // expected-error @+1 {{attribute 'padding' failed to satisfy constraint: padding enum}}
+  // expected-error @+1 {{invalid padding format provided}}
   %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SOMETHING", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
   return %0 : tensor<256x30x30x16xf32>
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir
new file mode 100644
index 0000000..1a1ade7
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir
@@ -0,0 +1,76 @@
+// RUN: tf-opt -split-input-file -verify-diagnostics --tf-shape-inference %s | FileCheck %s
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeValidPadding
+func @testConv2dShapeValidPadding(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+  // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x108x76x128xf32>
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+  return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceSamePadding
+func @testConv2dShapeInferenceSamePadding(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+  // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32>
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+  return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceDilation
+func @testConv2dShapeInferenceDilation(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+  // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x112x80x128xf32>
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+  return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceStrides
+func @testConv2dShapeInferenceStrides(%arg0: tensor<1x112x80x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+  // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x56x40x128xf32>
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x80x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+  return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceUnranked
+func @testConv2dShapeInferenceUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+// CHECK-LABEL: testConv2dShapeInferenceDynamic
+func @testConv2dShapeInferenceDynamic(%arg0: tensor<1x?x?x128xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+  // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x?x?x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x?x?x128xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+  return %0 : tensor<1x?x?x128xf32>
+}
+}
+
+// -----
+
+module attributes {tf.versions = {producer = 888 : i32}} {
+func @testConv2dShapeInvalidRanks(%arg0: tensor<1x112x80xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> {
+  // expected-error @+1 {{Invalid ranks}}
+  %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32>
+  return %0 : tensor<1x?x?x128xf32>
+}
+}
+