[mhlo] Verifier for mhlo.ConvOp
PiperOrigin-RevId: 448608547
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 58014bd..d9b27a6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1535,6 +1535,7 @@
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
+ let hasVerifier = 1;
code extraClassDeclaration = [{
bool hasWindowReversal() {
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index e561d7d..4959759 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -1792,6 +1792,311 @@
}
//===----------------------------------------------------------------------===//
+// ConvOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Checks:
+// P1. Same sizes for input, kernel and output spatial_dims.
+// P2. Spatial and non-spatial dimentions (for input,kernel, &output) should
+// be unique and in range [0, num_dims), where num_dims = rank of input
+// (lhs/rhs) tensors.
+//
+// Note that the spatial + non-spatial dimensions may not cover all the
+// dimensions in the range [0,num) because of the presence of 'unknown'
+// dimensions (ref. cl/415132294).
+LogicalResult isSpatialDimensionsValid(ConvOp op) {
+ auto input_spatial_dimensions =
+ op.dimension_numbers().getInputSpatialDimensions();
+ auto kernel_spatial_dimensions =
+ op.dimension_numbers().getKernelSpatialDimensions();
+ auto output_spatial_dimensions =
+ op.dimension_numbers().getOutputSpatialDimensions();
+
+ // P1.
+ if ((input_spatial_dimensions.size() != kernel_spatial_dimensions.size()) ||
+ (input_spatial_dimensions.size() != output_spatial_dimensions.size()))
+ return op.emitOpError() << "expects the same size for input, kernel and "
+ "output spatial-dimensions, but got "
+ << input_spatial_dimensions.size() << ", "
+ << kernel_spatial_dimensions.size() << ", and "
+ << output_spatial_dimensions.size() << " resp.";
+
+ // P2.
+ SmallVector<int64_t> input_dnums(input_spatial_dimensions.size() + 2);
+ input_dnums[0] = op.dimension_numbers().getInputBatchDimension();
+ input_dnums[1] = op.dimension_numbers().getInputFeatureDimension();
+ std::copy(input_spatial_dimensions.begin(), input_spatial_dimensions.end(),
+ input_dnums.begin() + 2);
+
+ SmallVector<int64_t> window_dnums(kernel_spatial_dimensions.size() + 2);
+ window_dnums[0] = op.dimension_numbers().getKernelInputFeatureDimension();
+ window_dnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension();
+ std::copy(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(),
+ window_dnums.begin() + 2);
+
+ SmallVector<int64_t> output_dnums(output_spatial_dimensions.size() + 2);
+ output_dnums[0] = op.dimension_numbers().getOutputBatchDimension();
+ output_dnums[1] = op.dimension_numbers().getOutputFeatureDimension();
+ std::copy(output_spatial_dimensions.begin(), output_spatial_dimensions.end(),
+ output_dnums.begin() + 2);
+
+ auto num_dims = op.lhs().getType().cast<RankedTensorType>().getRank();
+ const auto in_range = [num_dims](int64_t i) {
+ return 0 <= i && i < num_dims;
+ };
+
+ if (!llvm::all_of(input_dnums, in_range) ||
+ !llvm::all_of(window_dnums, in_range) ||
+ !llvm::all_of(output_dnums, in_range))
+ return op.emitOpError() << "expects input, kernel, and output "
+ "dimension-numbers to be in-range [0, "
+ << num_dims << ").";
+
+ const auto has_duplicates = [](SmallVector<int64_t>& dnums) {
+ std::sort(dnums.begin(), dnums.end());
+ auto last = std::unique(dnums.begin(), dnums.end());
+ return last != dnums.end();
+ };
+
+ if (has_duplicates(input_dnums))
+ return op.emitOpError()
+ << "expects input dimension-numbers to be unique, got {"
+ << input_dnums << "}.";
+
+ if (has_duplicates(window_dnums))
+ return op.emitOpError()
+ << "expects kernel dimension-numbers to be unique, got {"
+ << window_dnums << "}.";
+
+ if (has_duplicates(output_dnums))
+ return op.emitOpError()
+ << "expects output dimension-numbers to be unique, got {"
+ << output_dnums << "}.";
+
+ return success();
+}
+
+// Verifies the following properties:
+// P1. The input, kernel, and output spatial-dimentions are valid.
+// P2. Given,
+// input-dimensions: b * input-spatial-dims * f
+// kernel-dimensions: kernel-spatial-dims * i * o
+// output-dimensions: b' * out-spatial-dims * f'
+// where b = input-batch-dims
+// where f = input-feature-dims
+// where i = kernel-input-feature-dims
+// where o = kernel-output-feature-dims
+// where b' = output-batch-dims
+// where f' = output-feature-dims
+// Check the following properties w.r.t feature_group_count (fgc) and
+// batch_group_count (bgc).
+// fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1)
+// b % bgc == 0
+// f % fgc == 0 and i = f / fgc
+// o (or f') % bgc == 0 and o (or f') % fgc == 0
+LogicalResult verifyConvolutionAttributes(ConvOp op) {
+ // P1.
+ if (failed(isSpatialDimensionsValid(op))) return failure();
+
+ // P2.
+ const int64_t feature_group_count = op.feature_group_count();
+ const int64_t batch_group_count = op.batch_group_count();
+
+ if (feature_group_count <= 0)
+ return op.emitOpError()
+ << "expects feature_group_count to be a positive number, got "
+ << feature_group_count << ".";
+
+ if (batch_group_count <= 0)
+ return op.emitOpError()
+ << "expects batch_group_count to be a positive number, got "
+ << batch_group_count << ".";
+
+ if (batch_group_count > 1 && feature_group_count > 1)
+ return op.emitOpError()
+ << "expects batch_group_count and feature_group_count not to be "
+ "both greater than 1. Got "
+ << batch_group_count << " and " << feature_group_count << " resp.";
+
+ auto lhs_type = op.lhs().getType().cast<RankedTensorType>();
+ const int64_t input_features =
+ lhs_type.getShape()[op.dimension_numbers().getInputFeatureDimension()];
+ const int64_t input_batch =
+ lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()];
+
+ auto rhs_type = op.rhs().getType().cast<RankedTensorType>();
+ const int64_t kernel_input_features =
+ rhs_type
+ .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()];
+ const int64_t kernel_output_features =
+ rhs_type
+ .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
+
+ if (!isDynamicDimSize(kernel_output_features)) {
+ if (kernel_output_features % batch_group_count != 0)
+ return op.emitOpError() << "expects output feature dimension size ("
+ << kernel_output_features
+ << ") to be a multiple of "
+ "batch_group_count. Got batch_group_count = "
+ << batch_group_count << ".";
+
+ if (kernel_output_features % feature_group_count != 0)
+ return op.emitOpError()
+ << "expects kernel output feature dimension ("
+ << kernel_output_features
+ << ") to be divisible by "
+ "feature_group_count. For feature_group_count = "
+ << feature_group_count << ".";
+ }
+
+ if (!isDynamicDimSize(input_features)) {
+ if (input_features % feature_group_count != 0)
+ return op.emitOpError()
+ << "expects input feature dimension (" << input_features
+ << ") to be a multiple of "
+ "feature_group_count. Got feature_group_count = "
+ << feature_group_count << ".";
+
+ if (!isDynamicDimSize(kernel_input_features) &&
+ input_features / feature_group_count != kernel_input_features)
+ return op.emitOpError()
+ << "expects input feature dimension (" << input_features
+ << ") / "
+ "feature_group_count = kernel input feature dimension ("
+ << kernel_input_features
+ << "). Got feature_group_count = " << feature_group_count << ".";
+ }
+
+ if (!isDynamicDimSize(input_batch) && input_batch % batch_group_count != 0)
+ return op.emitOpError() << "expects input batch dimension (" << input_batch
+ << ") to be divisible by "
+ "batch_group_count. Got batch_group_count = "
+ << batch_group_count << ".";
+
+ return success();
+}
+
+// Infer the return-shape of ConvOp.
+// Precondition:
+// 1. Input args to ConvOp 'op' are RankedTypes.
+// 2. rank-of(input-type) == rank-of(output-type)
+SmallVector<int64_t> inferConvOpReturnShape(
+ ConvOp op, const ArrayRef<WindowDimension> window) {
+ // We keep the 'unknown' dimensions (cl/415132294) as it is in the
+ // output-shape. To do that we initilize the output dimensions with the shape
+ // of the return-type and updates only the spatial + non-spatial dimensions.
+ // Precondition 2 ensures that size of output-shape == size of input-shape.
+ SmallVector<int64_t> output_dimensions =
+ to_vector(op.getResult().getType().cast<ShapedType>().getShape());
+
+ // Infer the output spatial dimensions.
+ auto lhs_type = op.lhs().getType().cast<RankedTensorType>();
+ auto input_spatial_dims = op.dimension_numbers().getInputSpatialDimensions();
+ auto num_spatial_dims = input_spatial_dims.size();
+ SmallVector<int64_t> input_spatial_dim_vals(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i)
+ input_spatial_dim_vals[i] = lhs_type.getShape()[input_spatial_dims[i]];
+
+ auto window_output_shape =
+ inferWindowOutputShape(input_spatial_dim_vals, window);
+
+ for (int i = 0; i < window.size(); ++i)
+ output_dimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] =
+ window_output_shape[i];
+
+ // Infer the output-batch-dimension and output-feature-dimension.
+ auto rhs_type = op.rhs().getType().cast<RankedTensorType>();
+ const int64_t input_batch =
+ lhs_type.getShape()[op.dimension_numbers().getInputBatchDimension()];
+ const int64_t kernel_output_features =
+ rhs_type
+ .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
+
+ output_dimensions[op.dimension_numbers().getOutputBatchDimension()] =
+ isDynamicDimSize(input_batch) ? ShapedType::kDynamicSize
+ : input_batch / op.batch_group_count();
+ output_dimensions[op.dimension_numbers().getOutputFeatureDimension()] =
+ kernel_output_features;
+
+ return output_dimensions;
+}
+} // namespace
+
+/*
+ * We intend to verify the following properties
+ * P1. Verify the input, kernel types.
+ * P2. Verify the convolution atributes.
+ * P3. Verify and collect the window atributes.
+ * P4. Verify the return shape.
+ * TODO(b/232574102): Verify the element-type of return-value.
+ */
+LogicalResult ConvOp::verify() {
+ auto lhs_type = lhs().getType().dyn_cast<RankedTensorType>();
+ auto rhs_type = rhs().getType().dyn_cast<RankedTensorType>();
+
+ if (!lhs_type || !rhs_type) return success();
+
+ // P1.
+ int num_dims = lhs_type.getRank();
+ if (num_dims != rhs_type.getRank())
+ return emitOpError()
+ << "expects convolution arguments to have same number of "
+ "dimensions. Got: "
+ << lhs_type << " and " << rhs_type << ".";
+
+ if (num_dims < 2)
+ return emitOpError()
+ << "expects convolution arguments to have >= 2 dimensions. "
+ "Got: "
+ << lhs_type << " and " << rhs_type << ".";
+
+ // P2.
+ if (failed(verifyConvolutionAttributes(*this))) return failure();
+
+ // P3.
+ auto kernel_spatial_dimensions =
+ dimension_numbers().getKernelSpatialDimensions();
+ SmallVector<int64_t> window_dimensions(kernel_spatial_dimensions.size());
+ for (size_t i = 0; i < window_dimensions.size(); i++)
+ window_dimensions[i] = rhs_type.getShape()[kernel_spatial_dimensions[i]];
+
+ auto padding_or_err = convertNx2Attribute(this->padding(), getLoc());
+ if (failed(padding_or_err)) return failure();
+ SmallVector<std::pair<int64_t, int64_t>> padding = *padding_or_err;
+
+ auto window_or_err = verifyWindowAttributesAndInferWindowDimensions(
+ window_dimensions, convertDenseIntAttr(window_strides()), padding,
+ convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()),
+ getLoc());
+ if (failed(window_or_err)) return failure();
+
+ // P4.
+ auto actual_return_type = getResult().getType().cast<TensorType>();
+ auto actual_return_element_type = actual_return_type.getElementType();
+ if (!actual_return_type.hasRank()) return success();
+
+ auto actual_return_ranked_type = actual_return_type.cast<RankedTensorType>();
+ if (num_dims != actual_return_ranked_type.getRank())
+ return emitOpError() << "expects rank of convolution return-type to be "
+ "equal to input-ranks ("
+ << num_dims << "), but got "
+ << actual_return_ranked_type.getRank() << ".";
+
+ auto expected_return_shape = inferConvOpReturnShape(*this, *window_or_err);
+ auto expected_return_type =
+ RankedTensorType::get(expected_return_shape, actual_return_element_type);
+ if (failed(verifyCompatibleShape(expected_return_type,
+ actual_return_ranked_type)))
+ return emitOpError()
+ << "has shape mismatch between the expected return-type ("
+ << expected_return_type << ") and actual return-type ("
+ << actual_return_ranked_type << ").";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ConvertOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 0b25655..c2182e9 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -2055,28 +2055,26 @@
op, "non-one lhs- dialation unsupported yet");
}
- if (const mhlo::ConvDimensionNumbersAttr& dimension_numbers =
- op.dimension_numbers()) {
- // Make sure that this is 2-D convolution.
- const auto spatial_rank =
- llvm::size(dimension_numbers.getInputSpatialDimensions());
- if (spatial_rank != 2) {
- return rewriter.notifyMatchFailure(op,
- "only support 2-D cases for now");
- }
+ const mhlo::ConvDimensionNumbersAttr& dimension_numbers =
+ op.dimension_numbers();
+ // Make sure that this is 2-D convolution.
+ const auto spatial_rank =
+ llvm::size(dimension_numbers.getInputSpatialDimensions());
+ if (spatial_rank != 2) {
+ return rewriter.notifyMatchFailure(op, "only support 2-D cases for now");
+ }
- // Make sure that this is depthwise convolution.
- int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension();
- int64_t input_feature_count =
- op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
- if (op.feature_group_count() != input_feature_count) {
- return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
- }
+ // Make sure that this is depthwise convolution.
+ int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension();
+ int64_t input_feature_count =
+ op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
+ if (op.feature_group_count() != input_feature_count) {
+ return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
+ }
- // Make sure that this convolution has a canonical form.
- if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
- return rewriter.notifyMatchFailure(op, "does not have canonical form");
- }
+ // Make sure that this convolution has a canonical form.
+ if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
+ return rewriter.notifyMatchFailure(op, "does not have canonical form");
}
DenseIntElementsAttr window_strides;
@@ -2127,10 +2125,38 @@
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
};
- if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
+ int64_t kernel_input_feature_dimension =
+ dimension_numbers.getKernelInputFeatureDimension();
+ int64_t kernel_output_feature_dimension =
+ dimension_numbers.getKernelOutputFeatureDimension();
+ if (filter_dims[kernel_input_feature_dimension] *
+ filter_dims[kernel_output_feature_dimension] !=
+ op.feature_group_count()) {
// For cases where channel multiplier != 1
+
+ // Reshaping filter shape
+ // [filter_height, filter_width, 1, kernel-output-feature].
+ // to
+ // [filter_height, filter_width, feature_group_count,
+ // kernel-output-feature/feature_group_count ]
+ SmallVector<int64_t> reshaped_filter_dims;
+ reshaped_filter_dims.assign(filter_dims.begin(), filter_dims.end());
+ auto reshaped_filter = filter;
+ if (filter_dims[kernel_input_feature_dimension] == 1) {
+ reshaped_filter_dims[kernel_input_feature_dimension] =
+ op.feature_group_count();
+ reshaped_filter_dims[kernel_output_feature_dimension] /=
+ op.feature_group_count();
+ auto reshaped_filter_type = RankedTensorType::get(
+ reshaped_filter_dims,
+ op.rhs().getType().cast<RankedTensorType>().getElementType());
+
+ reshaped_filter =
+ rewriter.create<mhlo::ReshapeOp>(loc, reshaped_filter_type, filter);
+ }
+
auto output_dims = result_type.getShape();
- auto channel_multiplier = filter_dims[3];
+ auto channel_multiplier = reshaped_filter_dims[3];
SmallVector<int64_t> reshaped_output_dims;
reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
reshaped_output_dims.push_back(channel_multiplier);
@@ -2143,7 +2169,7 @@
auto reshaped_output_type = RankedTensorType::get(
reshaped_output_dims, result_type.getElementType());
auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
- op.getLoc(), reshaped_output_type, ValueRange{input, filter},
+ loc, reshaped_output_type, ValueRange{input, reshaped_filter},
ValueRange{zero_tensor}, window_strides, rhs_dilation,
PruneAttributeList(op));
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir
new file mode 100644
index 0000000..c4a2eaf
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir
@@ -0,0 +1,809 @@
+// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// -----
+
+// Valid: Generic convolution
+
+func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32> {
+ %result = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ input_batch_dimension = 0,
+ input_feature_dimension = 3,
+ input_spatial_dimensions = [1, 2],
+ kernel_input_feature_dimension = 3,
+ kernel_output_feature_dimension = 2,
+ kernel_spatial_dimensions = [0, 1],
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ output_spatial_dimensions = [1, 2]
+ >,
+ feature_group_count = 1 : i64,
+ lhs_dilation = dense<1> : tensor<2xi64>,
+ padding = dense<2> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32>
+ func.return %result : tensor<100x28x28x1xf32>
+}
+
+// Valid: Test convolution i8xi8 -> i32.
+
+func.func @convolution_upcast(%arg0 : tensor<100x26x26x32xi8>,
+ %arg1 : tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> {
+ %result = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ input_batch_dimension = 0,
+ input_feature_dimension = 3,
+ input_spatial_dimensions = [1, 2],
+ kernel_input_feature_dimension = 3,
+ kernel_output_feature_dimension = 2,
+ kernel_spatial_dimensions = [0, 1],
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ output_spatial_dimensions = [1, 2]
+ >,
+ feature_group_count = 1 : i64,
+ lhs_dilation = dense<1> : tensor<2xi64>,
+ padding = dense<2> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<100x26x26x32xi8>, tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32>
+ func.return %result : tensor<100x28x28x1xi32>
+}
+
+// Valid: Empty spatial dimensions
+
+// CHECK: func @conv_empty_spatial_dimensions
+// CHECK: mhlo.convolution
+// CHECK-SAME: dim_numbers = [b, f]x[i, o]->[b, f]
+// CHECK-SAME: window = {stride = [], pad = [], lhs_dilate = [],
+// CHECK-SAME: rhs_dilate = [], reverse = []}
+func.func @conv_empty_spatial_dimensions(%arg0: tensor<3x2xf16>,
+ %arg1: tensor<2x2xf16>) -> tuple<tensor<3x2xf16>> {
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, f]x[i, o]->[b, f],
+ window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [],
+ reverse = []}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ }
+ : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16>
+ %1 = "mhlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple<tensor<3x2xf16>>
+ func.return %1 : tuple<tensor<3x2xf16>>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<2x4x5x2xf32>,
+ %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> {
+ // expected-error@+1 {{expects input dimension-numbers to be unique, got {0, 0}.}}
+ %1 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ kernel_input_feature_dimension = 2,
+ kernel_output_feature_dimension = 3,
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ >,
+ feature_group_count = 2 : i64,
+ someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) ->
+ tensor<2x3x4x6xf32>
+ func.return %1 : tensor<2x3x4x6xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects convolution arguments to have same number of dimensions. Got: 'tensor<1x8x8x207xf32>' and 'tensor<3x3x207xf32>'.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>)
+ -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects convolution arguments to have >= 2 dimensions. Got: 'tensor<1xf32>' and 'tensor<3xf32>'.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1xf32>, tensor<3xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 3, 2, and 2 resp.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, 2, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 3, and 2 resp.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, 2, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects the same size for input, kernel and output spatial-dimensions, but got 2, 2, and 3 resp.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, 2, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>,
+ %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
+ // expected-error@+1 {{expects input, kernel, and output dimension-numbers to be in-range [0, 4).}}
+ %result = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ input_batch_dimension = 0,
+ input_feature_dimension = 3,
+ input_spatial_dimensions = [1, 4],
+ kernel_input_feature_dimension = 3,
+ kernel_output_feature_dimension = 2,
+ kernel_spatial_dimensions = [0, 1],
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ output_spatial_dimensions = [1, 2]
+ >,
+ feature_group_count = 1 : i64,
+ lhs_dilation = dense<1> : tensor<2xi64>,
+ padding = dense<2> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32>
+ func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>,
+ %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
+ // expected-error@+1 {{expects kernel dimension-numbers to be unique, got {0, 2, 3, 3}.}}
+ %result = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ input_batch_dimension = 0,
+ input_feature_dimension = 3,
+ input_spatial_dimensions = [1, 2],
+ kernel_input_feature_dimension = 3,
+ kernel_output_feature_dimension = 2,
+ kernel_spatial_dimensions = [0, 0],
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ output_spatial_dimensions = [1, 2]
+ >,
+ feature_group_count = 1 : i64,
+ lhs_dilation = dense<1> : tensor<2xi64>,
+ padding = dense<2> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32>
+ func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>,
+ %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
+ // expected-error@+1 {{expects output dimension-numbers to be unique, got {0, 3, 3, 3}.}}
+ %result = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ input_batch_dimension = 0,
+ input_feature_dimension = 3,
+ input_spatial_dimensions = [1, 2],
+ kernel_input_feature_dimension = 3,
+ kernel_output_feature_dimension = 2,
+ kernel_spatial_dimensions = [0, 1],
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ output_spatial_dimensions = [0, 3]
+ >,
+ feature_group_count = 1 : i64,
+ lhs_dilation = dense<1> : tensor<2xi64>,
+ padding = dense<2> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32>
+ func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{op expects batch_group_count to be a positive number, got 0.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 0 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{op expects feature_group_count to be a positive number, got 0.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 0 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects batch_group_count and feature_group_count not to be both greater than 1. Got 2 and 2 resp.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 2 : i64,
+ feature_group_count = 2 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects output feature dimension size (16) to be a multiple of batch_group_count. Got batch_group_count = 3.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 3 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects input feature dimension (207) to be a multiple of feature_group_count. Got feature_group_count = 2.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 2 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects input feature dimension (207) / feature_group_count = kernel input feature dimension (20). Got feature_group_count = 1.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x20x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects kernel output feature dimension (16) to be divisible by feature_group_count. For feature_group_count = 3.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 3 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x69x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<5x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects input batch dimension (5) to be divisible by batch_group_count. Got batch_group_count = 2.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 2 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<5x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects window-strides to have same dimension-size as size of window dimensions (2), but got: 1.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects base-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects window-dilation factors to have same dimension-size as size of window dimensions (2), but got: 1.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 1.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ // expected-error@+1 {{Expected array with 2 elements, got 4 elements instead}}
+ window = {stride = [1, 1], pad = [[1, 1, 1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32> {
+ // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 3.}}
+ %result = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ input_batch_dimension = 0,
+ input_feature_dimension = 3,
+ input_spatial_dimensions = [1, 2],
+ kernel_input_feature_dimension = 3,
+ kernel_output_feature_dimension = 2,
+ kernel_spatial_dimensions = [0, 1],
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ output_spatial_dimensions = [1, 2]
+ >,
+ feature_group_count = 1 : i64,
+ lhs_dilation = dense<1> : tensor<2xi64>,
+ padding = dense<2> : tensor<6xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32>
+ func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32> {
+ // expected-error@+1 {{expects the padding-entries to have even number of elements, but got 5 elements.}}
+ %result = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = #mhlo.conv<raw
+ input_batch_dimension = 0,
+ input_feature_dimension = 3,
+ input_spatial_dimensions = [1, 2],
+ kernel_input_feature_dimension = 3,
+ kernel_output_feature_dimension = 2,
+ kernel_spatial_dimensions = [0, 1],
+ output_batch_dimension = 0,
+ output_feature_dimension = 3,
+ output_spatial_dimensions = [1, 2]
+ >,
+ feature_group_count = 1 : i64,
+ lhs_dilation = dense<1> : tensor<2xi64>,
+ padding = dense<2> : tensor<5xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<100x26x26x32xf32>, tensor<3x3x1x32xf32>) ->
+ tensor<100x28x28x1xf32>
+ func.return %result : tensor<100x28x28x1xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects window to have positive value for 0-th window dimension, but got 0.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+ (tensor<1x8x8x207xf32>, tensor<0x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects window to have positive stride for 1-th window dimension, but got 0.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 0], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects window to have positive base dilation factor for 0-th window dimension, but got 0.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [0, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
+ // expected-error@+1 {{expects window to have positive window dilation factor for 0-th window dimension, but got 0.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [0, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+ func.return %0 : tensor<1x8x8x16xf32>
+}
+
+// -----
+
+// Invalid rank of output-type.
+
+func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> {
+ // expected-error @+1 {{expects rank of convolution return-type to be equal to input-ranks (4), but got 3.}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1, 1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32>
+ func.return %0 : tensor<1x8x16xf32>
+}
+
+// -----
+
+// Invalid batch dimension in output-type. Should be equal to
+// input-batch-dimension / batch_group_count.
+
+func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> {
+ // expected-error@+1 {{nvolution' op has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<2x8x8x16xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32>
+ func.return %0 : tensor<2x8x8x16xf32>
+}
+
+// -----
+
+// Invalid feature dimension in output-type. Should be equal to
+// kernel_output_feature_dimension.
+
+func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> {
+ // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x8x8x32xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32>
+ func.return %0 : tensor<1x8x8x32xf32>
+}
+
+// -----
+
+// The following tests checks the inferred output-type of ConvOp. We
+// deliberately put an invalid output-type in these tests so that the
+// inffered-type can be highlighted in the error message.
+
+// Dynamic input-batch-dimension
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<?x8x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> {
+ // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<?x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32>
+ func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic input-feature-dimension: No effect on output dimensions.
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x?xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> {
+ // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x?xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32>
+ func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic input-spatial-dimension
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x?x8x207xf32>,
+ %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> {
+ // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x?x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x?x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32>
+ func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic kernel-input-feature-dimension: No effect on output dimensions.
+func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> {
+ // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32>
+ func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic kernel-output-feature-dimension
+func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> {
+ // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x?xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32>
+ func.return %0 : tensor<1x1x1x1xf32>
+}
+
+// -----
+
+// Dynamic kernel-spatial-dimension
+func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>,
+ %arg1: tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> {
+ // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x?x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}}
+ %0 = mhlo.convolution(%arg0, %arg1)
+ dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
+ window = {stride = [1, 1], pad = [[1, 1], [1,1]],
+ lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
+ {
+ batch_group_count = 1 : i64,
+ feature_group_count = 1 : i64,
+ precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
+ } :
+ (tensor<1x8x8x207xf32>, tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32>
+ func.return %0 : tensor<1x1x1x1xf32>
+}
+
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir
index 08bb2f3..db142cd 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir
@@ -473,10 +473,10 @@
// -----
// CHECK-LABEL: func @conv
-func.func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
- -> tensor<3x5x5x4xf32> {
+func.func @conv(%input: tensor<3x2x4x3xf32>, %filter : tensor<2x2x3x4xf32>)
+ -> tensor<2x1x2x3xf32> {
%c0 = arith.constant 0 : index
- // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
+ // CHECK: %[[OUT:.*]] = memref.alloc() : memref<2x1x2x3xf32>
// CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
%out = "mhlo.convolution"(%filter, %input) {
@@ -496,8 +496,8 @@
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>
- } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
- func.return %out : tensor<3x5x5x4xf32>
+ } : (tensor<2x2x3x4xf32>, tensor<3x2x4x3xf32>) -> tensor<2x1x2x3xf32>
+ func.return %out : tensor<2x1x2x3xf32>
}
// -----
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index b150f5f..e7e353c 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -2863,7 +2863,7 @@
output_spatial_dimensions = [1]
>,
feature_group_count = 1 : i64,
- padding = dense<[[0], [0]]> : tensor<2x1xi64>,
+ padding = dense<[[0, 0]]> : tensor<1x2xi64>,
rhs_dilation = dense<1> : tensor<1xi64>,
window_strides = dense<1> : tensor<1xi64>,
someattr
@@ -2890,7 +2890,7 @@
// -----
func.func @conv_2d_nhwc_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3x2x?x?xf32>)
- -> tensor<?x2x3x?xf32> {
+ -> tensor<?x2x4x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = #mhlo.conv<raw
@@ -2908,8 +2908,8 @@
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
- } : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x3x?xf32>
- func.return %0 : tensor<?x2x3x?xf32>
+ } : (tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>) -> tensor<?x2x4x?xf32>
+ func.return %0 : tensor<?x2x4x?xf32>
}
// CHECK-LABEL: func @conv_2d_nhwc_hwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
@@ -2918,14 +2918,14 @@
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
// CHECK: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
-// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 4, %[[DIM3]]]
// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]]
// CHECK: linalg.conv_2d_nhwc
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x4x5x?xf32>, tensor<3x2x?x?xf32>)
-// CHECK-SAME: outs(%[[FILL]] : tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x2x4x?xf32>) -> tensor<?x2x4x?xf32>
// -----
@@ -2945,7 +2945,7 @@
output_spatial_dimensions = [1, 2, 3]
>,
feature_group_count = 1 : i64,
- padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
+ padding = dense<[[0, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>,
rhs_dilation = dense<1> : tensor<3xi64>,
window_strides = dense<1> : tensor<3xi64>
} : (tensor<?x8x8x8x?xf32>, tensor<2x2x2x?x?xf32>) -> tensor<?x7x7x7x?xf32>
@@ -3028,25 +3028,25 @@
// CHECK-LABEL: func @linalg.conv_2D_padding_test2
// CHECK-SAME: (%[[FILTER:.*]]: tensor<1x33x1x1xf16>, %[[INPUT:.*]]: tensor<400x1024x1024x1xf16>)
func.func @linalg.conv_2D_padding_test2(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<400x1024x1024x1xf16>)
- -> tensor<400x1024x1024x1xf16> {
- %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1024x1024x1xf16>)
- func.return %0 : tensor<400x1024x1024x1xf16>
+ -> tensor<400x1040x1024x1xf16> {
+ %0 = mhlo.convolution(%arg1, %arg0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1040x1024x1xf16>)
+ return %0 : tensor<400x1040x1024x1xf16>
}
-// CHECK-NEXT: %[[INIT:.*]] = linalg.init_tensor [400, 1024, 1024, 1] : tensor<400x1024x1024x1xf16>
+// CHECK-NEXT: %[[INIT:.*]] = linalg.init_tensor [400, 1040, 1024, 1] : tensor<400x1040x1024x1xf16>
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
-// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16>
+// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16>
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
// CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 8, 16, 0] high[0, 8, 16, 0] {
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
// CHECK-NEXT: tensor.yield %[[ZERO]] : f16
// CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1040x1056x1xf16>
-// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%1 : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16>
-// CHECK-NEXT: return %[[RESULT]] : tensor<400x1024x1024x1xf16>
+// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%1 : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16>
+// CHECK-NEXT: return %[[RESULT]] : tensor<400x1040x1024x1xf16>
// -----
func.func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>,
- %arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32> {
+ %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = #mhlo.conv<raw
@@ -3064,20 +3064,24 @@
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>,
- someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
+ someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32>
func.return %0 : tensor<2x3x4x6xf32>
}
// CHECK: func @depthwise_conv
// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]]
-// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
-// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
-// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm
+
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2, 3]] : tensor<2x2x1x6xf32> into tensor<24xf32>
+// CHECK: %[[CAST:.+]] = tensor.cast %[[COLLAPSE]] : tensor<24xf32> to tensor<24xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[CAST]] {{\[}}[0, 1, 2, 3]] : tensor<24xf32> into tensor<2x2x2x3xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
+// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>}
-// CHECK-SAME: ins(%[[IN]], %[[FILTER]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
+// CHECK-SAME: ins(%[[IN]], %[[EXPAND]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
-// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]]
+// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]]
// CHECK-SAME: [0], [1], [2], [3, 4]
// CHECK-SAME: : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32>
@@ -3085,7 +3089,7 @@
func.func @depthwise_conv_with_padding(
%arg0: tensor<2x4x5x2xf32>,
- %arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x6x6xf32> {
+ %arg1: tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = #mhlo.conv<raw
@@ -3103,8 +3107,8 @@
padding = dense<[[0, 0], [1, 1]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>,
- someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x6x6xf32>
- func.return %0 : tensor<2x3x6x6xf32>
+ someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32>
+ func.return %0 : tensor<2x3x6x4xf32>
}
// CHECK: func @depthwise_conv_with_padding
// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]
@@ -3113,17 +3117,24 @@
// CHECK: %[[PAD:.*]] = tensor.pad %[[IN]] low[0, 0, 1, 0] high[0, 0, 1, 0] {
// CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
// CHECK: tensor.yield %[[ZERO]] : f32
-// CHECK } : tensor<2x4x5x2xf32> to tensor<2x4x7x1xf32>
-// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 6, 2, 3] : tensor<2x3x6x2x3xf32>
+// CHECK } : tensor<2x4x5x2xf32> to tensor<2x4x7x2xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]]
+// CHECK-SAME: [0, 1, 2, 3]
+// CHECK-SAME: : tensor<2x2x1x4xf32> into tensor<16xf32>
+// CHECK: %[[CAST:.+]] = tensor.cast %[[COLLAPSE]] : tensor<16xf32> to tensor<16xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[CAST]]
+// CHECK-SAME: [0, 1, 2, 3]
+// CHECK-SAME: tensor<16xf32> into tensor<2x2x2x2xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 6, 2, 2] : tensor<2x3x6x2x2xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x3xf32>) -> tensor<2x3x6x2x3xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32>
// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>}
-// CHECK-SAME: ins(%[[PAD]], %[[FILTER]] : tensor<2x4x7x2xf32>, tensor<2x2x2x3xf32>)
-// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x6x2x3xf32>) -> tensor<2x3x6x2x3xf32>
+// CHECK-SAME: ins(%[[PAD]], %[[EXPAND]] : tensor<2x4x7x2xf32>, tensor<2x2x2x2xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32>
// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]]
// CHECK-SAME: [0], [1], [2], [3, 4]
-// CHECK-SAME: : tensor<2x3x6x2x3xf32> into tensor<2x3x6x6xf32>
+// CHECK-SAME: : tensor<2x3x6x2x2xf32> into tensor<2x3x6x4xf32>
// -----
@@ -3166,7 +3177,7 @@
func.func @depthwise_conv_multiplier_1_with_padding(
%arg0: tensor<1x113x113x96xf32>,
- %arg1: tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> {
+ %arg1: tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = #mhlo.conv<raw
@@ -3183,8 +3194,8 @@
feature_group_count = 96 : i64,
padding = dense<[[1, 1], [2, 2]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
- func.return %0 : tensor<1x56x56x96xf32>
+ window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32>
+ func.return %0 : tensor<1x57x58x96xf32>
}
// CHECK: func @depthwise_conv_multiplier_1_with_padding
// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]
@@ -3194,16 +3205,16 @@
// CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
// CHECK: tensor.yield %[[ZERO]] : f32
// CHECK } : tensor<1x113x113x96xf32> to tensor<1x115x117x96xf32>
-// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 57, 58, 96] : tensor<1x57x58x96xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32>
// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK-SAME: : tensor<3x3x1x96xf32> into tensor<3x3x96xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
// CHECK-SAME: ins(%[[PAD]], %[[RESHAPED_FILTER]] : tensor<1x115x117x96xf32>, tensor<3x3x96xf32>)
-// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32>
// -----
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
index 1cf1109..5a11361 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir
@@ -2541,13 +2541,13 @@
// CHECK: mhlo.convolution
// CHECK-SAME: dim_numbers = [b, 1, 0, f]x[0, 1, i, o]->[b, 0, 1, f]
// CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
-func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> {
+func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x2x4x3xf32>) -> tensor<2x1x1x3xf32> {
%0 = mhlo.convolution(%arg0, %arg1)
dim_numbers = [b, 1, 0, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
- : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
- func.return %0 : tensor<3x5x5x4xf32>
+ : (tensor<2x2x3x4xf32>, tensor<3x2x4x3xf32>) -> tensor<2x1x1x3xf32>
+ func.return %0 : tensor<2x1x1x3xf32>
}
// -----
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index 4e3b81c..3f3a949 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -1707,13 +1707,13 @@
// CHECK-LABEL: func @convert_depthwise_conv2d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<[3, 3, 207, 16]> : tensor<4xi64>
// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST]]) : (tensor<3x3x1x3312xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
-// CHECK: %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
-// CHECK: return %[[VAL_3]] : tensor<1x8x8x16xf32>
+// CHECK: %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x3312xf32>
+// CHECK: return %[[VAL_3]] : tensor<1x8x8x3312xf32>
// CHECK: }
-func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32> {
+func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
dimension_numbers = #mhlo.conv<raw
input_batch_dimension = 0,
@@ -1726,8 +1726,8 @@
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>, feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
- (tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x16xf32>
- func.return %0 : tensor<1x8x8x16xf32>
+ (tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32>
+ func.return %0 : tensor<1x8x8x3312xf32>
}
// CHECK-LABEL: func @convert_conv2d_to_resize(
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
index 915e5fb..9159aa8 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir
@@ -4346,7 +4346,7 @@
func.func @conv2d_backprop_filter(
%input: tensor<100x28x28x1xf32>,
%out_backprop: tensor<100x26x26x32xf32>
- ) -> tensor<100x28x28x1xf32> {
+ ) -> tensor<3x3x1x32xf32> {
// CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
// CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
@@ -4361,8 +4361,8 @@
padding = "VALID",
strides = [1, 1, 1, 1],
use_cudnn_on_gpu = true
- } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
- func.return %result : tensor<100x28x28x1xf32>
+ } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32>
+ func.return %result : tensor<3x3x1x32xf32>
}
// -----
@@ -4391,7 +4391,7 @@
// CHECK-LABEL: @conv3d_backprop_filter
-func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
+func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> {
// CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
// CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
@@ -4399,8 +4399,8 @@
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK: return %[[RESULT]]
%filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32>
- %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
- func.return %result : tensor<2x8x8x8x1xf32>
+ %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32>
+ func.return %result : tensor<3x3x3x1x6xf32>
}
// -----
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 536125f..3bb549a 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -4460,7 +4460,7 @@
func.func @conv2d_backprop_filter(
%input: tensor<100x28x28x1xf32>,
%out_backprop: tensor<100x26x26x32xf32>
- ) -> tensor<100x28x28x1xf32> {
+ ) -> tensor<3x3x1x32xf32> {
// CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
// CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
@@ -4475,8 +4475,8 @@
padding = "VALID",
strides = [1, 1, 1, 1],
use_cudnn_on_gpu = true
- } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
- func.return %result : tensor<100x28x28x1xf32>
+ } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32>
+ func.return %result : tensor<3x3x1x32xf32>
}
// -----
@@ -4505,7 +4505,7 @@
// CHECK-LABEL: @conv3d_backprop_filter
-func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
+func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> {
// CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
// CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
@@ -4513,8 +4513,8 @@
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK: return %[[RESULT]]
%filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32>
- %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
- func.return %result : tensor<2x8x8x8x1xf32>
+ %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32>
+ func.return %result : tensor<3x3x3x1x6xf32>
}
// -----
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index 731da33..13e70ce 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -256,15 +256,15 @@
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
// implementations with attributes, etc.
// CHECK-LABEL: func private @test_conv(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x1xf32>) -> tuple<tensor<256x30x30x1xf32>> {
%test_conv {
- %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
+ %arg0.1 = f32[256,32,32,1]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
- // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {xla_shape = "f32[256,32,32,6]{2,1,3,0}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
- %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
+ // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32>
+ %copy.1 = f32[256,32,32,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
- // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {xla_shape = "f32[256,32,32,6]{2,1,3,0}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
- %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
+ // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32>
+ %reshape.2 = f32[256,32,32,1]{2,1,3,0} reshape(%copy.1)
// Note that double brackets "[[" have to be escaped as they denote variables
// in FileCheck. The only way to do so is to drop into regex with "{{"
@@ -276,15 +276,15 @@
// CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[44, 45], [60, 60]], lhs_dilate = [1, 1], rhs_dilate = [2, 3]}
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK-SAME: precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
- // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<16x30x30x256xf32>
+ // CHECK-SAME: (tensor<256x32x32x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x30x30x256xf32>
- %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
+ %convolution.4 = f32[1,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
- // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
- %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
+ // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<1x30x30x256xf32>) -> tensor<256x30x30x1xf32>
+ %reshape.5 = f32[256,30,30,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
- // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) {xla_shape = {{.*}}} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
- ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
+ // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) {xla_shape = {{.*}}} : (tensor<256x30x30x1xf32>) -> tuple<tensor<256x30x30x1xf32>>
+ ROOT %tuple.6 = (f32[256,30,30,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
}
// Test for padding attribute shape in convolution