[XLIR:TFRT] Support fallback to legacy convolution (due to version or x32 convs).
PiperOrigin-RevId: 428632858
Change-Id: Iaa1f10a4d083c38236cb654da666a2fda0a95c65
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc
index 28f2d66..a5ae524 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc
@@ -15,6 +15,7 @@
// Pattern to lower lmhlo convolution ops to tfrt_gpu dialect.
#include <sys/types.h>
+#include <algorithm>
#include <functional>
#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
@@ -109,18 +110,125 @@
op.result_scale().convertToDouble());
}
-Value CreateBuildUnfusedConvOp(Value input, Value output, Value handle,
- mlir::Location loc,
- const xla::gpu::GpuConvConfig& config,
- cudnnBackendDescriptorType_t backend_type,
- ConversionPatternRewriter& rewriter) {
- auto get_element_type = [](Value value) {
- return value.getType().cast<mlir::MemRefType>().getElementType();
- };
+mlir::Type GetMemRefElementType(Value value) {
+ return value.getType().cast<mlir::MemRefType>().getElementType();
+}
+
+// Converts (via narrowing) a Span<const int64_t> to a vector<int>, and checks
+// that the elements have not changed due to the conversion.
+std::vector<int> CheckedNarrowing(absl::Span<const int64_t> wide_span) {
+ std::vector<int> narrow_vector(wide_span.size());
+ std::transform(
+ wide_span.cbegin(), wide_span.cend(), narrow_vector.begin(),
+ [](int64_t wide) {
+ int narrow = wide;
+ assert(narrow == wide &&
+ "checked narrowing failed; values not equal post-conversion");
+ return narrow;
+ });
+ return narrow_vector;
+}
+
+// Create ops to describe tensors (e.g., input, output, or bias) when using
+// legacy cudnn.
+template <class ConvolutionOpType>
+FailureOr<Value> CreateLegacyTensorDescriptor(
+ ConvolutionOpType op, const se::dnn::BatchDescriptor& batch_descriptor,
+ cudnnDataType_t elem_type, Value chain,
+ ConversionPatternRewriter& rewriter) {
+ std::vector<int64_t> dims64, strides64;
+ switch (batch_descriptor.layout()) {
+ case se::dnn::DataLayout::kBatchYXDepth:
+ case se::dnn::DataLayout::kBatchDepthYX: {
+ // cuDNN requires the strides and dims to be ordered as BDYX.
+ dims64 = batch_descriptor.full_dims(se::dnn::DataLayout::kBatchDepthYX);
+ strides64 =
+ batch_descriptor.full_strides(se::dnn::DataLayout::kBatchDepthYX);
+ break;
+ }
+ case se::dnn::DataLayout::kBatchDepthYX4:
+ case se::dnn::DataLayout::kBatchDepthYX32: {
+ const int64_t n = batch_descriptor.count();
+ const int64_t c = batch_descriptor.feature_map_count();
+ const int64_t h = batch_descriptor.height();
+ const int64_t w = batch_descriptor.width();
+ const int64_t v =
+ batch_descriptor.layout() == se::dnn::DataLayout::kBatchDepthYX4 ? 4
+ : 32;
+ assert(c / v > 0 && "Vectorized feature map count is non-positive.");
+ dims64 = {n, c / v, h, w};
+ strides64 = {c / v * h * w, h * w, w, 1};
+ break;
+ }
+ default:
+ return rewriter.notifyMatchFailure(op, "Unsupported tensor format.");
+ }
+
+ // cuDNN requires arrays of ints.
+ std::vector<int> dims = CheckedNarrowing(dims64);
+ std::vector<int> strides = CheckedNarrowing(strides64);
+ return rewriter
+ .create<tfrt::gpu::DnnCreateTensorDescriptorOp>(
+ op.getLoc(), elem_type, rewriter.getI32ArrayAttr(dims),
+ rewriter.getI32ArrayAttr(strides), chain)
+ .getResult();
+}
+
+template <class FusedConvOpType, class FusedConvOpAdaptorType>
+FailureOr<Value> CreateLegacyFusedConvOp(
+ FusedConvOpType op, FusedConvOpAdaptorType adaptor, Type mlir_scale_type,
+ Value handle, Value stream, Value input_desc, Value output_desc,
+ Value filter_desc, Value conv_desc, Value algorithm, Value side_input,
+ Value chain, const xla::gpu::GpuConvConfig& config,
+ ConversionPatternRewriter& rewriter) {
+ // Create bias descriptor.
+ se::dnn::BatchDescriptor bias_descriptor = GetBiasDescriptor(config);
+ cudnnDataType_t bias_type = MlirTypeToCudnnDataType(
+ GetMemRefElementType(op.bias()), bias_descriptor.layout());
+ FailureOr<Value> bias_desc_or = CreateLegacyTensorDescriptor(
+ op, bias_descriptor, bias_type, chain, rewriter);
+ if (failed(bias_desc_or)) {
+ return bias_desc_or;
+ }
+
+ // Create activation descriptor.
+ auto loc = op.getLoc();
+ auto coefficient =
+ rewriter.create<tfrt::compiler::ConstantF64Op>(loc, llvm::APFloat(0.0));
+ cudnnActivationMode_t activaton_mode = config.fusion->mode == se::dnn::kRelu
+ ? CUDNN_ACTIVATION_RELU
+ : CUDNN_ACTIVATION_IDENTITY;
+ auto activation_desc =
+ rewriter.create<tfrt::gpu::DnnCreateActivationDescriptorOp>(
+ loc, coefficient, activaton_mode, CUDNN_NOT_PROPAGATE_NAN, chain);
+
+ cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+ auto alpha1 = MakeScalingFactorConstant(
+ rewriter, loc, mlir_scale_type, llvm::APFloat(config.conv_result_scale),
+ llvm::APFloat(0.0));
+ auto alpha2 = MakeScalingFactorConstant(
+ rewriter, loc, mlir_scale_type,
+ llvm::APFloat(config.fusion->side_input_scale), llvm::APFloat(0.0));
+ return rewriter
+ .create<tfrt::gpu::DnnConvolutionBiasActivationForwardOp>(
+ loc, handle, stream, scale_type, alpha1, input_desc, adaptor.input(),
+ filter_desc, adaptor.filter(), conv_desc, algorithm,
+ adaptor.scratch(), alpha2, output_desc, side_input, *bias_desc_or,
+ adaptor.bias(), activation_desc, output_desc, adaptor.output(), chain)
+ .getResult();
+}
+
+// Create op to build a convolution plan, which can be used to run the
+// convolution. This is the unfused variant (not fused with activation).
+Value CreateBuildUnfusedConvPlanOp(Value input, Value output, Value handle,
+ mlir::Location loc,
+ const xla::gpu::GpuConvConfig& config,
+ cudnnBackendDescriptorType_t backend_type,
+ ConversionPatternRewriter& rewriter) {
cudnnDataType_t input_type = MlirTypeToCudnnDataType(
- get_element_type(input), config.input_descriptor.layout());
+ GetMemRefElementType(input), config.input_descriptor.layout());
cudnnDataType_t output_type = MlirTypeToCudnnDataType(
- get_element_type(output), config.output_descriptor.layout());
+ GetMemRefElementType(output), config.output_descriptor.layout());
int vector_size, vector_dim;
std::tie(vector_size, vector_dim) =
@@ -179,50 +287,27 @@
rewriter.getI64ArrayAttr(tuning_knob_values));
}
-Value CreateBuildFusedConvOp(Value input, Value output, Value bias,
- Value handle, mlir::Location loc,
- const xla::gpu::GpuConvConfig& config,
- cudnnBackendDescriptorType_t backend_type,
- ConversionPatternRewriter& rewriter) {
- se::dnn::BatchDescriptor bias_descriptor(config.output_descriptor.ndims());
- bias_descriptor.set_count(1)
- .set_height(1)
- .set_width(1)
- .set_feature_map_count(config.output_descriptor.feature_map_count())
- .set_layout([&] {
- if (config.algorithm.is_cudnn_frontend()) {
- // For the purposes of the cudnn graph, say that the bias tensor has
- // the same layout as the output tensor. It doesn't actually matter,
- // because bias is a 1D array. But we need to get the correct
- // vectorization, otherwise the cudnn graph API rejects this tensor,
- // even though vectorized float tensors aren't even a thing in cuDNN.
- return config.output_descriptor.layout();
- }
- // Normalize NCHW_VECT_C to NCHW for layout of `bias`, even though it's
- // actually the same (because `bias` only has one dimension): cudnn
- // does not accept NCHW_VECT_C for `bias`.
- se::dnn::DataLayout layout = config.output_descriptor.layout();
- switch (layout) {
- case se::dnn::DataLayout::kBatchDepthYX4:
- case se::dnn::DataLayout::kBatchDepthYX32:
- return se::dnn::DataLayout::kBatchDepthYX;
- default:
- return layout;
- }
- }());
- if (bias_descriptor.ndims() == 3) {
- bias_descriptor.set_spatial_dim(se::dnn::DimIndex::Z, 1);
- }
+// Create op to build a convolution plan, which can be used to run the
+// convolution. This is the variant with fused activation.
+Value CreateBuildFusedConvPlanOp(Value input, Value output, Value bias,
+ Value handle, mlir::Location loc,
+ const xla::gpu::GpuConvConfig& config,
+ cudnnBackendDescriptorType_t backend_type,
+ ConversionPatternRewriter& rewriter) {
+ se::dnn::BatchDescriptor bias_descriptor = GetBiasDescriptor(config);
+ // For the purposes of the cudnn graph, say that the bias tensor has the same
+ // layout as the output tensor. It doesn't actually matter, because bias is a
+ // 1D array. But we need to get the correct vectorization, otherwise the
+ // cudnn graph API rejects this tensor, even though vectorized float tensors
+ // aren't even a thing in cuDNN.
+ bias_descriptor.set_layout(config.output_descriptor.layout());
- auto get_element_type = [](Value value) {
- return value.getType().cast<mlir::MemRefType>().getElementType();
- };
cudnnDataType_t input_type = MlirTypeToCudnnDataType(
- get_element_type(input), config.input_descriptor.layout());
+ GetMemRefElementType(input), config.input_descriptor.layout());
cudnnDataType_t output_type = MlirTypeToCudnnDataType(
- get_element_type(output), config.output_descriptor.layout());
- cudnnDataType_t bias_type =
- MlirTypeToCudnnDataType(get_element_type(bias), bias_descriptor.layout());
+ GetMemRefElementType(output), config.output_descriptor.layout());
+ cudnnDataType_t bias_type = MlirTypeToCudnnDataType(
+ GetMemRefElementType(bias), bias_descriptor.layout());
int vector_size, vector_dim;
std::tie(vector_size, vector_dim) =
@@ -305,13 +390,17 @@
descriptor.kind = xla::gpu::CudnnConvKind::kForward;
return Status::OK();
}
-Value GetResult(lmhlo_gpu::ConvForwardOp op) { return op.output(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvForwardOp op, Value handle,
- const xla::gpu::GpuConvConfig& config,
- cudnnBackendDescriptorType_t backend_type,
- ConversionPatternRewriter& rewriter) {
- return CreateBuildUnfusedConvOp(op.input(), op.output(), handle, op.getLoc(),
- config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvForwardOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvForwardOp op) { return op.output(); }
+Value GetFilter(lmhlo_gpu::ConvForwardOp op) { return op.filter(); }
+Value GetResult(lmhlo_gpu::ConvForwardOp op) { return GetOutput(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvForwardOp op, Value handle,
+ const xla::gpu::GpuConvConfig& config,
+ cudnnBackendDescriptorType_t backend_type,
+ ConversionPatternRewriter& rewriter) {
+ return CreateBuildUnfusedConvPlanOp(op.input(), op.output(), handle,
+ op.getLoc(), config, backend_type,
+ rewriter);
}
Value CreateRunConvolutionOp(lmhlo_gpu::ConvForwardOpAdaptor adaptor,
mlir::Location loc, Value handle, Value conv_plan,
@@ -321,6 +410,20 @@
loc, handle, stream, conv_plan, adaptor.input(), adaptor.output(),
adaptor.filter(), adaptor.scratch(), chain);
}
+FailureOr<Value> CreateLegacyConvOp(
+ lmhlo_gpu::ConvForwardOp op, lmhlo_gpu::ConvForwardOpAdaptor adaptor,
+ Type mlir_scale_type, Value handle, Value stream, Value input_desc,
+ Value output_desc, Value filter_desc, Value conv_desc, Value algorithm,
+ Value chain, const xla::gpu::GpuConvConfig& config,
+ ConversionPatternRewriter& rewriter) {
+ cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+ return rewriter
+ .create<tfrt::gpu::DnnConvolutionForwardOp>(
+ op.getLoc(), handle, stream, scale_type, input_desc, adaptor.input(),
+ filter_desc, adaptor.filter(), conv_desc, algorithm,
+ adaptor.scratch(), output_desc, adaptor.output(), chain)
+ .getResult();
+}
// Specialization for convolution backward input
Status SetConvKind(lmhlo_gpu::ConvBackwardInputOp op,
@@ -328,13 +431,17 @@
descriptor.kind = xla::gpu::CudnnConvKind::kBackwardInput;
return Status::OK();
}
-Value GetResult(lmhlo_gpu::ConvBackwardInputOp op) { return op.d_input(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvBackwardInputOp op, Value handle,
- const xla::gpu::GpuConvConfig& config,
- cudnnBackendDescriptorType_t backend_type,
- ConversionPatternRewriter& rewriter) {
- return CreateBuildUnfusedConvOp(op.d_input(), op.d_output(), handle,
- op.getLoc(), config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvBackwardInputOp op) { return op.d_input(); }
+Value GetOutput(lmhlo_gpu::ConvBackwardInputOp op) { return op.d_output(); }
+Value GetFilter(lmhlo_gpu::ConvBackwardInputOp op) { return op.filter(); }
+Value GetResult(lmhlo_gpu::ConvBackwardInputOp op) { return GetInput(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvBackwardInputOp op, Value handle,
+ const xla::gpu::GpuConvConfig& config,
+ cudnnBackendDescriptorType_t backend_type,
+ ConversionPatternRewriter& rewriter) {
+ return CreateBuildUnfusedConvPlanOp(op.d_input(), op.d_output(), handle,
+ op.getLoc(), config, backend_type,
+ rewriter);
}
Value CreateRunConvolutionOp(lmhlo_gpu::ConvBackwardInputOpAdaptor adaptor,
mlir::Location loc, Value handle, Value conv_plan,
@@ -344,6 +451,21 @@
loc, handle, stream, conv_plan, adaptor.d_input(), adaptor.d_output(),
adaptor.filter(), adaptor.scratch(), chain);
}
+FailureOr<Value> CreateLegacyConvOp(
+ lmhlo_gpu::ConvBackwardInputOp op,
+ lmhlo_gpu::ConvBackwardInputOpAdaptor adaptor, Type mlir_scale_type,
+ Value handle, Value stream, Value input_desc, Value output_desc,
+ Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+ const xla::gpu::GpuConvConfig& config,
+ ConversionPatternRewriter& rewriter) {
+ cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+ return rewriter
+ .create<tfrt::gpu::DnnConvolutionBackwardDataOp>(
+ op.getLoc(), handle, stream, scale_type, filter_desc,
+ adaptor.filter(), output_desc, adaptor.d_output(), conv_desc,
+ algorithm, adaptor.scratch(), input_desc, adaptor.d_input(), chain)
+ .getResult();
+}
// Specialization for convolution backward filter
Status SetConvKind(lmhlo_gpu::ConvBackwardFilterOp op,
@@ -351,13 +473,17 @@
descriptor.kind = xla::gpu::CudnnConvKind::kBackwardFilter;
return Status::OK();
}
-Value GetResult(lmhlo_gpu::ConvBackwardFilterOp op) { return op.d_filter(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvBackwardFilterOp op, Value handle,
- const xla::gpu::GpuConvConfig& config,
- cudnnBackendDescriptorType_t backend_type,
- ConversionPatternRewriter& rewriter) {
- return CreateBuildUnfusedConvOp(op.input(), op.d_output(), handle,
- op.getLoc(), config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvBackwardFilterOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvBackwardFilterOp op) { return op.d_output(); }
+Value GetFilter(lmhlo_gpu::ConvBackwardFilterOp op) { return op.d_filter(); }
+Value GetResult(lmhlo_gpu::ConvBackwardFilterOp op) { return GetFilter(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvBackwardFilterOp op, Value handle,
+ const xla::gpu::GpuConvConfig& config,
+ cudnnBackendDescriptorType_t backend_type,
+ ConversionPatternRewriter& rewriter) {
+ return CreateBuildUnfusedConvPlanOp(op.input(), op.d_output(), handle,
+ op.getLoc(), config, backend_type,
+ rewriter);
}
Value CreateRunConvolutionOp(lmhlo_gpu::ConvBackwardFilterOpAdaptor adaptor,
mlir::Location loc, Value handle, Value conv_plan,
@@ -367,6 +493,21 @@
loc, handle, stream, conv_plan, adaptor.input(), adaptor.d_output(),
adaptor.d_filter(), adaptor.scratch(), chain);
}
+FailureOr<Value> CreateLegacyConvOp(
+ lmhlo_gpu::ConvBackwardFilterOp op,
+ lmhlo_gpu::ConvBackwardFilterOpAdaptor adaptor, Type mlir_scale_type,
+ Value handle, Value stream, Value input_desc, Value output_desc,
+ Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+ const xla::gpu::GpuConvConfig& config,
+ ConversionPatternRewriter& rewriter) {
+ cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+ return rewriter
+ .create<tfrt::gpu::DnnConvolutionBackwardFilterOp>(
+ op.getLoc(), handle, stream, scale_type, input_desc, adaptor.input(),
+ output_desc, adaptor.d_output(), conv_desc, algorithm,
+ adaptor.scratch(), filter_desc, adaptor.d_filter(), chain)
+ .getResult();
+}
// Specialization for convolution forward fused
Status SetConvKind(lmhlo_gpu::ConvForwardFusedOp op,
@@ -382,13 +523,17 @@
static_cast<int64_t>(activation_mode));
return Status::OK();
}
-Value GetResult(lmhlo_gpu::ConvForwardFusedOp op) { return op.output(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvForwardFusedOp op, Value handle,
- const xla::gpu::GpuConvConfig& config,
- cudnnBackendDescriptorType_t backend_type,
- ConversionPatternRewriter& rewriter) {
- return CreateBuildFusedConvOp(op.input(), op.output(), op.bias(), handle,
- op.getLoc(), config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvForwardFusedOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvForwardFusedOp op) { return op.output(); }
+Value GetFilter(lmhlo_gpu::ConvForwardFusedOp op) { return op.filter(); }
+Value GetResult(lmhlo_gpu::ConvForwardFusedOp op) { return GetOutput(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvForwardFusedOp op, Value handle,
+ const xla::gpu::GpuConvConfig& config,
+ cudnnBackendDescriptorType_t backend_type,
+ ConversionPatternRewriter& rewriter) {
+ return CreateBuildFusedConvPlanOp(op.input(), op.output(), op.bias(), handle,
+ op.getLoc(), config, backend_type,
+ rewriter);
}
Value CreateRunConvolutionOp(lmhlo_gpu::ConvForwardFusedOpAdaptor adaptor,
mlir::Location loc, Value handle, Value conv_plan,
@@ -399,6 +544,18 @@
adaptor.filter(), adaptor.output(), adaptor.bias(), adaptor.scratch(),
chain);
}
+FailureOr<Value> CreateLegacyConvOp(
+ lmhlo_gpu::ConvForwardFusedOp op,
+ lmhlo_gpu::ConvForwardFusedOpAdaptor adaptor, Type mlir_scale_type,
+ Value handle, Value stream, Value input_desc, Value output_desc,
+ Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+ const xla::gpu::GpuConvConfig& config,
+ ConversionPatternRewriter& rewriter) {
+ return CreateLegacyFusedConvOp(op, adaptor, mlir_scale_type, handle, stream,
+ input_desc, output_desc, filter_desc,
+ conv_desc, algorithm, adaptor.output(), chain,
+ config, rewriter);
+}
// Specialization for convolution forward fused side input
Status SetConvKind(lmhlo_gpu::ConvForwardFusedSideInputOp op,
@@ -416,15 +573,23 @@
op.side_input_scale().convertToDouble());
return Status::OK();
}
-Value GetResult(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
+Value GetInput(lmhlo_gpu::ConvForwardFusedSideInputOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
return op.output();
}
-Value CreateBuildConvOp(lmhlo_gpu::ConvForwardFusedSideInputOp op, Value handle,
- const xla::gpu::GpuConvConfig& config,
- cudnnBackendDescriptorType_t backend_type,
- ConversionPatternRewriter& rewriter) {
- return CreateBuildFusedConvOp(op.input(), op.output(), op.bias(), handle,
- op.getLoc(), config, backend_type, rewriter);
+Value GetFilter(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
+ return op.filter();
+}
+Value GetResult(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
+ return GetOutput(op);
+}
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvForwardFusedSideInputOp op,
+ Value handle, const xla::gpu::GpuConvConfig& config,
+ cudnnBackendDescriptorType_t backend_type,
+ ConversionPatternRewriter& rewriter) {
+ return CreateBuildFusedConvPlanOp(op.input(), op.output(), op.bias(), handle,
+ op.getLoc(), config, backend_type,
+ rewriter);
}
Value CreateRunConvolutionOp(
lmhlo_gpu::ConvForwardFusedSideInputOpAdaptor adaptor, mlir::Location loc,
@@ -435,6 +600,111 @@
adaptor.filter(), adaptor.side_input(), adaptor.bias(), adaptor.scratch(),
chain);
}
+FailureOr<Value> CreateLegacyConvOp(
+ lmhlo_gpu::ConvForwardFusedSideInputOp op,
+ lmhlo_gpu::ConvForwardFusedSideInputOpAdaptor adaptor, Type mlir_scale_type,
+ Value handle, Value stream, Value input_desc, Value output_desc,
+ Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+ const xla::gpu::GpuConvConfig& config,
+ ConversionPatternRewriter& rewriter) {
+ Value side_input = config.fusion->side_input_scale == 0
+ ? adaptor.output()
+ : adaptor.side_input();
+ return CreateLegacyFusedConvOp(
+ op, adaptor, mlir_scale_type, handle, stream, input_desc, output_desc,
+ filter_desc, conv_desc, algorithm, side_input, chain, config, rewriter);
+}
+
+template <class ConvolutionOpType, class OpAdaptor>
+FailureOr<Value> LegacyConvolutionRewritePattern(
+ ConvolutionOpType op, OpAdaptor adaptor, Value chain, Value stream,
+ const xla::gpu::GpuConvConfig& config,
+ ConversionPatternRewriter& rewriter) {
+ cudnnDataType_t input_type = MlirTypeToCudnnDataType(
+ GetMemRefElementType(GetInput(op)), config.input_descriptor.layout());
+ cudnnDataType_t output_type = MlirTypeToCudnnDataType(
+ GetMemRefElementType(GetOutput(op)), config.output_descriptor.layout());
+ cudnnDataType_t filter_type = MlirTypeToCudnnDataType(
+ GetMemRefElementType(GetInput(op)), config.filter_descriptor.layout());
+
+ // Create input descriptor.
+ FailureOr<Value> input_desc_or = CreateLegacyTensorDescriptor(
+ op, config.input_descriptor, input_type, chain, rewriter);
+ if (failed(input_desc_or)) {
+ return input_desc_or;
+ }
+
+ // Create output descriptor.
+ FailureOr<Value> output_desc_or = CreateLegacyTensorDescriptor(
+ op, config.output_descriptor, output_type, chain, rewriter);
+ if (failed(output_desc_or)) {
+ return output_desc_or;
+ }
+
+ // Create filter descriptor.
+ cudnnTensorFormat_t tensor_format;
+ switch (config.filter_descriptor.layout()) {
+ case se::dnn::FilterLayout::kOutputInputYX:
+ tensor_format = CUDNN_TENSOR_NCHW;
+ break;
+ case se::dnn::FilterLayout::kOutputYXInput:
+ tensor_format = CUDNN_TENSOR_NHWC;
+ break;
+ case se::dnn::FilterLayout::kOutputInputYX4:
+ case se::dnn::FilterLayout::kOutputInputYX32: {
+ tensor_format = CUDNN_TENSOR_NCHW_VECT_C;
+ break;
+ }
+ default:
+ return rewriter.notifyMatchFailure(op, "Unexpected filter layout.");
+ }
+ std::vector<int> dims(2 + config.filter_descriptor.ndims());
+ dims[0] = config.filter_descriptor.output_feature_map_count();
+ dims[1] = config.filter_descriptor.input_feature_map_count();
+ absl::Span<const int64_t> spatial_dims =
+ config.filter_descriptor.input_filter_dims();
+ std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
+ auto loc = op.getLoc();
+ Value filter_desc = rewriter.create<tfrt::gpu::DnnCreateFilterDescriptorOp>(
+ loc, filter_type, tensor_format, rewriter.getI32ArrayAttr(dims), chain);
+
+ // Create convolution descriptor.
+ mlir::Type mlir_compute_type = GetMemRefElementType(GetInput(op));
+ cudnnDataType_t compute_type = MlirTypeToCudnnDataType(mlir_compute_type);
+ cudnnConvolutionMode_t conv_mode =
+ config.conv_desc.convolution_not_crosscorr() ? CUDNN_CONVOLUTION
+ : CUDNN_CROSS_CORRELATION;
+ const auto& convolution_descriptor = config.conv_desc;
+ // cuDNN requires arrays of ints.
+ std::vector<int> strides = CheckedNarrowing(convolution_descriptor.strides());
+ std::vector<int> padding = CheckedNarrowing(convolution_descriptor.padding());
+ std::vector<int> dilations =
+ CheckedNarrowing(convolution_descriptor.dilations());
+ cudnnMathType_t math_type = config.algorithm.tensor_ops_enabled()
+ ? CUDNN_TENSOR_OP_MATH
+ : CUDNN_FMA_MATH;
+ Value conv_desc =
+ rewriter.create<tfrt::gpu::DnnCreateConvolutionDescriptorOp>(
+ loc, compute_type, conv_mode, math_type,
+ rewriter.getI32ArrayAttr(padding), rewriter.getI32ArrayAttr(strides),
+ rewriter.getI32ArrayAttr(dilations), chain);
+
+ // Create convolution op.
+ mlir::Type mlir_scale_type =
+ mlir_compute_type.isF64() ? rewriter.getF64Type() : rewriter.getF32Type();
+ Value context = rewriter.create<tfrt::gpu::StreamGetContextOp>(loc, stream);
+ Value handle = rewriter.create<tfrt::gpu::DnnCreateOp>(loc, context);
+ Value algorithm = rewriter.create<tfrt::compiler::ConstantUI64Op>(
+ loc, config.algorithm.algo_id());
+ auto out_chain_or =
+ CreateLegacyConvOp(op, adaptor, mlir_scale_type, handle, stream,
+ *input_desc_or, *output_desc_or, filter_desc,
+ conv_desc, algorithm, chain, config, rewriter);
+ if (succeeded(out_chain_or)) {
+ rewriter.eraseOp(op);
+ }
+ return out_chain_or;
+}
template <class ConvolutionOpType>
struct ConvolutionRewritePattern
@@ -472,6 +742,35 @@
op, "TensorFlow padding alignment is not supported.");
}
+ bool use_legacy_conv = [&] {
+ if (!config.algorithm.is_cudnn_frontend()) return true;
+
+ auto print_reason = [&](const char* reason) {
+ LOG(ERROR)
+ << "Disabling cuDNN frontend for the following convolution:\n"
+ << " input: " << config.input_descriptor.ToString() << "\n"
+ << " filter: " << config.filter_descriptor.ToString() << "\n"
+ << " conv: " << config.conv_desc.ToString() << "\n... because "
+ << reason;
+ };
+
+ if (config.input_descriptor.layout() ==
+ se::dnn::DataLayout::kBatchDepthYX32)
+ // Current versions of the frontend API lack support for Tx32.
+ return print_reason("Tx32 convolutions are unsupported."), true;
+
+ if (CUDNN_VERSION < 8100)
+ // cuDNN frontend support became sufficiently stable to use in 8.1.
+ return print_reason("the cuDNN version does not support it."), true;
+
+ return false;
+ }();
+
+ if (use_legacy_conv) {
+ return LegacyConvolutionRewritePattern(op, adaptor, chain, stream, config,
+ rewriter);
+ }
+
cudnnBackendDescriptorType_t backend_type;
switch (descriptor.kind) {
case xla::gpu::CudnnConvKind::kForward:
@@ -507,8 +806,8 @@
rewriter.getFunctionType(handle_type, conv_plan_type));
symbol_table.insert(conv_plan_func);
rewriter.setInsertionPointToEnd(conv_plan_func.addEntryBlock());
- Value conv_plan = CreateBuildConvOp(op, conv_plan_func.getArgument(0),
- config, backend_type, rewriter);
+ Value conv_plan = CreateBuildConvPlanOp(op, conv_plan_func.getArgument(0),
+ config, backend_type, rewriter);
rewriter.create<tfrt::compiler::ReturnOp>(loc, conv_plan);
// Once-initialize the convolution plan.
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc
index 7b510e2..8fcaa2e 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc
@@ -65,7 +65,7 @@
}
break;
case se::dnn::DataLayout::kBatchDepthYX32:
- if (type.isSignlessInteger(/*width=*/32)) {
+ if (type.isSignlessInteger(/*width=*/8)) {
return CUDNN_DATA_INT8x32;
}
break;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 172b5a2..f780a12 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -2559,6 +2559,8 @@
"//tensorflow/compiler/xla/service/gpu/tests:kernel_launch_test",
"//tensorflow/compiler/xla/service/gpu/tests:mlir_gemm_test",
"//tensorflow/compiler/xla/tests:cholesky_test_gpu",
+ "//tensorflow/compiler/xla/tests:convolution_test_cudnn_frontend_disabled_gpu",
+ "//tensorflow/compiler/xla/tests:convolution_test_gpu",
"//tensorflow/compiler/xla/tests:dot_operation_test_gpu",
"//tensorflow/compiler/xla/tests:multioutput_fusion_test_gpu",
"//tensorflow/compiler/xla/tests:scatter_test_gpu",
@@ -2622,6 +2624,8 @@
"//tensorflow/compiler/xla/service/gpu/tests:sorting_test",
"//tensorflow/compiler/xla/service/gpu/tests:tree_reduction_rewriter_test",
"//tensorflow/compiler/xla/tests:cholesky_test_gpu",
+ "//tensorflow/compiler/xla/tests:convolution_test_cudnn_frontend_disabled_gpu",
+ "//tensorflow/compiler/xla/tests:convolution_test_gpu",
"//tensorflow/compiler/xla/tests:multioutput_fusion_test_gpu",
"//tensorflow/compiler/xla/tests:scatter_test_gpu",
"//tensorflow/compiler/xla/tests:select_and_scatter_test_gpu",
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index c0c0597..ad3e420 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1273,6 +1273,24 @@
)
xla_test(
+ name = "convolution_test_cudnn_frontend_disabled",
+ timeout = "long",
+ srcs = ["convolution_test.cc"],
+ backend_args = {"gpu": ["--xla_gpu_enable_cudnn_frontend=false"]},
+ backends = ["gpu"],
+ shard_count = 50,
+ tags = [
+ "no_rocm",
+ "nozapfhahn",
+ "optonly",
+ ],
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_test(
name = "convolution_variants_test",
timeout = "long",
srcs = ["convolution_variants_test.cc"],