[mhlo] ConvOp -> ConvolutionOp
Aligns the op class name with the mnemonic
PiperOrigin-RevId: 459808502
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td
index ea57644..889f3c9 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td
@@ -851,7 +851,7 @@
);
}
-def LHLO_ConvOp : LHLO_Op<"convolution", []> {
+def LHLO_ConvolutionOp : LHLO_Op<"convolution", []> {
let summary = "Convolution operator";
let description = [{
Computes a convolution of the kind used in neural networks.
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h
index 4268e59..035a43c 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h
@@ -51,7 +51,7 @@
MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ComplexOp);
MAP_HLO_TO_LHLO(ConcatenateOp);
-MAP_HLO_TO_LHLO(ConvOp);
+MAP_HLO_TO_LHLO(ConvolutionOp);
MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp);
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h
index 98f3555..117b827 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h
@@ -49,7 +49,7 @@
MAP_LHLO_TO_HLO(CompareOp);
MAP_LHLO_TO_HLO(ComplexOp);
MAP_LHLO_TO_HLO(ConcatenateOp);
-MAP_LHLO_TO_HLO(ConvOp);
+MAP_LHLO_TO_HLO(ConvolutionOp);
MAP_LHLO_TO_HLO(ConvertOp);
MAP_LHLO_TO_HLO(CopyOp);
MAP_LHLO_TO_HLO(CosOp);
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 666c9e8..9bd4628 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1561,7 +1561,7 @@
let hasVerifier = 1;
}
-def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]> {
+def HLO_ConvolutionOp : HLO_Op<"convolution", [NoSideEffect]> {
let summary = "Convolution operator";
let description = [{
Computes a convolution of the kind used in neural networks.
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index f6d18f0..20f8a84 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -1861,7 +1861,7 @@
}
//===----------------------------------------------------------------------===//
-// ConvOp
+// ConvolutionOp
//===----------------------------------------------------------------------===//
namespace {
@@ -1874,7 +1874,7 @@
// Note that the spatial + non-spatial dimensions may not cover all the
// dimensions in the range [0,num) because of the presence of 'unknown'
// dimensions (ref. cl/415132294).
-LogicalResult isSpatialDimensionsValid(ConvOp op) {
+LogicalResult isSpatialDimensionsValid(ConvolutionOp op) {
auto inputSpatialDimensions =
op.dimension_numbers().getInputSpatialDimensions();
auto kernelSpatialDimensions =
@@ -1962,7 +1962,7 @@
// b % bgc == 0
// f % fgc == 0 and i = f / fgc
// o (or f') % bgc == 0 and o (or f') % fgc == 0
-LogicalResult verifyConvolutionAttributes(ConvOp op) {
+LogicalResult verifyConvolutionAttributes(ConvolutionOp op) {
// P1.
if (failed(isSpatialDimensionsValid(op))) return failure();
@@ -2044,12 +2044,12 @@
return success();
}
-// Infer the return-shape of ConvOp.
+// Infer the return-shape of ConvolutionOp.
// Precondition:
-// 1. Input args to ConvOp 'op' are RankedTypes.
+// 1. Input args to ConvolutionOp 'op' are RankedTypes.
// 2. rank-of(input-type) == rank-of(output-type)
-SmallVector<int64_t> inferConvOpReturnShape(
- ConvOp op, const ArrayRef<WindowDimension> window) {
+SmallVector<int64_t> inferConvolutionOpReturnShape(
+ ConvolutionOp op, const ArrayRef<WindowDimension> window) {
// We keep the 'unknown' dimensions (cl/415132294) as it is in the
// output-shape. To do that we initilize the output dimensions with the shape
// of the return-type and updates only the spatial + non-spatial dimensions.
@@ -2097,7 +2097,7 @@
* P4. Verify the return shape.
* TODO(b/232574102): Verify the element-type of return-value.
*/
-LogicalResult ConvOp::verify() {
+LogicalResult ConvolutionOp::verify() {
auto lhsType = lhs().getType().dyn_cast<RankedTensorType>();
auto rhsType = rhs().getType().dyn_cast<RankedTensorType>();
@@ -2149,7 +2149,7 @@
<< numDims << "), but got "
<< actualReturnRankedType.getRank() << ".";
- auto expectedReturnShape = inferConvOpReturnShape(*this, *windowOrErr);
+ auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr);
auto expectedReturnType =
RankedTensorType::get(expectedReturnShape, actualReturnElementType);
if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType)))
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
index cd4786e..9fb2ed3 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
@@ -524,7 +524,7 @@
HloToLhloOpConverter<mhlo::ComplexOp>,
HloToLhloOpConverter<mhlo::ConcatenateOp>,
HloToLhloOpConverter<mhlo::ConstantOp>,
- HloToLhloOpConverter<mhlo::ConvOp>,
+ HloToLhloOpConverter<mhlo::ConvolutionOp>,
HloToLhloOpConverter<mhlo::ConvertOp>,
HloToLhloOpConverter<mhlo::CopyOp>,
HloToLhloOpConverter<mhlo::CosOp>,
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 1a7cef8..020cfab 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -2097,11 +2097,12 @@
/// Converts mhlo.conv operation to linalg named op. This only covers normal
/// convolution cases. The op must have canonical dimension numbers. Depthwise
/// convolution and pointwise convolution are not handled in the conversion.
-struct NormalConvOpConversion : public OpConversionPattern<mhlo::ConvOp> {
- using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
+struct NormalConvolutionOpConversion
+ : public OpConversionPattern<mhlo::ConvolutionOp> {
+ using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ConvOp op, OpAdaptor adaptor,
+ mhlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
if (op.feature_group_count() != 1u) return failure();
@@ -2174,11 +2175,12 @@
/// Converts mhlo.convolution operation to
/// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
/// depthwise_conv_2d_input_nhwc_filter_hwc op.
-struct DepthwiseConvOpConversion : public OpConversionPattern<mhlo::ConvOp> {
- using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
+struct DepthwiseConvolutionOpConversion
+ : public OpConversionPattern<mhlo::ConvolutionOp> {
+ using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ConvOp op, OpAdaptor adaptor,
+ mhlo::ConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (op.batch_group_count() != 1) return failure();
// Fall into the normal convolution cases.
@@ -3338,8 +3340,8 @@
DynamicSliceConverter,
DynamicUpdateSliceConverter,
TransposeConverter<mhlo::TransposeOp>,
- NormalConvOpConversion,
- DepthwiseConvOpConversion,
+ NormalConvolutionOpConversion,
+ DepthwiseConvolutionOpConversion,
GatherConversion,
PadOpConversion,
PadOpNegativePaddingConversion,
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_conv_op.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_conv_op.mlir
index c4a2eaf..1b5d1dc 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_conv_op.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_conv_op.mlir
@@ -691,7 +691,7 @@
// -----
-// The following tests checks the inferred output-type of ConvOp. We
+// The following tests checks the inferred output-type of ConvolutionOp. We
// deliberately put an invalid output-type in these tests so that the
// inffered-type can be highlighted in the error message.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
index ea40206..2937c6b 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
@@ -160,7 +160,7 @@
// Common functionality for ConvertConvOp classes.
template <int SupportedSpatialDims>
struct ConvertNdConvOp {
- bool IsSupportedConvOp(mhlo::ConvOp conv_op) const {
+ bool IsSupportedConvOp(mhlo::ConvolutionOp conv_op) const {
if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
!conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
!conv_op.getType().cast<ShapedType>().hasStaticShape())
@@ -192,13 +192,13 @@
// Convert a 1-D convolution into a 2-D convolution (which TF supports) so that
// it can be rewritten by the pattern `Convert2DConvOp`.
-class Convert1DConvOp : public OpConversionPattern<mhlo::ConvOp>,
+class Convert1DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
ConvertNdConvOp<1> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ConvOp conv_op, OpAdaptor adaptor,
+ mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
//
// Check that input is a supported 1d convolution.
@@ -360,7 +360,7 @@
rewriter)
.shape;
- auto conv2d_op = rewriter.create<mhlo::ConvOp>(
+ auto conv2d_op = rewriter.create<mhlo::ConvolutionOp>(
conv_op.getLoc(), transposed_output_2d_shape,
transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(),
window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d,
@@ -389,13 +389,13 @@
}
};
-class Convert2DConvOp : public OpConversionPattern<mhlo::ConvOp>,
+class Convert2DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,
ConvertNdConvOp<2> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ConvOp conv_op, OpAdaptor adaptor,
+ mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!IsSupportedConvOp(conv_op)) {
return failure();
@@ -469,7 +469,7 @@
};
private:
- bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
+ bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims,
ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
ArrayRef<int64_t> padding_array) const {
mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
@@ -602,7 +602,7 @@
start_attr, size_attr);
}
- void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
+ void CreateConvOp(mhlo::ConvolutionOp conv_op, ArrayRef<int64_t> strides,
StringRef padding, ArrayRef<int64_t> explicit_padding,
ArrayRef<int64_t> dilation, bool is_depthwise_conv,
int input_channels, int num_spatial_dims,
@@ -698,12 +698,13 @@
}
};
-class ConvertNonTrivialConvOp : public OpConversionPattern<mhlo::ConvOp> {
+class ConvertNonTrivialConvOp
+ : public OpConversionPattern<mhlo::ConvolutionOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ConvOp conv_op, OpAdaptor adaptor,
+ mhlo::ConvolutionOp conv_op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (IsSupportedConvOp(conv_op, rewriter).failed()) {
return rewriter.notifyMatchFailure(
@@ -788,7 +789,7 @@
};
private:
- bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
+ bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims,
ArrayRef<int64_t> strides) const {
for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
int dim = i + 1;
@@ -804,7 +805,7 @@
return true;
}
- LogicalResult IsSupportedConvOp(mhlo::ConvOp conv_op,
+ LogicalResult IsSupportedConvOp(mhlo::ConvolutionOp conv_op,
ConversionPatternRewriter &rewriter) const {
if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
!conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
@@ -888,7 +889,7 @@
return success();
}
- void CreateResizeBilinearOp(mhlo::ConvOp conv_op,
+ void CreateResizeBilinearOp(mhlo::ConvolutionOp conv_op,
llvm::ArrayRef<int32_t> output_sizes,
bool align_corners,
ConversionPatternRewriter &rewriter) const {
@@ -908,7 +909,7 @@
rewriter.replaceOp(conv_op, {output});
}
- LogicalResult MatchResizeOp(mhlo::ConvOp conv_op, bool &align_corners,
+ LogicalResult MatchResizeOp(mhlo::ConvolutionOp conv_op, bool &align_corners,
llvm::SmallVector<int, 2> &output_sizes,
ConversionPatternRewriter &rewriter) const {
mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers();
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 4591731..83704d7 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -1274,7 +1274,8 @@
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
return func_builder
- ->create<mlir::mhlo::ConvOp>(loc, result_type, operands, attributes)
+ ->create<mlir::mhlo::ConvolutionOp>(loc, result_type, operands,
+ attributes)
.getOperation();
}
diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc
index 17521a0..e10942e 100644
--- a/tensorflow/compiler/mlir/xla/hlo_utils.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc
@@ -351,7 +351,7 @@
return xla::HloOpcode::kClamp;
} else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
return xla::HloOpcode::kConcatenate;
- } else if (isa<mlir::mhlo::ConvOp, mlir::lmhlo::ConvOp>(op)) {
+ } else if (isa<mlir::mhlo::ConvolutionOp, mlir::lmhlo::ConvolutionOp>(op)) {
return xla::HloOpcode::kConvolution;
} else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
return xla::HloOpcode::kSort;
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index c08462b..568ea7f 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -108,7 +108,7 @@
mlir::ArrayAttr config_attr;
if (precision_config)
config_attr = ConvertPrecisionConfig(precision_config, &builder_);
- auto op = builder_.create<mlir::mhlo::ConvOp>(
+ auto op = builder_.create<mlir::mhlo::ConvolutionOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
GetI64ElementsAttr(window_strides, &builder_),
ConvertPadding(padding, &builder_),
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index d4ab3d1..10ed5d8 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -1031,7 +1031,7 @@
return failure();
}
-LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
+LogicalResult ExportXlaOp(mlir::mhlo::ConvolutionOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp lhs, rhs;
if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 008d976..83537c4 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -1399,8 +1399,8 @@
NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
dimension_numbers_attr, feature_group_count_attr,
batch_group_count_attr, paddings_attr};
- rewriter.replaceOpWithNewOp<ConvOp>(op, op.getType(), operands,
- llvm::makeArrayRef(attrs));
+ rewriter.replaceOpWithNewOp<ConvolutionOp>(op, op.getType(), operands,
+ llvm::makeArrayRef(attrs));
return success();
}
};
@@ -5041,7 +5041,7 @@
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
- Value result = rewriter.create<ConvOp>(
+ Value result = rewriter.create<ConvolutionOp>(
op.getLoc(), op.getType(), op.out_backprop(), filter,
/*window_strides=*/
GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
@@ -5247,7 +5247,7 @@
const int batch_dim =
tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
- Value result = rewriter.create<ConvOp>(
+ Value result = rewriter.create<ConvolutionOp>(
op.getLoc(), op.getType(), op.input(), op.out_backprop(),
/*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
/*padding=*/paddings_attr, /*lhs_dilation=*/
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index c20d2a0..b92b08c 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -46,7 +46,7 @@
_version = 77
# Version number for MLIR:Python components.
-mlir_api_version = 23
+mlir_api_version = 24
xla_platform_names = {
'cpu': 'Host',