[mhlo] Verifier for mhlo.ConvOp

PiperOrigin-RevId: 448608547
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 58014bd..d9b27a6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1535,6 +1535,7 @@
 
   let results = (outs HLO_Tensor);
   let hasCustomHLOConverter = 1;
+  let hasVerifier = 1;
 
   code extraClassDeclaration = [{
     bool hasWindowReversal() {
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index e561d7d..4959759 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -1792,6 +1792,311 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ConvOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Checks:
+//  P1. Same sizes for input, kernel and output spatial_dims.
+//  P2. Spatial and non-spatial dimentions (for input,kernel, &output) should
+//      be unique and in range [0, num_dims), where num_dims = rank of input
+//      (lhs/rhs) tensors.
+//
+//  Note that the spatial + non-spatial dimensions may not cover all the
+//  dimensions in the range [0,num) because of the presence of 'unknown'
+//  dimensions (ref. cl/415132294).
+LogicalResult isSpatialDimensionsValid(ConvOp op) {
+  auto input_spatial_dimensions =
+      op.dimension_numbers().getInputSpatialDimensions();
+  auto kernel_spatial_dimensions =
+      op.dimension_numbers().getKernelSpatialDimensions();
+  auto output_spatial_dimensions =
+      op.dimension_numbers().getOutputSpatialDimensions();
+
+  // P1.
+  if ((input_spatial_dimensions.size() != kernel_spatial_dimensions.size()) ||
+      (input_spatial_dimensions.size() != output_spatial_dimensions.size()))
+    return op.emitOpError() << "expects the same size for input, kernel and "
+                               "output spatial-dimensions, but got "
+                            << input_spatial_dimensions.size() << ", "
+                            << kernel_spatial_dimensions.size() << ", and "
+                            << output_spatial_dimensions.size() << " resp.";
+
+  // P2.
+  SmallVector<int64_t> input_dnums(input_spatial_dimensions.size() + 2);
+  input_dnums[0] = op.dimension_numbers().getInputBatchDimension();
+  input_dnums[1] = op.dimension_numbers().getInputFeatureDimension();
+  std::copy(input_spatial_dimensions.begin(), input_spatial_dimensions.end(),
+            input_dnums.begin() + 2);
+
+  SmallVector<int64_t> window_dnums(kernel_spatial_dimensions.size() + 2);
+  window_dnums[0] = op.dimension_numbers().getKernelInputFeatureDimension();
+  window_dnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension();
+  std::copy(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(),
+            window_dnums.begin() + 2);
+
+  SmallVector<int64_t> output_dnums(output_spatial_dimensions.size() + 2);
+  output_dnums[0] = op.dimension_numbers().getOutputBatchDimension();
+  output_dnums[1] = op.dimension_numbers().getOutputFeatureDimension();
+  std::copy(output_spatial_dimensions.begin(), output_spatial_dimensions.end(),
+            output_dnums.begin() + 2);
+
+  auto num_dims = op.lhs().getType().cast<RankedTensorType>().getRank();
+  const auto in_range = [num_dims](int64_t i) {
+    return 0 <= i && i < num_dims;
+  };
+
+  if (!llvm::all_of(input_dnums, in_range) ||
+      !llvm::all_of(window_dnums, in_range) ||
+      !llvm::all_of(output_dnums, in_range))
+    return op.emitOpError() << "expects input, kernel, and output "
+                               "dimension-numbers to be in-range [0, "
+                            << num_dims << ").";
+
+  const auto has_duplicates = [](SmallVector<int64_t>& dnums) {
+    std::sort(dnums.begin(), dnums.end());
+    auto last = std::unique(dnums.begin(), dnums.end());
+    return last != dnums.end();
+  };
+
+  if (has_duplicates(input_dnums))
+    return op.emitOpError()
+           << "expects input dimension-numbers to be unique, got {"
+           << input_dnums << "}.";
+
+  if (has_duplicates(window_dnums))
+    return op.emitOpError()
+           << "expects kernel dimension-numbers to be unique, got {"
+           << window_dnums << "}.";
+
+  if (has_duplicates(output_dnums))
+    return op.emitOpError()
+           << "expects output dimension-numbers to be unique, got {"
+           << output_dnums << "}.";
+
+  return success();
+}
+
+// Verifies the following properties:
+//  P1. The input, kernel, and output spatial-dimentions are valid.
+//  P2. Given,
+//          input-dimensions: b * input-spatial-dims * f
+//          kernel-dimensions: kernel-spatial-dims * i * o
+//          output-dimensions: b' * out-spatial-dims * f'
+//            where b = input-batch-dims
+//            where f = input-feature-dims
+//            where i = kernel-input-feature-dims
+//            where o = kernel-output-feature-dims
+//            where b' = output-batch-dims
+//            where f' = output-feature-dims
+//      Check the following properties w.r.t feature_group_count (fgc) and
+//      batch_group_count (bgc).
+//        fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1)
+//        b % bgc == 0
+//        f % fgc == 0 and i = f / fgc
+//        o (or f') % bgc == 0 and o (or f') % fgc == 0
+LogicalResult verifyConvolutionAttributes(ConvOp op) {
+  // P1.
+  if (failed(isSpatialDimensionsValid(op))) return failure();
+
+  // P2.
+  const int64_t feature_group_count = op.feature_group_count();
+  const int64_t batch_group_count = op.batch_group_count();
+
+  if (feature_group_count <= 0)
+    return op.emitOpError()
+           << "expects feature_group_count to be a positive number, got "
+           << feature_group_count << ".";
+
+  if (batch_group_count <= 0)
+    return op.emitOpError()
+           << "expects batch_group_count to be a positive number, got "
+           << batch_group_count << ".";
+
+  if (batch_group_count > 1 && feature_group_count > 1)
+    return op.emitOpError()
+           << "expects batch_group_count and feature_group_count not to be "
+              "both greater than 1. Got "
+           << batch_group_count << " and " << feature_group_count << " resp.";
+
+  auto lhs_type = op.lhs().getType().cast<RankedTensorType>();
+  const int64_t input_features =
+      lhs_type.getShape()[op.dimension_numbers().getInputFeatureDimension()];
+  const int64_t input_batch =
+      lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()];
+
+  auto rhs_type = op.rhs().getType().cast<RankedTensorType>();
+  const int64_t kernel_input_features =
+      rhs_type
+          .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()];
+  const int64_t kernel_output_features =
+      rhs_type
+          .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
+
+  if (!isDynamicDimSize(kernel_output_features)) {
+    if (kernel_output_features % batch_group_count != 0)
+      return op.emitOpError() << "expects output feature dimension size ("
+                              << kernel_output_features
+                              << ") to be a multiple of "
+                                 "batch_group_count. Got batch_group_count = "
+                              << batch_group_count << ".";
+
+    if (kernel_output_features % feature_group_count != 0)
+      return op.emitOpError()
+             << "expects kernel output feature dimension ("
+             << kernel_output_features
+             << ") to be divisible by "
+                "feature_group_count. For feature_group_count = "
+             << feature_group_count << ".";
+  }
+
+  if (!isDynamicDimSize(input_features)) {
+    if (input_features % feature_group_count != 0)
+      return op.emitOpError()
+             << "expects input feature dimension (" << input_features
+             << ") to be a multiple of "
+                "feature_group_count. Got feature_group_count = "
+             << feature_group_count << ".";
+
+    if (!isDynamicDimSize(kernel_input_features) &&
+        input_features / feature_group_count != kernel_input_features)
+      return op.emitOpError()
+             << "expects input feature dimension (" << input_features
+             << ") / "
+                "feature_group_count = kernel input feature dimension ("
+             << kernel_input_features
+             << "). Got feature_group_count = " << feature_group_count << ".";
+  }
+
+  if (!isDynamicDimSize(input_batch) && input_batch % batch_group_count != 0)
+    return op.emitOpError() << "expects input batch dimension (" << input_batch
+                            << ") to be divisible by "
+                               "batch_group_count. Got batch_group_count = "
+                            << batch_group_count << ".";
+
+  return success();
+}
+
+// Infer the return-shape of ConvOp.
+// Precondition:
+//  1. Input args to ConvOp 'op' are RankedTypes.
+//  2. rank-of(input-type) == rank-of(output-type)
+SmallVector<int64_t> inferConvOpReturnShape(
+    ConvOp op, const ArrayRef<WindowDimension> window) {
+  // We keep the 'unknown' dimensions (cl/415132294) as it is in the
+  // output-shape. To do that we initilize the output dimensions with the shape
+  // of the return-type and updates only the spatial + non-spatial dimensions.
+  // Precondition 2 ensures that size of output-shape == size of input-shape.
+  SmallVector<int64_t> output_dimensions =
+      to_vector(op.getResult().getType().cast<ShapedType>().getShape());
+
+  // Infer the output spatial dimensions.
+  auto lhs_type = op.lhs().getType().cast<RankedTensorType>();
+  auto input_spatial_dims = op.dimension_numbers().getInputSpatialDimensions();
+  auto num_spatial_dims = input_spatial_dims.size();
+  SmallVector<int64_t> input_spatial_dim_vals(num_spatial_dims);
+  for (int i = 0; i < num_spatial_dims; ++i)
+    input_spatial_dim_vals[i] = lhs_type.getShape()[input_spatial_dims[i]];
+
+  auto window_output_shape =
+      inferWindowOutputShape(input_spatial_dim_vals, window);
+
+  for (int i = 0; i < window.size(); ++i)
+    output_dimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] =
+        window_output_shape[i];
+
+  // Infer the output-batch-dimension and output-feature-dimension.
+  auto rhs_type = op.rhs().getType().cast<RankedTensorType>();
+  const int64_t input_batch =
+      lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()];
+  const int64_t kernel_output_features =
+      rhs_type
+          .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
+
+  output_dimensions[op.dimension_numbers().getOutputBatchDimension()] =
+      isDynamicDimSize(input_batch) ? ShapedType::kDynamicSize
+                                    : input_batch / op.batch_group_count();
+  output_dimensions[op.dimension_numbers().getOutputFeatureDimension()] =
+      kernel_output_features;
+
+  return output_dimensions;
+}
+}  // namespace
+
+/*
+ * We intend to verify the following properties
+ *  P1. Verify the input, kernel types.
+ *  P2. Verify the convolution atributes.
+ *  P3. Verify and collect the window atributes.
+ *  P4. Verify the return shape.
+ *      TODO(b/232574102): Verify the element-type of return-value.
+ */
+LogicalResult ConvOp::verify() {
+  auto lhs_type = lhs().getType().dyn_cast<RankedTensorType>();
+  auto rhs_type = rhs().getType().dyn_cast<RankedTensorType>();
+
+  if (!lhs_type || !rhs_type) return success();
+
+  // P1.
+  int num_dims = lhs_type.getRank();
+  if (num_dims != rhs_type.getRank())
+    return emitOpError()
+           << "expects convolution arguments to have same number of "
+              "dimensions. Got: "
+           << lhs_type << " and " << rhs_type << ".";
+
+  if (num_dims < 2)
+    return emitOpError()
+           << "expects convolution arguments to have >= 2 dimensions. "
+              "Got: "
+           << lhs_type << " and " << rhs_type << ".";
+
+  // P2.
+  if (failed(verifyConvolutionAttributes(*this))) return failure();
+
+  // P3.
+  auto kernel_spatial_dimensions =
+      dimension_numbers().getKernelSpatialDimensions();
+  SmallVector<int64_t> window_dimensions(kernel_spatial_dimensions.size());
+  for (size_t i = 0; i < window_dimensions.size(); i++)
+    window_dimensions[i] = rhs_type.getShape()[kernel_spatial_dimensions[i]];
+
+  auto padding_or_err = convertNx2Attribute(this->padding(), getLoc());
+  if (failed(padding_or_err)) return failure();
+  SmallVector<std::pair<int64_t, int64_t>> padding = *padding_or_err;
+
+  auto window_or_err = verifyWindowAttributesAndInferWindowDimensions(
+      window_dimensions, convertDenseIntAttr(window_strides()), padding,
+      convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()),
+      getLoc());
+  if (failed(window_or_err)) return failure();
+
+  // P4.
+  auto actual_return_type = getResult().getType().cast<TensorType>();
+  auto actual_return_element_type = actual_return_type.getElementType();
+  if (!actual_return_type.hasRank()) return success();
+
+  auto actual_return_ranked_type = actual_return_type.cast<RankedTensorType>();
+  if (num_dims != actual_return_ranked_type.getRank())
+    return emitOpError() << "expects rank of convolution return-type to be "
+                            "equal to input-ranks ("
+                         << num_dims << "), but got "
+                         << actual_return_ranked_type.getRank() << ".";
+
+  auto expected_return_shape = inferConvOpReturnShape(*this, *window_or_err);
+  auto expected_return_type =
+      RankedTensorType::get(expected_return_shape, actual_return_element_type);
+  if (failed(verifyCompatibleShape(expected_return_type,
+                                   actual_return_ranked_type)))
+    return emitOpError()
+           << "has shape mismatch between the expected return-type ("
+           << expected_return_type << ") and actual return-type ("
+           << actual_return_ranked_type << ").";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // ConvertOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 0b25655..c2182e9 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -2055,28 +2055,26 @@
           op, "non-one lhs- dialation unsupported yet");
     }
 
-    if (const mhlo::ConvDimensionNumbersAttr& dimension_numbers =
-            op.dimension_numbers()) {
-      // Make sure that this is 2-D convolution.
-      const auto spatial_rank =
-          llvm::size(dimension_numbers.getInputSpatialDimensions());
-      if (spatial_rank != 2) {
-        return rewriter.notifyMatchFailure(op,
-                                           "only support 2-D cases for now");
-      }
+    const mhlo::ConvDimensionNumbersAttr& dimension_numbers =
+        op.dimension_numbers();
+    // Make sure that this is 2-D convolution.
+    const auto spatial_rank =
+        llvm::size(dimension_numbers.getInputSpatialDimensions());
+    if (spatial_rank != 2) {
+      return rewriter.notifyMatchFailure(op, "only support 2-D cases for now");
+    }
 
-      // Make sure that this is depthwise convolution.
-      int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension();
-      int64_t input_feature_count =
-          op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
-      if (op.feature_group_count() != input_feature_count) {
-        return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
-      }
+    // Make sure that this is depthwise convolution.
+    int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension();
+    int64_t input_feature_count =
+        op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
+    if (op.feature_group_count() != input_feature_count) {
+      return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
+    }
 
-      // Make sure that this convolution has a canonical form.
-      if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
-        return rewriter.notifyMatchFailure(op, "does not have canonical form");
-      }
+    // Make sure that this convolution has a canonical form.
+    if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
+      return rewriter.notifyMatchFailure(op, "does not have canonical form");
     }
 
     DenseIntElementsAttr window_strides;
@@ -2127,10 +2125,38 @@
       return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
     };
 
-    if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
+    int64_t kernel_input_feature_dimension =
+        dimension_numbers.getKernelInputFeatureDimension();
+    int64_t kernel_output_feature_dimension =
+        dimension_numbers.getKernelOutputFeatureDimension();
+    if (filter_dims[kernel_input_feature_dimension] *
+            filter_dims[kernel_output_feature_dimension] !=
+        op.feature_group_count()) {
       // For cases where channel multiplier != 1
+
+      // Reshaping filter shape
+      //   [filter_height, filter_width, 1, kernel-output-feature].
+      // to
+      //   [filter_height, filter_width, feature_group_count,
+      //      kernel-output-feature/feature_group_count ]
+      SmallVector<int64_t> reshaped_filter_dims;
+      reshaped_filter_dims.assign(filter_dims.begin(), filter_dims.end());
+      auto reshaped_filter = filter;
+      if (filter_dims[kernel_input_feature_dimension] == 1) {
+        reshaped_filter_dims[kernel_input_feature_dimension] =
+            op.feature_group_count();
+        reshaped_filter_dims[kernel_output_feature_dimension] /=
+            op.feature_group_count();
+        auto reshaped_filter_type = RankedTensorType::get(
+            reshaped_filter_dims,
+            op.rhs().getType().cast<RankedTensorType>().getElementType());
+
+        reshaped_filter =
+            rewriter.create<mhlo::ReshapeOp>(loc, reshaped_filter_type, filter);
+      }
+
       auto output_dims = result_type.getShape();
-      auto channel_multiplier = filter_dims[3];
+      auto channel_multiplier = reshaped_filter_dims[3];
       SmallVector<int64_t> reshaped_output_dims;
       reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
       reshaped_output_dims.push_back(channel_multiplier);
@@ -2143,7 +2169,7 @@
       auto reshaped_output_type = RankedTensorType::get(
           reshaped_output_dims, result_type.getElementType());
       auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
-          op.getLoc(), reshaped_output_type, ValueRange{input, filter},
+          loc, reshaped_output_type, ValueRange{input, reshaped_filter},
           ValueRange{zero_tensor}, window_strides, rhs_dilation,
           PruneAttributeList(op));
 
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir
new file mode 100644
index 0000000..c4a2eaf
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir
@@ -0,0 +1,809 @@
+// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// -----
+
+// Valid: Generic convolution
+
+func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32> {
+  %result = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      input_batch_dimension = 0,
+      input_feature_dimension = 3,
+      input_spatial_dimensions = [1, 2],
+      kernel_input_feature_dimension = 3,
+      kernel_output_feature_dimension = 2,
+      kernel_spatial_dimensions = [0, 1],
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+      output_spatial_dimensions = [1, 2]
+    >,
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<2x2xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32>
+  func.return %result : tensor<100x28x28x1xf32>
+}
+
+// Valid: Test convolution i8xi8 -> i32.
+
+func.func @convolution_upcast(%arg0 : tensor<100x26x26x32xi8>,
+    %arg1 : tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> {
+  %result = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      input_batch_dimension = 0,
+      input_feature_dimension = 3,
+      input_spatial_dimensions = [1, 2],
+      kernel_input_feature_dimension = 3,
+      kernel_output_feature_dimension = 2,
+      kernel_spatial_dimensions = [0, 1],
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+      output_spatial_dimensions = [1, 2]
+    >,
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<2x2xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32>
+  func.return %result : tensor<100x28x28x1xi32>
+}
+
+// Valid: Empty spatial dimensions
+
+// CHECK: func @conv_empty_spatial_dimensions
+// CHECK: mhlo.convolution
+// CHECK-SAME: dim_numbers = [b, f]x[i, o]->[b, f]
+// CHECK-SAME: window = {stride = [], pad = [], lhs_dilate = [],
+// CHECK-SAME: rhs_dilate = [], reverse = []}
+func.func @conv_empty_spatial_dimensions(%arg0: tensor<3x2xf16>,
+    %arg1: tensor<2x2xf16>) -> tuple<tensor<3x2xf16>> {
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, f]x[i, o]->[b, f],
+         window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [],
+           reverse = []}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         }
+       : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16>
+  %1 = "mhlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple<tensor<3x2xf16>>
+  func.return %1 : tuple<tensor<3x2xf16>>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<2x4x5x2xf32>,
+                     %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> {
+  // expected-error@+1 {{expects input dimension-numbers to be unique, got {0, 0}.}}
+  %1 = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      kernel_input_feature_dimension = 2,
+      kernel_output_feature_dimension = 3,
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+    >,
+    feature_group_count = 2 : i64,
+    someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) ->
+      tensor<2x3x4x6xf32>
+  func.return %1 : tensor<2x3x4x6xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects convolution arguments to have same number of dimensions. Got: 'tensor<1x8x8x207xf32>' and 'tensor<3x3x207xf32>'.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>)
+    -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects convolution arguments to have >= 2 dimensions. Got: 'tensor<1xf32>' and 'tensor<3xf32>'.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1xf32>, tensor<3xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 3, 2, and 2 resp.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, 2, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 3, and 2 resp.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, 2, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 2, and 3 resp.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, 2, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>,
+    %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
+  // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}}
+  %result = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      input_batch_dimension = 0,
+      input_feature_dimension = 3,
+      input_spatial_dimensions = [1, 4],
+      kernel_input_feature_dimension = 3,
+      kernel_output_feature_dimension = 2,
+      kernel_spatial_dimensions = [0, 1],
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+      output_spatial_dimensions = [1, 2]
+    >,
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<2x2xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32>
+  func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>,
+    %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
+  // expected-error@+1 {{expects kernel dimension-numbers to be unique, got {0, 2, 3, 3}.}}
+  %result = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      input_batch_dimension = 0,
+      input_feature_dimension = 3,
+      input_spatial_dimensions = [1, 2],
+      kernel_input_feature_dimension = 3,
+      kernel_output_feature_dimension = 2,
+      kernel_spatial_dimensions = [0, 0],
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+      output_spatial_dimensions = [1, 2]
+    >,
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<2x2xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32>
+  func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>,
+    %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
+  // expected-error@+1 {{expects output dimension-numbers to be unique, got {0, 3, 3, 3}.}}
+  %result = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      input_batch_dimension = 0,
+      input_feature_dimension = 3,
+      input_spatial_dimensions = [1, 2],
+      kernel_input_feature_dimension = 3,
+      kernel_output_feature_dimension = 2,
+      kernel_spatial_dimensions = [0, 1],
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+      output_spatial_dimensions = [0, 3]
+    >,
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<2x2xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32>
+  func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{op expects batch_group_count to be a positive number, got 0.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 0 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{op expects feature_group_count to be a positive number, got 0.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 0 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects batch_group_count and feature_group_count not to be both greater than 1. Got 2 and 2 resp.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 2 : i64,
+           feature_group_count = 2 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects output feature dimension size (16) to be a multiple of batch_group_count. Got batch_group_count = 3.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 3 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects input feature dimension (207) to be a multiple of feature_group_count. Got feature_group_count = 2.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 2 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects input feature dimension (207) / feature_group_count = kernel input feature dimension (20). Got feature_group_count = 1.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects kernel output feature dimension (16) to be divisible by feature_group_count. For feature_group_count = 3.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 3 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<5x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects input batch dimension (5) to be divisible by batch_group_count. Got batch_group_count = 2.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 2 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<5x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects window-strides to have same dimension-size as size of window dimensions (2), but got: 1.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects base-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects window-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 1.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         // expected-error@+1 {{Expected array with 2 elements, got 4 elements instead}}
+         window = {stride = [1, 1], pad = [[1, 1, 1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32> {
+  // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 3.}}
+  %result = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      input_batch_dimension = 0,
+      input_feature_dimension = 3,
+      input_spatial_dimensions = [1, 2],
+      kernel_input_feature_dimension = 3,
+      kernel_output_feature_dimension = 2,
+      kernel_spatial_dimensions = [0, 1],
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+      output_spatial_dimensions = [1, 2]
+    >,
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<6xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32>
+  func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32> {
+  // expected-error@+1 {{expects the padding-entries to have even number of elements, but got 5 elements.}}
+  %result = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = #mhlo.conv<raw
+      input_batch_dimension = 0,
+      input_feature_dimension = 3,
+      input_spatial_dimensions = [1, 2],
+      kernel_input_feature_dimension = 3,
+      kernel_output_feature_dimension = 2,
+      kernel_spatial_dimensions = [0, 1],
+      output_batch_dimension = 0,
+      output_feature_dimension = 3,
+      output_spatial_dimensions = [1, 2]
+    >,
+    feature_group_count = 1 : i64,
+    lhs_dilation = dense<1> : tensor<2xi64>,
+    padding = dense<2> : tensor<5xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+    tensor<100x28x28x1xf32>
+  func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects window to have positive value for 0-th window dimension, but got 0.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+       (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects window to have positive stride for 1-th window dimension, but got 0.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 0], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects window to have positive base dilation factor for 0-th window dimension, but got 0.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [0, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+  // expected-error@+1 {{expects window to have positive window dilation factor for 0-th window dimension, but got 0.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [0, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+  func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+// Invalid rank of output-type.
+
+func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> {
+  // expected-error @+1 {{expects rank of convolution return-type to be equal to input-ranks (4), but got 3.}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32>
+  func.return %0 : tensor<1x8x16xf32>
+}
+
+// -----
+
+// Invalid batch dimension in output-type. Should be equal to
+// input-batch-dimension / batch_group_count.
+
+func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> {
+  // expected-error@+1 {{nvolution' op has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<2x8x8x16xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32>
+  func.return %0 : tensor<2x8x8x16xf32>
+}
+
+// -----
+
+// Invalid feature dimension in output-type. Should be equal to
+// kernel_output_feature_dimension.
+
+func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> {
+  // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x8x8x32xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32>
+  func.return %0 : tensor<1x8x8x32xf32>
+}
+
+// -----
+
+// The following tests checks the inferred output-type of ConvOp. We
+// deliberately put an invalid output-type in these tests so that the
+// inffered-type can be highlighted in the error message.
+
+// Dynamic input-batch-dimension
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<?x8x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> {
+  // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<?x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32>
+  func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic input-feature-dimension: No effect on output dimensions.
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x?xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> {
+  // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x?xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32>
+  func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic input-spatial-dimension
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x?x8x207xf32>,
+    %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> {
+  // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x?x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x?x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32>
+  func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic kernel-input-feature-dimension: No effect on output dimensions.
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> {
+  // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32>
+  func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic kernel-output-feature-dimension
+func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> {
+  // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x?xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32>
+  func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic kernel-spatial-dimension
+func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>,
+    %arg1: tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> {
+  // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x?x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+  %0 = mhlo.convolution(%arg0, %arg1)
+         dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+         window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+           lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+         {
+           batch_group_count = 1 : i64,
+           feature_group_count = 1 : i64,
+           precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+         } :
+       (tensor<1x8x8x207xf32>, tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32>
+  func.return %0 : tensor<1x1x1x1xf32>
+}
+
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir
index 08bb2f3..db142cd 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir
@@ -473,10 +473,10 @@
 // -----
 
 // CHECK-LABEL: func @conv
-func.func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
-    -> tensor<3x5x5x4xf32> {
+func.func @conv(%input: tensor<3x2x4x3xf32>, %filter : tensor<2x2x3x4xf32>)
+    -> tensor<2x1x2x3xf32> {
   %c0 = arith.constant 0 : index
-  // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
+  // CHECK: %[[OUT:.*]] = memref.alloc() : memref<2x1x2x3xf32>
   // CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]])
   // CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
   %out = "mhlo.convolution"(%filter, %input) {
@@ -496,8 +496,8 @@
     padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
     rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
     window_strides = dense<[2, 1]> : tensor<2xi64>
-  } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
-  func.return %out : tensor<3x5x5x4xf32>
+  } : (tensor<2x2x3x4xf32>, tensor<3x2x4x3xf32>) -> tensor<2x1x2x3xf32>
+  func.return %out : tensor<2x1x2x3xf32>
 }
 
 // -----
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index b150f5f..e7e353c 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -2863,7 +2863,7 @@
       output_spatial_dimensions = [1]
     >,
     feature_group_count = 1 : i64,
-    padding = dense<[[0], [0]]> : tensor<2x1xi64>,
+    padding = dense<[[0, 0]]> : tensor<1x2xi64>,
     rhs_dilation = dense<1> : tensor<1xi64>,
     window_strides = dense<1> : tensor<1xi64>,
     someattr
@@ -2890,7 +2890,7 @@
 // -----
 
 func.func @conv_2d_nhwc_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3x2x?x?xf32>)
-  -> tensor<?x2x3x?xf32> {
+  -> tensor<?x2x4x?xf32> {
   %0 = "mhlo.convolution"(%arg0, %arg1) {
     batch_group_count = 1 : i64,
     dimension_numbers = #mhlo.conv<raw
@@ -2908,8 +2908,8 @@
     padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
     rhs_dilation = dense<1> : tensor<2xi64>,
     window_strides = dense<1> : tensor<2xi64>
-  } : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x3x?xf32>
-  func.return %0 : tensor<?x2x3x?xf32>
+  } : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x4x?xf32>
+  func.return %0 : tensor<?x2x4x?xf32>
 }
 // CHECK-LABEL: func @conv_2d_nhwc_hwcf
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
@@ -2918,14 +2918,14 @@
 // CHECK:         %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
 // CHECK:         %[[C3:.+]] = arith.constant 3 : index
 // CHECK:         %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
-// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 4, %[[DIM3]]]
 // CHECK:         %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]]
 // CHECK:         linalg.conv_2d_nhwc
 // CHECK-SAME:      {dilations = dense<1> : tensor<2xi64>
 // CHECK-SAME:       strides = dense<1> : tensor<2xi64>}
 // CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>)
-// CHECK-SAME:    outs(%[[FILL]] : tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
+// CHECK-SAME:    outs(%[[FILL]] : tensor<?x2x4x?xf32>) -> tensor<?x2x4x?xf32>
 
 // -----
 
@@ -2945,7 +2945,7 @@
       output_spatial_dimensions = [1, 2, 3]
     >,
     feature_group_count = 1 : i64,
-    padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
+    padding = dense<[[0, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>,
     rhs_dilation = dense<1> : tensor<3xi64>,
     window_strides = dense<1> : tensor<3xi64>
   } : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32>
@@ -3028,25 +3028,25 @@
 // CHECK-LABEL: func @linalg.conv_2D_padding_test2
 // CHECK-SAME: (%[[FILTER:.*]]: tensor<1x33x1x1xf16>, %[[INPUT:.*]]: tensor<400x1024x1024x1xf16>)
 func.func @linalg.conv_2D_padding_test2(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<400x1024x1024x1xf16>)
-  -> tensor<400x1024x1024x1xf16> {
-  %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1024x1024x1xf16>)
-  func.return %0 : tensor<400x1024x1024x1xf16>
+  -> tensor<400x1040x1024x1xf16> {
+  %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1040x1024x1xf16>)
+  return %0 : tensor<400x1040x1024x1xf16>
 }
-// CHECK-NEXT: %[[INIT:.*]] = linalg.init_tensor [400, 1024, 1024, 1] : tensor<400x1024x1024x1xf16>
+// CHECK-NEXT: %[[INIT:.*]] = linalg.init_tensor [400, 1040, 1024, 1] : tensor<400x1040x1024x1xf16>
 // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
-// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16>
+// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16>
 // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
 // CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 8, 16, 0] high[0, 8, 16, 0]  {
 // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
 // CHECK-NEXT:   tensor.yield %[[ZERO]] : f16
 // CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1040x1056x1xf16>
-// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%1 : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16>
-// CHECK-NEXT: return %[[RESULT]] : tensor<400x1024x1024x1xf16>
+// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%1 : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16>
+// CHECK-NEXT: return %[[RESULT]] : tensor<400x1040x1024x1xf16>
 
 // -----
 
 func.func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>,
-                     %arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32> {
+                     %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> {
   %0 = "mhlo.convolution"(%arg0, %arg1) {
     batch_group_count = 1 : i64,
     dimension_numbers = #mhlo.conv<raw
@@ -3064,20 +3064,24 @@
     padding = dense<0> : tensor<2x2xi64>,
     rhs_dilation = dense<1> : tensor<2xi64>,
     window_strides = dense<1> : tensor<2xi64>,
-    someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
+    someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32>
   func.return %0 : tensor<2x3x4x6xf32>
 }
 // CHECK:      func @depthwise_conv
 // CHECK-SAME:   %[[IN:[a-zA-Z0-9_]*]]
 // CHECK-SAME:   %[[FILTER:[a-zA-Z0-9_]*]]
-// CHECK:        %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
-// CHECK:        %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
-// CHECK:        %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm
+
+// CHECK:       %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2, 3]] : tensor<2x2x1x6xf32> into tensor<24xf32>
+// CHECK:       %[[CAST:.+]] = tensor.cast %[[COLLAPSE]] : tensor<24xf32> to tensor<24xf32>
+// CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[CAST]] {{\[}}[0, 1, 2, 3]] : tensor<24xf32> into tensor<2x2x2x3xf32>
+// CHECK:       %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
+// CHECK:       %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
+// CHECK:       %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm
 // CHECK-SAME:     {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>}
-// CHECK-SAME:     ins(%[[IN]], %[[FILTER]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
+// CHECK-SAME:     ins(%[[IN]], %[[EXPAND]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
-// CHECK:        %{{.+}} = tensor.collapse_shape %[[OUT]]
+// CHECK:       %{{.+}} = tensor.collapse_shape %[[OUT]]
 // CHECK-SAME:     [0], [1], [2], [3, 4]
 // CHECK-SAME:     : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32>
 
@@ -3085,7 +3089,7 @@
 
 func.func @depthwise_conv_with_padding(
     %arg0: tensor<2x4x5x2xf32>,
-    %arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x6x6xf32> {
+    %arg1: tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> {
   %0 = "mhlo.convolution"(%arg0, %arg1) {
     batch_group_count = 1 : i64,
     dimension_numbers = #mhlo.conv<raw
@@ -3103,8 +3107,8 @@
     padding = dense<[[0, 0], [1, 1]]> : tensor<2x2xi64>,
     rhs_dilation = dense<1> : tensor<2xi64>,
     window_strides = dense<1> : tensor<2xi64>,
-    someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x6x6xf32>
-  func.return %0 : tensor<2x3x6x6xf32>
+    someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32>
+  func.return %0 : tensor<2x3x6x4xf32>
 }
 // CHECK:      func @depthwise_conv_with_padding
 // CHECK-SAME:   %[[IN:[a-zA-Z0-9_]*]]
@@ -3113,17 +3117,24 @@
 // CHECK:        %[[PAD:.*]] = tensor.pad %[[IN]] low[0, 0, 1, 0] high[0, 0, 1, 0]  {
 // CHECK:        ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
 // CHECK:          tensor.yield %[[ZERO]] : f32
-// CHECK         } : tensor<2x4x5x2xf32> to tensor<2x4x7x1xf32>
-// CHECK:        %[[INIT:.+]] = linalg.init_tensor [2, 3, 6, 2, 3] : tensor<2x3x6x2x3xf32>
+// CHECK         } : tensor<2x4x5x2xf32> to tensor<2x4x7x2xf32>
+// CHECK:        %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]]
+// CHECK-SAME:    [0, 1, 2, 3]
+// CHECK-SAME:    : tensor<2x2x1x4xf32> into tensor<16xf32>
+// CHECK:       %[[CAST:.+]] = tensor.cast %[[COLLAPSE]] : tensor<16xf32> to tensor<16xf32>
+// CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[CAST]]
+// CHECK-SAME:   [0, 1, 2, 3]
+// CHECK-SAME:   tensor<16xf32> into tensor<2x2x2x2xf32>
+// CHECK:        %[[INIT:.+]] = linalg.init_tensor [2, 3, 6, 2, 2] : tensor<2x3x6x2x2xf32>
 // CHECK:        %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x3xf32>) -> tensor<2x3x6x2x3xf32>
+// CHECK:        %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32>
 // CHECK:        %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm
 // CHECK-SAME:     {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>}
-// CHECK-SAME:     ins(%[[PAD]], %[[FILTER]] : tensor<2x4x7x2xf32>, tensor<2x2x2x3xf32>)
-// CHECK-SAME:     outs(%[[FILL]] : tensor<2x3x6x2x3xf32>) -> tensor<2x3x6x2x3xf32>
+// CHECK-SAME:     ins(%[[PAD]], %[[EXPAND]] : tensor<2x4x7x2xf32>, tensor<2x2x2x2xf32>)
+// CHECK-SAME:     outs(%[[FILL]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32>
 // CHECK:        %{{.+}} = tensor.collapse_shape %[[OUT]]
 // CHECK-SAME:     [0], [1], [2], [3, 4]
-// CHECK-SAME:     : tensor<2x3x6x2x3xf32> into tensor<2x3x6x6xf32>
+// CHECK-SAME:     : tensor<2x3x6x2x2xf32> into tensor<2x3x6x4xf32>
 
 // -----
 
@@ -3166,7 +3177,7 @@
 
 func.func @depthwise_conv_multiplier_1_with_padding(
     %arg0: tensor<1x113x113x96xf32>,
-    %arg1: tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> {
+    %arg1: tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> {
   %0 = "mhlo.convolution"(%arg0, %arg1) {
     batch_group_count = 1 : i64,
     dimension_numbers = #mhlo.conv<raw
@@ -3183,8 +3194,8 @@
     feature_group_count = 96 : i64,
     padding = dense<[[1, 1], [2, 2]]> : tensor<2x2xi64>,
     rhs_dilation = dense<1> : tensor<2xi64>,
-    window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
-  func.return %0 : tensor<1x56x56x96xf32>
+    window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32>
+  func.return %0 : tensor<1x57x58x96xf32>
 }
 // CHECK:       func @depthwise_conv_multiplier_1_with_padding
 // CHECK-SAME:    %[[IN:[a-zA-Z0-9_]*]]
@@ -3194,16 +3205,16 @@
 // CHECK:         ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
 // CHECK:           tensor.yield %[[ZERO]] : f32
 // CHECK          } : tensor<1x113x113x96xf32> to tensor<1x115x117x96xf32>
-// CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 57, 58, 96] : tensor<1x57x58x96xf32>
 // CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32>
 // CHECK:         %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
 // CHECK-SAME:     [0], [1], [2, 3]
 // CHECK-SAME:     : tensor<3x3x1x96xf32> into tensor<3x3x96xf32>
 // CHECK:         %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:      {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
 // CHECK-SAME:       ins(%[[PAD]], %[[RESHAPED_FILTER]] : tensor<1x115x117x96xf32>, tensor<3x3x96xf32>)
-// CHECK-SAME:       outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
+// CHECK-SAME:       outs(%[[FILL]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32>
 
 // -----
 
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
index 1cf1109..5a11361 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
@@ -2541,13 +2541,13 @@
 // CHECK: mhlo.convolution
 // CHECK-SAME: dim_numbers = [b, 1, 0, f]x[0, 1, i, o]->[b, 0, 1, f]
 // CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
-func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> {
+func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x2x4x3xf32>) -> tensor<2x1x1x3xf32> {
   %0 = mhlo.convolution(%arg0, %arg1)
      dim_numbers = [b, 1, 0, f]x[0, 1, i, o]->[b, 0, 1, f],
      window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
      { batch_group_count = 1 : i64, feature_group_count = 1 : i64}
-  : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
-  func.return %0 : tensor<3x5x5x4xf32>
+  : (tensor<2x2x3x4xf32>, tensor<3x2x4x3xf32>) -> tensor<2x1x1x3xf32>
+  func.return %0 : tensor<2x1x1x3xf32>
 }
 
 // -----
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index 4e3b81c..3f3a949 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -1707,13 +1707,13 @@
 
 // CHECK-LABEL:   func @convert_depthwise_conv2d(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
+// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
 // CHECK:           %[[CST:.*]] = arith.constant dense<[3, 3, 207, 16]> : tensor<4xi64>
 // CHECK:           %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST]]) : (tensor<3x3x1x3312xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
-// CHECK:           %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
-// CHECK:           return %[[VAL_3]] : tensor<1x8x8x16xf32>
+// CHECK:           %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x3312xf32>
+// CHECK:           return %[[VAL_3]] : tensor<1x8x8x3312xf32>
 // CHECK:         }
-func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
+func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
   %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
     dimension_numbers = #mhlo.conv<raw
       input_batch_dimension = 0,
@@ -1726,8 +1726,8 @@
       output_feature_dimension = 3,
       output_spatial_dimensions = [1, 2]
     >, feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
-       (tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32>
-  func.return %0 : tensor<1x8x8x16xf32>
+       (tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32>
+  func.return %0 : tensor<1x8x8x3312xf32>
 }
 
 // CHECK-LABEL:   func @convert_conv2d_to_resize(
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
index 915e5fb..9159aa8 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
@@ -4346,7 +4346,7 @@
 func.func @conv2d_backprop_filter(
     %input: tensor<100x28x28x1xf32>,
     %out_backprop: tensor<100x26x26x32xf32>
-  ) -> tensor<100x28x28x1xf32> {
+  ) -> tensor<3x3x1x32xf32> {
   // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
   // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f]
   // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
@@ -4361,8 +4361,8 @@
     padding = "VALID",
     strides = [1, 1, 1, 1],
     use_cudnn_on_gpu = true
-  } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
-  func.return %result : tensor<100x28x28x1xf32>
+  } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32>
+  func.return %result : tensor<3x3x1x32xf32>
 }
 
 // -----
@@ -4391,7 +4391,7 @@
 
 
 // CHECK-LABEL: @conv3d_backprop_filter
-func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
+func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> {
   // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
   // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f]
   // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
@@ -4399,8 +4399,8 @@
   // CHECK-SAME: feature_group_count = 1 : i64
   // CHECK: return %[[RESULT]]
   %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32>
-  %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
-  func.return %result : tensor<2x8x8x8x1xf32>
+  %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32>
+  func.return %result : tensor<3x3x3x1x6xf32>
 }
 
 // -----
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 536125f..3bb549a 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -4460,7 +4460,7 @@
 func.func @conv2d_backprop_filter(
     %input: tensor<100x28x28x1xf32>,
     %out_backprop: tensor<100x26x26x32xf32>
-  ) -> tensor<100x28x28x1xf32> {
+  ) -> tensor<3x3x1x32xf32> {
   // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
   // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f]
   // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
@@ -4475,8 +4475,8 @@
     padding = "VALID",
     strides = [1, 1, 1, 1],
     use_cudnn_on_gpu = true
-  } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
-  func.return %result : tensor<100x28x28x1xf32>
+  } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32>
+  func.return %result : tensor<3x3x1x32xf32>
 }
 
 // -----
@@ -4505,7 +4505,7 @@
 
 
 // CHECK-LABEL: @conv3d_backprop_filter
-func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
+func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> {
   // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
   // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f]
   // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
@@ -4513,8 +4513,8 @@
   // CHECK-SAME: feature_group_count = 1 : i64
   // CHECK: return %[[RESULT]]
   %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32>
-  %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
-  func.return %result : tensor<2x8x8x8x1xf32>
+  %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32>
+  func.return %result : tensor<3x3x3x1x6xf32>
 }
 
 // -----
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index 731da33..13e70ce 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -256,15 +256,15 @@
 // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
 // implementations with attributes, etc.
 // CHECK-LABEL: func private @test_conv(
-// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<256x32x32x1xf32>) -> tuple<tensor<256x30x30x1xf32>> {
 %test_conv {
-  %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
+  %arg0.1 = f32[256,32,32,1]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
 
-  // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {xla_shape = "f32[256,32,32,6]{2,1,3,0}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
-  %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
+  // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32>
+  %copy.1 = f32[256,32,32,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
 
-  // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {xla_shape = "f32[256,32,32,6]{2,1,3,0}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
-  %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
+  // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32>
+  %reshape.2 = f32[256,32,32,1]{2,1,3,0} reshape(%copy.1)
 
   // Note that double brackets "[[" have to be escaped as they denote variables
   // in FileCheck. The only way to do so is to drop into regex with "{{"
@@ -276,15 +276,15 @@
   // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[44, 45], [60, 60]], lhs_dilate = [1, 1], rhs_dilate = [2, 3]}
   // CHECK-SAME:     feature_group_count = 1 : i64
   // CHECK-SAME:     precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
-  // CHECK-SAME:   (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<16x30x30x256xf32>
+  // CHECK-SAME:   (tensor<256x32x32x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x30x30x256xf32>
 
-  %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
+  %convolution.4 = f32[1,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
 
-  // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
-  %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
+  // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<1x30x30x256xf32>) -> tensor<256x30x30x1xf32>
+  %reshape.5 = f32[256,30,30,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
 
-  // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) {xla_shape = {{.*}}} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
-  ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
+  // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) {xla_shape = {{.*}}} : (tensor<256x30x30x1xf32>) -> tuple<tensor<256x30x30x1xf32>>
+  ROOT %tuple.6 = (f32[256,30,30,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
 }
 
 // Test for padding attribute shape in convolution