[XLIR:TFRT] Support fallback to legacy convolution (due to version or x32 convs).

PiperOrigin-RevId: 428632858
Change-Id: Iaa1f10a4d083c38236cb654da666a2fda0a95c65
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc
index 28f2d66..a5ae524 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/convolution_pattern.cc
@@ -15,6 +15,7 @@
 // Pattern to lower lmhlo convolution ops to tfrt_gpu dialect.
 #include <sys/types.h>
 
+#include <algorithm>
 #include <functional>
 
 #include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
@@ -109,18 +110,125 @@
       op.result_scale().convertToDouble());
 }
 
-Value CreateBuildUnfusedConvOp(Value input, Value output, Value handle,
-                               mlir::Location loc,
-                               const xla::gpu::GpuConvConfig& config,
-                               cudnnBackendDescriptorType_t backend_type,
-                               ConversionPatternRewriter& rewriter) {
-  auto get_element_type = [](Value value) {
-    return value.getType().cast<mlir::MemRefType>().getElementType();
-  };
+mlir::Type GetMemRefElementType(Value value) {
+  return value.getType().cast<mlir::MemRefType>().getElementType();
+}
+
+// Converts (via narrowing) a Span<const int64_t> to a vector<int>, and checks
+// that the elements have not changed due to the conversion.
+std::vector<int> CheckedNarrowing(absl::Span<const int64_t> wide_span) {
+  std::vector<int> narrow_vector(wide_span.size());
+  std::transform(
+      wide_span.cbegin(), wide_span.cend(), narrow_vector.begin(),
+      [](int64_t wide) {
+        int narrow = wide;
+        assert(narrow == wide &&
+               "checked narrowing failed; values not equal post-conversion");
+        return narrow;
+      });
+  return narrow_vector;
+}
+
+// Create ops to describe tensors (e.g., input, output, or bias) when using
+// legacy cudnn.
+template <class ConvolutionOpType>
+FailureOr<Value> CreateLegacyTensorDescriptor(
+    ConvolutionOpType op, const se::dnn::BatchDescriptor& batch_descriptor,
+    cudnnDataType_t elem_type, Value chain,
+    ConversionPatternRewriter& rewriter) {
+  std::vector<int64_t> dims64, strides64;
+  switch (batch_descriptor.layout()) {
+    case se::dnn::DataLayout::kBatchYXDepth:
+    case se::dnn::DataLayout::kBatchDepthYX: {
+      // cuDNN requires the strides and dims to be ordered as BDYX.
+      dims64 = batch_descriptor.full_dims(se::dnn::DataLayout::kBatchDepthYX);
+      strides64 =
+          batch_descriptor.full_strides(se::dnn::DataLayout::kBatchDepthYX);
+      break;
+    }
+    case se::dnn::DataLayout::kBatchDepthYX4:
+    case se::dnn::DataLayout::kBatchDepthYX32: {
+      const int64_t n = batch_descriptor.count();
+      const int64_t c = batch_descriptor.feature_map_count();
+      const int64_t h = batch_descriptor.height();
+      const int64_t w = batch_descriptor.width();
+      const int64_t v =
+          batch_descriptor.layout() == se::dnn::DataLayout::kBatchDepthYX4 ? 4
+                                                                           : 32;
+      assert(c / v > 0 && "Vectorized feature map count is non-positive.");
+      dims64 = {n, c / v, h, w};
+      strides64 = {c / v * h * w, h * w, w, 1};
+      break;
+    }
+    default:
+      return rewriter.notifyMatchFailure(op, "Unsupported tensor format.");
+  }
+
+  // cuDNN requires arrays of ints.
+  std::vector<int> dims = CheckedNarrowing(dims64);
+  std::vector<int> strides = CheckedNarrowing(strides64);
+  return rewriter
+      .create<tfrt::gpu::DnnCreateTensorDescriptorOp>(
+          op.getLoc(), elem_type, rewriter.getI32ArrayAttr(dims),
+          rewriter.getI32ArrayAttr(strides), chain)
+      .getResult();
+}
+
+template <class FusedConvOpType, class FusedConvOpAdaptorType>
+FailureOr<Value> CreateLegacyFusedConvOp(
+    FusedConvOpType op, FusedConvOpAdaptorType adaptor, Type mlir_scale_type,
+    Value handle, Value stream, Value input_desc, Value output_desc,
+    Value filter_desc, Value conv_desc, Value algorithm, Value side_input,
+    Value chain, const xla::gpu::GpuConvConfig& config,
+    ConversionPatternRewriter& rewriter) {
+  // Create bias descriptor.
+  se::dnn::BatchDescriptor bias_descriptor = GetBiasDescriptor(config);
+  cudnnDataType_t bias_type = MlirTypeToCudnnDataType(
+      GetMemRefElementType(op.bias()), bias_descriptor.layout());
+  FailureOr<Value> bias_desc_or = CreateLegacyTensorDescriptor(
+      op, bias_descriptor, bias_type, chain, rewriter);
+  if (failed(bias_desc_or)) {
+    return bias_desc_or;
+  }
+
+  // Create activation descriptor.
+  auto loc = op.getLoc();
+  auto coefficient =
+      rewriter.create<tfrt::compiler::ConstantF64Op>(loc, llvm::APFloat(0.0));
+  cudnnActivationMode_t activaton_mode = config.fusion->mode == se::dnn::kRelu
+                                             ? CUDNN_ACTIVATION_RELU
+                                             : CUDNN_ACTIVATION_IDENTITY;
+  auto activation_desc =
+      rewriter.create<tfrt::gpu::DnnCreateActivationDescriptorOp>(
+          loc, coefficient, activaton_mode, CUDNN_NOT_PROPAGATE_NAN, chain);
+
+  cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+  auto alpha1 = MakeScalingFactorConstant(
+      rewriter, loc, mlir_scale_type, llvm::APFloat(config.conv_result_scale),
+      llvm::APFloat(0.0));
+  auto alpha2 = MakeScalingFactorConstant(
+      rewriter, loc, mlir_scale_type,
+      llvm::APFloat(config.fusion->side_input_scale), llvm::APFloat(0.0));
+  return rewriter
+      .create<tfrt::gpu::DnnConvolutionBiasActivationForwardOp>(
+          loc, handle, stream, scale_type, alpha1, input_desc, adaptor.input(),
+          filter_desc, adaptor.filter(), conv_desc, algorithm,
+          adaptor.scratch(), alpha2, output_desc, side_input, *bias_desc_or,
+          adaptor.bias(), activation_desc, output_desc, adaptor.output(), chain)
+      .getResult();
+}
+
+// Create op to build a convolution plan, which can be used to run the
+// convolution. This is the unfused variant (not fused with activation).
+Value CreateBuildUnfusedConvPlanOp(Value input, Value output, Value handle,
+                                   mlir::Location loc,
+                                   const xla::gpu::GpuConvConfig& config,
+                                   cudnnBackendDescriptorType_t backend_type,
+                                   ConversionPatternRewriter& rewriter) {
   cudnnDataType_t input_type = MlirTypeToCudnnDataType(
-      get_element_type(input), config.input_descriptor.layout());
+      GetMemRefElementType(input), config.input_descriptor.layout());
   cudnnDataType_t output_type = MlirTypeToCudnnDataType(
-      get_element_type(output), config.output_descriptor.layout());
+      GetMemRefElementType(output), config.output_descriptor.layout());
 
   int vector_size, vector_dim;
   std::tie(vector_size, vector_dim) =
@@ -179,50 +287,27 @@
       rewriter.getI64ArrayAttr(tuning_knob_values));
 }
 
-Value CreateBuildFusedConvOp(Value input, Value output, Value bias,
-                             Value handle, mlir::Location loc,
-                             const xla::gpu::GpuConvConfig& config,
-                             cudnnBackendDescriptorType_t backend_type,
-                             ConversionPatternRewriter& rewriter) {
-  se::dnn::BatchDescriptor bias_descriptor(config.output_descriptor.ndims());
-  bias_descriptor.set_count(1)
-      .set_height(1)
-      .set_width(1)
-      .set_feature_map_count(config.output_descriptor.feature_map_count())
-      .set_layout([&] {
-        if (config.algorithm.is_cudnn_frontend()) {
-          // For the purposes of the cudnn graph, say that the bias tensor has
-          // the same layout as the output tensor.  It doesn't actually matter,
-          // because bias is a 1D array.  But we need to get the correct
-          // vectorization, otherwise the cudnn graph API rejects this tensor,
-          // even though vectorized float tensors aren't even a thing in cuDNN.
-          return config.output_descriptor.layout();
-        }
-        // Normalize NCHW_VECT_C to NCHW for layout of `bias`, even though it's
-        // actually the same (because `bias` only has one dimension):  cudnn
-        // does not accept NCHW_VECT_C for `bias`.
-        se::dnn::DataLayout layout = config.output_descriptor.layout();
-        switch (layout) {
-          case se::dnn::DataLayout::kBatchDepthYX4:
-          case se::dnn::DataLayout::kBatchDepthYX32:
-            return se::dnn::DataLayout::kBatchDepthYX;
-          default:
-            return layout;
-        }
-      }());
-  if (bias_descriptor.ndims() == 3) {
-    bias_descriptor.set_spatial_dim(se::dnn::DimIndex::Z, 1);
-  }
+// Create op to build a convolution plan, which can be used to run the
+// convolution. This is the variant with fused activation.
+Value CreateBuildFusedConvPlanOp(Value input, Value output, Value bias,
+                                 Value handle, mlir::Location loc,
+                                 const xla::gpu::GpuConvConfig& config,
+                                 cudnnBackendDescriptorType_t backend_type,
+                                 ConversionPatternRewriter& rewriter) {
+  se::dnn::BatchDescriptor bias_descriptor = GetBiasDescriptor(config);
+  // For the purposes of the cudnn graph, say that the bias tensor has the same
+  // layout as the output tensor.  It doesn't actually matter, because bias is a
+  // 1D array.  But we need to get the correct vectorization, otherwise the
+  // cudnn graph API rejects this tensor, even though vectorized float tensors
+  // aren't even a thing in cuDNN.
+  bias_descriptor.set_layout(config.output_descriptor.layout());
 
-  auto get_element_type = [](Value value) {
-    return value.getType().cast<mlir::MemRefType>().getElementType();
-  };
   cudnnDataType_t input_type = MlirTypeToCudnnDataType(
-      get_element_type(input), config.input_descriptor.layout());
+      GetMemRefElementType(input), config.input_descriptor.layout());
   cudnnDataType_t output_type = MlirTypeToCudnnDataType(
-      get_element_type(output), config.output_descriptor.layout());
-  cudnnDataType_t bias_type =
-      MlirTypeToCudnnDataType(get_element_type(bias), bias_descriptor.layout());
+      GetMemRefElementType(output), config.output_descriptor.layout());
+  cudnnDataType_t bias_type = MlirTypeToCudnnDataType(
+      GetMemRefElementType(bias), bias_descriptor.layout());
 
   int vector_size, vector_dim;
   std::tie(vector_size, vector_dim) =
@@ -305,13 +390,17 @@
   descriptor.kind = xla::gpu::CudnnConvKind::kForward;
   return Status::OK();
 }
-Value GetResult(lmhlo_gpu::ConvForwardOp op) { return op.output(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvForwardOp op, Value handle,
-                        const xla::gpu::GpuConvConfig& config,
-                        cudnnBackendDescriptorType_t backend_type,
-                        ConversionPatternRewriter& rewriter) {
-  return CreateBuildUnfusedConvOp(op.input(), op.output(), handle, op.getLoc(),
-                                  config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvForwardOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvForwardOp op) { return op.output(); }
+Value GetFilter(lmhlo_gpu::ConvForwardOp op) { return op.filter(); }
+Value GetResult(lmhlo_gpu::ConvForwardOp op) { return GetOutput(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvForwardOp op, Value handle,
+                            const xla::gpu::GpuConvConfig& config,
+                            cudnnBackendDescriptorType_t backend_type,
+                            ConversionPatternRewriter& rewriter) {
+  return CreateBuildUnfusedConvPlanOp(op.input(), op.output(), handle,
+                                      op.getLoc(), config, backend_type,
+                                      rewriter);
 }
 Value CreateRunConvolutionOp(lmhlo_gpu::ConvForwardOpAdaptor adaptor,
                              mlir::Location loc, Value handle, Value conv_plan,
@@ -321,6 +410,20 @@
       loc, handle, stream, conv_plan, adaptor.input(), adaptor.output(),
       adaptor.filter(), adaptor.scratch(), chain);
 }
+FailureOr<Value> CreateLegacyConvOp(
+    lmhlo_gpu::ConvForwardOp op, lmhlo_gpu::ConvForwardOpAdaptor adaptor,
+    Type mlir_scale_type, Value handle, Value stream, Value input_desc,
+    Value output_desc, Value filter_desc, Value conv_desc, Value algorithm,
+    Value chain, const xla::gpu::GpuConvConfig& config,
+    ConversionPatternRewriter& rewriter) {
+  cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+  return rewriter
+      .create<tfrt::gpu::DnnConvolutionForwardOp>(
+          op.getLoc(), handle, stream, scale_type, input_desc, adaptor.input(),
+          filter_desc, adaptor.filter(), conv_desc, algorithm,
+          adaptor.scratch(), output_desc, adaptor.output(), chain)
+      .getResult();
+}
 
 // Specialization for convolution backward input
 Status SetConvKind(lmhlo_gpu::ConvBackwardInputOp op,
@@ -328,13 +431,17 @@
   descriptor.kind = xla::gpu::CudnnConvKind::kBackwardInput;
   return Status::OK();
 }
-Value GetResult(lmhlo_gpu::ConvBackwardInputOp op) { return op.d_input(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvBackwardInputOp op, Value handle,
-                        const xla::gpu::GpuConvConfig& config,
-                        cudnnBackendDescriptorType_t backend_type,
-                        ConversionPatternRewriter& rewriter) {
-  return CreateBuildUnfusedConvOp(op.d_input(), op.d_output(), handle,
-                                  op.getLoc(), config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvBackwardInputOp op) { return op.d_input(); }
+Value GetOutput(lmhlo_gpu::ConvBackwardInputOp op) { return op.d_output(); }
+Value GetFilter(lmhlo_gpu::ConvBackwardInputOp op) { return op.filter(); }
+Value GetResult(lmhlo_gpu::ConvBackwardInputOp op) { return GetInput(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvBackwardInputOp op, Value handle,
+                            const xla::gpu::GpuConvConfig& config,
+                            cudnnBackendDescriptorType_t backend_type,
+                            ConversionPatternRewriter& rewriter) {
+  return CreateBuildUnfusedConvPlanOp(op.d_input(), op.d_output(), handle,
+                                      op.getLoc(), config, backend_type,
+                                      rewriter);
 }
 Value CreateRunConvolutionOp(lmhlo_gpu::ConvBackwardInputOpAdaptor adaptor,
                              mlir::Location loc, Value handle, Value conv_plan,
@@ -344,6 +451,21 @@
       loc, handle, stream, conv_plan, adaptor.d_input(), adaptor.d_output(),
       adaptor.filter(), adaptor.scratch(), chain);
 }
+FailureOr<Value> CreateLegacyConvOp(
+    lmhlo_gpu::ConvBackwardInputOp op,
+    lmhlo_gpu::ConvBackwardInputOpAdaptor adaptor, Type mlir_scale_type,
+    Value handle, Value stream, Value input_desc, Value output_desc,
+    Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+    const xla::gpu::GpuConvConfig& config,
+    ConversionPatternRewriter& rewriter) {
+  cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+  return rewriter
+      .create<tfrt::gpu::DnnConvolutionBackwardDataOp>(
+          op.getLoc(), handle, stream, scale_type, filter_desc,
+          adaptor.filter(), output_desc, adaptor.d_output(), conv_desc,
+          algorithm, adaptor.scratch(), input_desc, adaptor.d_input(), chain)
+      .getResult();
+}
 
 // Specialization for convolution backward filter
 Status SetConvKind(lmhlo_gpu::ConvBackwardFilterOp op,
@@ -351,13 +473,17 @@
   descriptor.kind = xla::gpu::CudnnConvKind::kBackwardFilter;
   return Status::OK();
 }
-Value GetResult(lmhlo_gpu::ConvBackwardFilterOp op) { return op.d_filter(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvBackwardFilterOp op, Value handle,
-                        const xla::gpu::GpuConvConfig& config,
-                        cudnnBackendDescriptorType_t backend_type,
-                        ConversionPatternRewriter& rewriter) {
-  return CreateBuildUnfusedConvOp(op.input(), op.d_output(), handle,
-                                  op.getLoc(), config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvBackwardFilterOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvBackwardFilterOp op) { return op.d_output(); }
+Value GetFilter(lmhlo_gpu::ConvBackwardFilterOp op) { return op.d_filter(); }
+Value GetResult(lmhlo_gpu::ConvBackwardFilterOp op) { return GetFilter(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvBackwardFilterOp op, Value handle,
+                            const xla::gpu::GpuConvConfig& config,
+                            cudnnBackendDescriptorType_t backend_type,
+                            ConversionPatternRewriter& rewriter) {
+  return CreateBuildUnfusedConvPlanOp(op.input(), op.d_output(), handle,
+                                      op.getLoc(), config, backend_type,
+                                      rewriter);
 }
 Value CreateRunConvolutionOp(lmhlo_gpu::ConvBackwardFilterOpAdaptor adaptor,
                              mlir::Location loc, Value handle, Value conv_plan,
@@ -367,6 +493,21 @@
       loc, handle, stream, conv_plan, adaptor.input(), adaptor.d_output(),
       adaptor.d_filter(), adaptor.scratch(), chain);
 }
+FailureOr<Value> CreateLegacyConvOp(
+    lmhlo_gpu::ConvBackwardFilterOp op,
+    lmhlo_gpu::ConvBackwardFilterOpAdaptor adaptor, Type mlir_scale_type,
+    Value handle, Value stream, Value input_desc, Value output_desc,
+    Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+    const xla::gpu::GpuConvConfig& config,
+    ConversionPatternRewriter& rewriter) {
+  cudnnDataType_t scale_type = MlirTypeToCudnnDataType(mlir_scale_type);
+  return rewriter
+      .create<tfrt::gpu::DnnConvolutionBackwardFilterOp>(
+          op.getLoc(), handle, stream, scale_type, input_desc, adaptor.input(),
+          output_desc, adaptor.d_output(), conv_desc, algorithm,
+          adaptor.scratch(), filter_desc, adaptor.d_filter(), chain)
+      .getResult();
+}
 
 // Specialization for convolution forward fused
 Status SetConvKind(lmhlo_gpu::ConvForwardFusedOp op,
@@ -382,13 +523,17 @@
       static_cast<int64_t>(activation_mode));
   return Status::OK();
 }
-Value GetResult(lmhlo_gpu::ConvForwardFusedOp op) { return op.output(); }
-Value CreateBuildConvOp(lmhlo_gpu::ConvForwardFusedOp op, Value handle,
-                        const xla::gpu::GpuConvConfig& config,
-                        cudnnBackendDescriptorType_t backend_type,
-                        ConversionPatternRewriter& rewriter) {
-  return CreateBuildFusedConvOp(op.input(), op.output(), op.bias(), handle,
-                                op.getLoc(), config, backend_type, rewriter);
+Value GetInput(lmhlo_gpu::ConvForwardFusedOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvForwardFusedOp op) { return op.output(); }
+Value GetFilter(lmhlo_gpu::ConvForwardFusedOp op) { return op.filter(); }
+Value GetResult(lmhlo_gpu::ConvForwardFusedOp op) { return GetOutput(op); }
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvForwardFusedOp op, Value handle,
+                            const xla::gpu::GpuConvConfig& config,
+                            cudnnBackendDescriptorType_t backend_type,
+                            ConversionPatternRewriter& rewriter) {
+  return CreateBuildFusedConvPlanOp(op.input(), op.output(), op.bias(), handle,
+                                    op.getLoc(), config, backend_type,
+                                    rewriter);
 }
 Value CreateRunConvolutionOp(lmhlo_gpu::ConvForwardFusedOpAdaptor adaptor,
                              mlir::Location loc, Value handle, Value conv_plan,
@@ -399,6 +544,18 @@
       adaptor.filter(), adaptor.output(), adaptor.bias(), adaptor.scratch(),
       chain);
 }
+FailureOr<Value> CreateLegacyConvOp(
+    lmhlo_gpu::ConvForwardFusedOp op,
+    lmhlo_gpu::ConvForwardFusedOpAdaptor adaptor, Type mlir_scale_type,
+    Value handle, Value stream, Value input_desc, Value output_desc,
+    Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+    const xla::gpu::GpuConvConfig& config,
+    ConversionPatternRewriter& rewriter) {
+  return CreateLegacyFusedConvOp(op, adaptor, mlir_scale_type, handle, stream,
+                                 input_desc, output_desc, filter_desc,
+                                 conv_desc, algorithm, adaptor.output(), chain,
+                                 config, rewriter);
+}
 
 // Specialization for convolution forward fused side input
 Status SetConvKind(lmhlo_gpu::ConvForwardFusedSideInputOp op,
@@ -416,15 +573,23 @@
       op.side_input_scale().convertToDouble());
   return Status::OK();
 }
-Value GetResult(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
+Value GetInput(lmhlo_gpu::ConvForwardFusedSideInputOp op) { return op.input(); }
+Value GetOutput(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
   return op.output();
 }
-Value CreateBuildConvOp(lmhlo_gpu::ConvForwardFusedSideInputOp op, Value handle,
-                        const xla::gpu::GpuConvConfig& config,
-                        cudnnBackendDescriptorType_t backend_type,
-                        ConversionPatternRewriter& rewriter) {
-  return CreateBuildFusedConvOp(op.input(), op.output(), op.bias(), handle,
-                                op.getLoc(), config, backend_type, rewriter);
+Value GetFilter(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
+  return op.filter();
+}
+Value GetResult(lmhlo_gpu::ConvForwardFusedSideInputOp op) {
+  return GetOutput(op);
+}
+Value CreateBuildConvPlanOp(lmhlo_gpu::ConvForwardFusedSideInputOp op,
+                            Value handle, const xla::gpu::GpuConvConfig& config,
+                            cudnnBackendDescriptorType_t backend_type,
+                            ConversionPatternRewriter& rewriter) {
+  return CreateBuildFusedConvPlanOp(op.input(), op.output(), op.bias(), handle,
+                                    op.getLoc(), config, backend_type,
+                                    rewriter);
 }
 Value CreateRunConvolutionOp(
     lmhlo_gpu::ConvForwardFusedSideInputOpAdaptor adaptor, mlir::Location loc,
@@ -435,6 +600,111 @@
       adaptor.filter(), adaptor.side_input(), adaptor.bias(), adaptor.scratch(),
       chain);
 }
+FailureOr<Value> CreateLegacyConvOp(
+    lmhlo_gpu::ConvForwardFusedSideInputOp op,
+    lmhlo_gpu::ConvForwardFusedSideInputOpAdaptor adaptor, Type mlir_scale_type,
+    Value handle, Value stream, Value input_desc, Value output_desc,
+    Value filter_desc, Value conv_desc, Value algorithm, Value chain,
+    const xla::gpu::GpuConvConfig& config,
+    ConversionPatternRewriter& rewriter) {
+  Value side_input = config.fusion->side_input_scale == 0
+                         ? adaptor.output()
+                         : adaptor.side_input();
+  return CreateLegacyFusedConvOp(
+      op, adaptor, mlir_scale_type, handle, stream, input_desc, output_desc,
+      filter_desc, conv_desc, algorithm, side_input, chain, config, rewriter);
+}
+
+template <class ConvolutionOpType, class OpAdaptor>
+FailureOr<Value> LegacyConvolutionRewritePattern(
+    ConvolutionOpType op, OpAdaptor adaptor, Value chain, Value stream,
+    const xla::gpu::GpuConvConfig& config,
+    ConversionPatternRewriter& rewriter) {
+  cudnnDataType_t input_type = MlirTypeToCudnnDataType(
+      GetMemRefElementType(GetInput(op)), config.input_descriptor.layout());
+  cudnnDataType_t output_type = MlirTypeToCudnnDataType(
+      GetMemRefElementType(GetOutput(op)), config.output_descriptor.layout());
+  cudnnDataType_t filter_type = MlirTypeToCudnnDataType(
+      GetMemRefElementType(GetInput(op)), config.filter_descriptor.layout());
+
+  // Create input descriptor.
+  FailureOr<Value> input_desc_or = CreateLegacyTensorDescriptor(
+      op, config.input_descriptor, input_type, chain, rewriter);
+  if (failed(input_desc_or)) {
+    return input_desc_or;
+  }
+
+  // Create output descriptor.
+  FailureOr<Value> output_desc_or = CreateLegacyTensorDescriptor(
+      op, config.output_descriptor, output_type, chain, rewriter);
+  if (failed(output_desc_or)) {
+    return output_desc_or;
+  }
+
+  // Create filter descriptor.
+  cudnnTensorFormat_t tensor_format;
+  switch (config.filter_descriptor.layout()) {
+    case se::dnn::FilterLayout::kOutputInputYX:
+      tensor_format = CUDNN_TENSOR_NCHW;
+      break;
+    case se::dnn::FilterLayout::kOutputYXInput:
+      tensor_format = CUDNN_TENSOR_NHWC;
+      break;
+    case se::dnn::FilterLayout::kOutputInputYX4:
+    case se::dnn::FilterLayout::kOutputInputYX32: {
+      tensor_format = CUDNN_TENSOR_NCHW_VECT_C;
+      break;
+    }
+    default:
+      return rewriter.notifyMatchFailure(op, "Unexpected filter layout.");
+  }
+  std::vector<int> dims(2 + config.filter_descriptor.ndims());
+  dims[0] = config.filter_descriptor.output_feature_map_count();
+  dims[1] = config.filter_descriptor.input_feature_map_count();
+  absl::Span<const int64_t> spatial_dims =
+      config.filter_descriptor.input_filter_dims();
+  std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
+  auto loc = op.getLoc();
+  Value filter_desc = rewriter.create<tfrt::gpu::DnnCreateFilterDescriptorOp>(
+      loc, filter_type, tensor_format, rewriter.getI32ArrayAttr(dims), chain);
+
+  // Create convolution descriptor.
+  mlir::Type mlir_compute_type = GetMemRefElementType(GetInput(op));
+  cudnnDataType_t compute_type = MlirTypeToCudnnDataType(mlir_compute_type);
+  cudnnConvolutionMode_t conv_mode =
+      config.conv_desc.convolution_not_crosscorr() ? CUDNN_CONVOLUTION
+                                                   : CUDNN_CROSS_CORRELATION;
+  const auto& convolution_descriptor = config.conv_desc;
+  // cuDNN requires arrays of ints.
+  std::vector<int> strides = CheckedNarrowing(convolution_descriptor.strides());
+  std::vector<int> padding = CheckedNarrowing(convolution_descriptor.padding());
+  std::vector<int> dilations =
+      CheckedNarrowing(convolution_descriptor.dilations());
+  cudnnMathType_t math_type = config.algorithm.tensor_ops_enabled()
+                                  ? CUDNN_TENSOR_OP_MATH
+                                  : CUDNN_FMA_MATH;
+  Value conv_desc =
+      rewriter.create<tfrt::gpu::DnnCreateConvolutionDescriptorOp>(
+          loc, compute_type, conv_mode, math_type,
+          rewriter.getI32ArrayAttr(padding), rewriter.getI32ArrayAttr(strides),
+          rewriter.getI32ArrayAttr(dilations), chain);
+
+  // Create convolution op.
+  mlir::Type mlir_scale_type =
+      mlir_compute_type.isF64() ? rewriter.getF64Type() : rewriter.getF32Type();
+  Value context = rewriter.create<tfrt::gpu::StreamGetContextOp>(loc, stream);
+  Value handle = rewriter.create<tfrt::gpu::DnnCreateOp>(loc, context);
+  Value algorithm = rewriter.create<tfrt::compiler::ConstantUI64Op>(
+      loc, config.algorithm.algo_id());
+  auto out_chain_or =
+      CreateLegacyConvOp(op, adaptor, mlir_scale_type, handle, stream,
+                         *input_desc_or, *output_desc_or, filter_desc,
+                         conv_desc, algorithm, chain, config, rewriter);
+  if (succeeded(out_chain_or)) {
+    rewriter.eraseOp(op);
+  }
+  return out_chain_or;
+}
 
 template <class ConvolutionOpType>
 struct ConvolutionRewritePattern
@@ -472,6 +742,35 @@
           op, "TensorFlow padding alignment is not supported.");
     }
 
+    bool use_legacy_conv = [&] {
+      if (!config.algorithm.is_cudnn_frontend()) return true;
+
+      auto print_reason = [&](const char* reason) {
+        LOG(ERROR)
+            << "Disabling cuDNN frontend for the following convolution:\n"
+            << "  input: " << config.input_descriptor.ToString() << "\n"
+            << "  filter: " << config.filter_descriptor.ToString() << "\n"
+            << "  conv: " << config.conv_desc.ToString() << "\n... because "
+            << reason;
+      };
+
+      if (config.input_descriptor.layout() ==
+          se::dnn::DataLayout::kBatchDepthYX32)
+        // Current versions of the frontend API lack support for Tx32.
+        return print_reason("Tx32 convolutions are unsupported."), true;
+
+      if (CUDNN_VERSION < 8100)
+        // cuDNN frontend support became sufficiently stable to use in 8.1.
+        return print_reason("the cuDNN version does not support it."), true;
+
+      return false;
+    }();
+
+    if (use_legacy_conv) {
+      return LegacyConvolutionRewritePattern(op, adaptor, chain, stream, config,
+                                             rewriter);
+    }
+
     cudnnBackendDescriptorType_t backend_type;
     switch (descriptor.kind) {
       case xla::gpu::CudnnConvKind::kForward:
@@ -507,8 +806,8 @@
         rewriter.getFunctionType(handle_type, conv_plan_type));
     symbol_table.insert(conv_plan_func);
     rewriter.setInsertionPointToEnd(conv_plan_func.addEntryBlock());
-    Value conv_plan = CreateBuildConvOp(op, conv_plan_func.getArgument(0),
-                                        config, backend_type, rewriter);
+    Value conv_plan = CreateBuildConvPlanOp(op, conv_plan_func.getArgument(0),
+                                            config, backend_type, rewriter);
     rewriter.create<tfrt::compiler::ReturnOp>(loc, conv_plan);
 
     // Once-initialize the convolution plan.
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc
index 7b510e2..8fcaa2e 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.cc
@@ -65,7 +65,7 @@
       }
       break;
     case se::dnn::DataLayout::kBatchDepthYX32:
-      if (type.isSignlessInteger(/*width=*/32)) {
+      if (type.isSignlessInteger(/*width=*/8)) {
         return CUDNN_DATA_INT8x32;
       }
       break;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 172b5a2..f780a12 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -2559,6 +2559,8 @@
         "//tensorflow/compiler/xla/service/gpu/tests:kernel_launch_test",
         "//tensorflow/compiler/xla/service/gpu/tests:mlir_gemm_test",
         "//tensorflow/compiler/xla/tests:cholesky_test_gpu",
+        "//tensorflow/compiler/xla/tests:convolution_test_cudnn_frontend_disabled_gpu",
+        "//tensorflow/compiler/xla/tests:convolution_test_gpu",
         "//tensorflow/compiler/xla/tests:dot_operation_test_gpu",
         "//tensorflow/compiler/xla/tests:multioutput_fusion_test_gpu",
         "//tensorflow/compiler/xla/tests:scatter_test_gpu",
@@ -2622,6 +2624,8 @@
         "//tensorflow/compiler/xla/service/gpu/tests:sorting_test",
         "//tensorflow/compiler/xla/service/gpu/tests:tree_reduction_rewriter_test",
         "//tensorflow/compiler/xla/tests:cholesky_test_gpu",
+        "//tensorflow/compiler/xla/tests:convolution_test_cudnn_frontend_disabled_gpu",
+        "//tensorflow/compiler/xla/tests:convolution_test_gpu",
         "//tensorflow/compiler/xla/tests:multioutput_fusion_test_gpu",
         "//tensorflow/compiler/xla/tests:scatter_test_gpu",
         "//tensorflow/compiler/xla/tests:select_and_scatter_test_gpu",
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index c0c0597..ad3e420 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1273,6 +1273,24 @@
 )
 
 xla_test(
+    name = "convolution_test_cudnn_frontend_disabled",
+    timeout = "long",
+    srcs = ["convolution_test.cc"],
+    backend_args = {"gpu": ["--xla_gpu_enable_cudnn_frontend=false"]},
+    backends = ["gpu"],
+    shard_count = 50,
+    tags = [
+        "no_rocm",
+        "nozapfhahn",
+        "optonly",
+    ],
+    deps = CONVOLUTION_TEST_DEPS + [
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+xla_test(
     name = "convolution_variants_test",
     timeout = "long",
     srcs = ["convolution_variants_test.cc"],