Roll forward cl/420217183 with fix.

The original change had a bug where it would incorrectly fuse relu even if the
conv had a user that *wasn't* relu.  See the new DontFuseReluIfMultipleUses
test.  This was caught by the failing TAP test.  I also added similar tests for
the other fusions in this patch, though they didn't exhibit this bug.

I also added compiler fuel to this patch.

PiperOrigin-RevId: 420515732
Change-Id: I58d3a33f912b703f2afc5f87281ee04925629f4c
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 9569653..c9707cc 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -2204,16 +2204,11 @@
     deps = [
         ":backend_configs_cc",
         ":cublas_cudnn",
-        "//tensorflow/compiler/xla:comparison_util",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla/service:hlo",
-        "//tensorflow/compiler/xla/service:hlo_creation_utils",
         "//tensorflow/compiler/xla/service:hlo_pass",
         "//tensorflow/compiler/xla/service:pattern_matcher",
-        "//tensorflow/core/platform:errors",
-        "//tensorflow/core/platform:statusor",
         "//tensorflow/core/platform:stream_executor_no_cuda",
-        "//tensorflow/stream_executor:dnn_proto_cc",
     ],
 )
 
@@ -2228,20 +2223,11 @@
         "requires-gpu-sm70",
     ],
     deps = [
-        ":backend_configs_cc",
         ":cublas_cudnn",
         ":cudnn_fused_conv_rewriter",
-        ":gpu_conv_rewriter",
         ":ir_emission_utils",
         "//tensorflow/compiler/xla:test_helpers",
-        "//tensorflow/compiler/xla/service:algebraic_simplifier",
-        "//tensorflow/compiler/xla/service:hlo_constant_folding",
         "//tensorflow/compiler/xla/service:hlo_parser",
-        "//tensorflow/compiler/xla/service:hlo_pass",
-        "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
-        "//tensorflow/compiler/xla/service:pattern_matcher",
-        "//tensorflow/compiler/xla/service:pattern_matcher_gmock",
-        "//tensorflow/compiler/xla/service:reshape_mover",
         "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
         "//tensorflow/compiler/xla/tests:filecheck",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
index c3d31dd..7a1e9dd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
@@ -15,806 +15,479 @@
 
 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
 
-#include <functional>
-#include <string>
-
 #include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/statusor.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/stream_executor/dnn.pb.h"
 
 namespace xla {
 namespace gpu {
 namespace {
 
-namespace m = match;
+// Describes matched patterns:
+//   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+//   for floating point types or
+//   max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
+//   * side_input + broadcast(bias));
+//   for int8_t.
+// Where side_input has the shape of output buffer, and bias is a 1D array with
+// the dimension of number of output features.
+struct ConvWithRelu {
+  HloInstruction* maximum;
+  HloCustomCallInstruction* conv;
+  HloInstruction* bias;
+  HloInstruction* side_input;
+  HloConstantInstruction* alpha_conv;
+  HloConstantInstruction* alpha_side_input;
+};
 
-// If VLOG is on and `instr` matches `filter_pattern`, prints out why it doesn't
-// match `log_pattern`.  You can use this to explain "near-hits".
-template <typename FilterPattern, typename LogPattern>
-void VlogIfFailureToMatch(HloInstruction* instr, const LogPattern& log_pattern,
-                          absl::string_view desc,
-                          const FilterPattern& filter_pattern) {
-  if (!VLOG_IS_ON(3) || !Match(instr, filter_pattern)) {
-    return;
+// The pattern we want to match:
+//   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+//   or
+//   max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
+//   * side_input + broadcast(bias));
+// With its variants involving commute/reassociation of adds, multiplies, and
+// max, and omission of alpha1, side_input, alpha2, or bias.
+absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
+  using match::Add;
+  using match::AddAnyOrder;
+  using match::AnyOf;
+  using match::Broadcast;
+  using match::ConstantScalar;
+  using match::GetTupleElement;
+  using match::Maximum;
+  using match::MultiplyAnyOrder;
+  using match::Op;
+
+  HloInstruction* relu_input;
+
+  // Match max(0, relu_input).
+  auto zero_pattern = Broadcast(ConstantScalar(0));
+  if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
+      !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
+    return absl::nullopt;
   }
-  std::stringstream os;
-  if (!Match(instr, log_pattern, {/*capture=*/false, /*explain_os=*/&os})) {
-    VLOG(3) << "Failed to match " << desc << ":\n" << os.str();
-  }
-}
+  HloInstruction* conv_instr = nullptr;
+  HloInstruction* alpha_conv_instr = nullptr;
+  HloInstruction* alpha_side_input_instr = nullptr;
+  HloInstruction* bias_broadcast_instr = nullptr;
+  HloInstruction* bias = nullptr;
+  HloInstruction* side_input = nullptr;
 
-bool IsConvCustomCall(const HloInstruction* instr) {
-  return instr->opcode() == HloOpcode::kCustomCall &&
-         (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
-          instr->custom_call_target() ==
-              kCudnnConvBiasActivationForwardCallTarget);
-}
+  // These nodes will not be in the returned value, but we need to check them
+  // for single use.
+  HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
+                 *mul1 = nullptr, *mul2 = nullptr;
 
-// Can instr be converted to type `dst_ty` without losing any precision?  For
-// our purposes, this is true if:
-//
-//  - instr already has type dst_ty, or
-//  - instr is convert<wider type>(op_with_dst_ty), or
-//  - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
-//    get back exactly the original value, or
-//  - instr is a broadcast, reshape, or transpose of one of the above.
-bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
-                               PrimitiveType dst_ty) {
-  if (instr->shape().element_type() == dst_ty) {
-    return true;
-  }
-
-  if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
-    // Check that the convert from dst_ty to instr->element_type() doesn't lose
-    // precision.  Otherwise, this convert is not lossless.
-    return primitive_util::CastPreservesValues(dst_ty,
-                                               instr->shape().element_type());
-  }
-
-  if (instr->opcode() == HloOpcode::kConstant) {
-    if (!instr->shape().IsArray()) {
-      return false;
-    }
-    // Check if instr's literal roundtrips to ty and back to its original type
-    // without modification.
-    PrimitiveType orig_ty = instr->shape().element_type();
-
-    // The only reason Convert() should fail is if we don't support converting
-    // from x to y, which indeed means it's not losslessly-convertible.
-    StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
-    if (!converted1.ok()) {
-      return false;
-    }
-    StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
-    if (!converted2.ok()) {
-      return false;
-    }
-
-    return instr->literal() == *converted2;
-  }
-
-  if (instr->opcode() == HloOpcode::kBroadcast ||
-      instr->opcode() == HloOpcode::kReshape ||
-      instr->opcode() == HloOpcode::kTranspose) {
-    return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
-  }
-
-  return false;
-}
-
-// Helpers suitable for use in m::Op().WithPredicate(...).
-bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
-  return IsLosslesslyConvertibleTo(instr, S8);
-}
-bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
-  return IsLosslesslyConvertibleTo(instr, F16);
-}
-
-// If `conv` is a vanilla forward conv, transforms it into a
-// conv-bias-activation.  If it's already a conv-bias-activation, does nothing.
-//
-// If `conv` is anything else, returns an error.
-StatusOr<HloInstruction*> EnsureIsConvBiasActivation(HloInstruction* conv) {
-  CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
-
-  if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
-    return conv;
-  }
-
-  if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
-    HloComputation* comp = conv->parent();
-
-    const Shape& shape = conv->shape().tuple_shapes(0);
-    int64_t num_output_features = shape.dimensions(
-        conv->convolution_dimension_numbers().output_feature_dimension());
-
-    // bias for integer convs is always f32, see
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    PrimitiveType bias_ty;
-    if (primitive_util::IsIntegralType(shape.element_type())) {
-      bias_ty = F32;
-    } else {
-      bias_ty = shape.element_type();
-    }
-    auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
-
-    absl::InlinedVector<HloInstruction*, 3> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    new_operands.push_back(bias);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(conv->shape(), new_operands));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
-    new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
-    comp->parent()->SetAndUniquifyInstrName(new_conv,
-                                            "cudnn-conv-bias-activation");
-    return new_conv;
-  }
-
-  return FailedPrecondition("Unsupported conv: %s", conv->ToString());
-}
-
-// convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
-// gte(custom-call<float>(int8_x, int8_w))
-StatusOr<bool> FuseConvertToFloat(HloComputation* comp) {
-  bool changed = false;
-  for (auto instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv = nullptr;
-    auto pattern =
-        m::Convert(
-            m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
-                .WithElementType(S32))
-            .WithElementType(F32);
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvertToFloat: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    Shape new_shape = conv->shape();
-    new_shape.mutable_tuple_shapes(0)->set_element_type(F32);
-    HloInstruction* new_conv =
-        comp->AddInstruction(conv->CloneWithNewShape(new_shape));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
-
-    changed = true;
-  }
-
-  return changed;
-}
-
-// alpha * gte(custom-call(...)) ->
-// gte(custom-call(..., backend_config={alpha})).
-StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
-  bool changed = false;
-  for (auto instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv = nullptr;
-    HloInstruction* gte = nullptr;
-    HloInstruction* alpha = nullptr;
-    auto pattern = m::MultiplyAnyOrder(
-        m::GetTupleElement(&gte, m::Op(&conv).WithPredicate(IsConvCustomCall),
-                           0)
-            .WithOneUse(),
-        m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-
-    // alpha is f32 except for f64 convs, where it's f64.  See
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
-    if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
-      continue;
-    }
-
-    TF_ASSIGN_OR_RETURN(auto config,
-                        conv->backend_config<CudnnConvBackendConfig>());
-    if (config.conv_result_scale() != 1) {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvAlpha: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    // StreamExecutor doesn't support the alpha parameter on non-bias-activation
-    // convs, so we have to upgrade `conv`.
-    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-
-    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
-    config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
-
-    TF_RETURN_IF_ERROR(conv->set_backend_config(config));
-    TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
-
-    changed = true;
-  }
-  return changed;
-}
-
-StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
-  bool changed = false;
-  for (auto instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv = nullptr;
-    HloInstruction* gte = nullptr;
-    HloInstruction* addend = nullptr;
-    auto pattern = m::AddAnyOrder(
-        m::GetTupleElement(
-            &gte, m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse(), 0)
-            .WithOneUse(),
-        m::Op(&addend));
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-
-    // If it's a vanilla forward conv, upgrade it to a bias-activation conv.  We
-    // only want to do this if the fusion will succeed, but we're guaranteed
-    // that it will, because the only reason we'll bail at this point is if
-    // !can_accept_bias && !can_accept_side_input, and our shiny new
-    // bias-activation conv will be able to accept both.
-    if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
-      TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-    }
-
-    // Does `conv` already have a (nonzero) bias?  Does it already have a
-    // side_input?
-    bool can_accept_bias =
-        Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
-    bool can_accept_side_input = conv->operand_count() < 4;
-
-    // The addend can be fused as a bias if
-    //  - it is 1D broadcasted in the output feature dimension, and
-    //  - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
-    //    convs, and f16 for f16 convs)
-    PrimitiveType conv_ty = gte->shape().element_type();
-    PrimitiveType bias_ty = conv_ty == F16 ? F16 : F32;
-    bool addend_may_be_rank1_bias =
-        addend->opcode() == HloOpcode::kBroadcast &&
-        addend->dimensions().size() == 1 &&
-        addend->dimensions(0) ==
-            conv->convolution_dimension_numbers().output_feature_dimension() &&
-        IsLosslesslyConvertibleTo(addend, bias_ty);
-
-    bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
-                                    addend->dimensions().empty() &&
-                                    IsLosslesslyConvertibleTo(addend, bias_ty);
-
-    absl::InlinedVector<HloInstruction*, 4> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    TF_ASSIGN_OR_RETURN(auto config,
-                        conv->backend_config<CudnnConvBackendConfig>());
-    if (can_accept_bias && addend_may_be_rank1_bias) {
-      new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty);
-    } else if (can_accept_bias && addend_may_be_rank0_bias) {
-      new_operands[2] = MakeBroadcastHlo(
-          MakeConvertToHlo(addend->mutable_operand(0), bias_ty),
-          /*broadcast_dimensions=*/{},
-          /*result_shape_bounds=*/
-          {gte->shape().dimensions(conv->convolution_dimension_numbers()
-                                       .output_feature_dimension())});
-    } else if (can_accept_side_input) {
-      CHECK_EQ(new_operands.size(), 3);
-      new_operands.push_back(addend);
-      config.set_side_input_scale(1);
-    } else {
-      // Can't fuse; this op already has a bias and a side-input.
-      continue;
-    }
-
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(conv->shape(), new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
-    changed = true;
-  }
-  return changed;
-}
-
-// custom-call(..., alpha * side_input) ->
-// custom-call(..., side_input, backend_config={alpha}).
-//
-// We also have to support the more complicated case of
-//
-//   custom-call(..., reshape(side_input * alpha)) -->
-//   custom-call(..., reshape(side_input), backend_config={alpha}),
-//
-// where `reshape` can be an arbitrary chain of reshapes+transposes.  This idiom
-// is created by the ReshapeMover pass.
-StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* conv;
-    HloInstruction* side_input;
-    auto pattern = m::Op(&conv)
-                       .WithPredicate(IsConvCustomCall)
-                       .WithOperand(3, m::Op(&side_input));
-    if (!Match(instr, pattern)) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(auto config,
-                        conv->backend_config<CudnnConvBackendConfig>());
-    if (config.side_input_scale() != 1) {
-      continue;
-    }
-
-    // Given side_input, pattern match the following (working from bottom up).
+  const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
+  const auto conv_pattern = [&] {
+    auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
+    auto conv_pattern = GetTupleElement(
+        &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
+    return AnyOf<HloInstruction>(
+        MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
+  }();
+  const auto side_input_pattern = [&] {
+    auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
+    // If bias is already matched, match arbitrary additional input as side
+    // input. Note this may force a cheap operation (e.g. broadcast) to be
+    // materialized into a large buffer, as large as the output buffer.
     //
-    // before_reshape = multiply(base, broadcast(alpha))
-    // side_input = chain_of_reshapes_and_transposes(before_reshape)
+    // TODO(timshen): If in practice there are significant false positives, we
+    // should fix it.
+    auto side_input_pattern = Op(&side_input);
+    return AnyOf<HloInstruction>(
+        MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
+        side_input_pattern);
+  }();
+
+  {
+    // Try to match any of the following form of add, in any association:
+    //   addends[0]
+    //   addends[0] + addends[1]
+    //   addends[0] + addends[1] + addends[2]
     //
-    // where alpha is a scalar constant.
-    //
-    // alpha is f32 except for f64 convs, where it's f64.  See
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    HloInstruction* before_reshape = side_input;
-    while (before_reshape->opcode() == HloOpcode::kReshape ||
-           before_reshape->opcode() == HloOpcode::kTranspose) {
-      before_reshape = before_reshape->mutable_operand(0);
-    }
-
-    PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
-    PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
-    HloInstruction* base;
-    HloInstruction* alpha;
-    if (!Match(
-            before_reshape,
-            m::MultiplyAnyOrder(
-                m::Op(&base),
-                m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
-                    [&](const HloInstruction* instr) {
-                      return IsLosslesslyConvertibleTo(instr, alpha_ty);
-                    }))))) {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    // Rewrite conv's operand 3 to
-    //
-    //   chain_of_reshapes_and_transposes(before_reshape).
-    //
-    // and store alpha in the conv's backend config.
-    //
-    // We're going to do something bad here: We aren't going to check that the
-    // chain of reshapes/transposes has one use, so we're potentially
-    // duplicating all these instructions (once with alpha and once without).
-    //
-    // This is justified because
-    //
-    //  - duplicating reshapes/transposes shouldn't be "that bad" -- these
-    //    instructions can usually be fused, and
-    //
-    //  - *not* fusing alpha can be catastrophic.  For s8->s8 convolutions, the
-    //    side-input must be s8.  But the product side_input * alpha is f32, so
-    //    we can only see that side-input is s8 if we fuse alpha. IOW not fusing
-    //    alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
-    //    slower than some extra transposes.
-
-    // Recursively clone chain_of_reshapes_and_transposes until we get to
-    // `before_reshape`, at which point we skip the multiply(base, alpha) and
-    // just return base.
-    std::function<HloInstruction*(const HloInstruction*)> clone =
-        [&](const HloInstruction* instr) {
-          if (instr == before_reshape) {
-            return base;
-          }
-          CHECK(instr->opcode() == HloOpcode::kReshape ||
-                instr->opcode() == HloOpcode::kTranspose)
-              << "Must be reshape or transpose: " << instr->ToString();
-          return comp->AddInstruction(instr->CloneWithNewOperands(
-              instr->shape(), {clone(instr->operand(0))}));
-        };
-    absl::InlinedVector<HloInstruction*, 4> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    new_operands[3] = clone(side_input);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(conv->shape(), new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-
-    TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
-    config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
-    TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
-
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
-    changed = true;
-  }
-  return changed;
-}
-
-StatusOr<bool> FuseRelu(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* gte;
-    HloInstruction* conv;
-    if (!Match(
-            instr,
-            m::MaximumAnyOrder(
-                m::Broadcast(m::ConstantEffectiveScalar(0)),
-                m::GetTupleElement(
-                    &gte,
-                    m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse())
-                    .WithOneUse()))) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
-                        conv->backend_config<CudnnConvBackendConfig>());
-    if (config.activation_mode() != se::dnn::kNone) {
-      continue;
-    }
-
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseRelu: ", conv->ToString());
-        })) {
-      continue;
-    }
-    TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-    config.set_activation_mode(se::dnn::kRelu);
-    TF_RETURN_IF_ERROR(conv->set_backend_config(config));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
-    changed = true;
-  }
-  return changed;
-}
-
-StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* gte = nullptr;
-    HloInstruction* conv = nullptr;
-
-    auto f32_convertible_to_f16_pattern =
-        m::Op().WithElementType(F32).WithPredicate(
-            IsLosslesslyConvertibleToF16);
-    auto pattern =
-        m::Convert(
-            m::GetTupleElement(
-                &gte,
-                m::Op(&conv)
-                    .WithPredicate(IsConvCustomCall)
-                    .WithOperand(0, f32_convertible_to_f16_pattern)
-                    .WithOperand(1, f32_convertible_to_f16_pattern)
-                    .WithOperandIfPresent(2, f32_convertible_to_f16_pattern)
-                    .WithOperandIfPresent(3, f32_convertible_to_f16_pattern),
-                0)
-                .WithOneUse())
-            .WithElementType(F16);
-    if (!Match(instr, pattern)) {
-      VlogIfFailureToMatch(
-          instr, pattern, "fp16 conv",
-          m::Op().WithOperand(
-              0, m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))));
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvertToF16: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    VLOG(2) << "Matched fp16 conv: " << conv->ToString();
-
-    // In fp16 convs, all operands, including `bias`, must be fp16.  This is
-    // different from int8 convs, where the bias is fp32.  See table of
-    // supported datatypes at
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    absl::InlinedVector<HloInstruction*, 4> new_operands;
-    for (HloInstruction* operand : conv->operands()) {
-      new_operands.push_back(MakeConvertToHlo(operand, F16));
-    }
-
-    Shape new_shape = conv->shape();
-    new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(new_shape, new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
-    changed = true;
-  }
-  return changed;
-}
-
-StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
-  bool changed = false;
-  for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
-    HloInstruction* gte = nullptr;
-    HloInstruction* conv = nullptr;
-
-    auto conv_pattern =
-        m::Op(&conv)
-            .WithPredicate(IsConvCustomCall)
-            .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
-            .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
-
-    // int8 -> int8 conv
-    auto s8_pattern =
-        m::Convert(
-            m::Clamp(
-                m::Broadcast(m::ConstantEffectiveScalar(-128)),
-                m::GetTupleElement(
-                    &gte,
-                    conv_pattern.WithOperandIfPresent(
-                        3, m::Op().WithPredicate(IsLosslesslyConvertibleToS8)),
-                    0)
-                    .WithOneUse(),
-                m::Broadcast(m::ConstantEffectiveScalar(127))))
-            .WithElementType(S8);
-
-    // int8 -> fp32 conv
-    auto f32_pattern = m::GetTupleElement(&gte,
-                                          conv_pattern.WithOperandIfPresent(
-                                              3, m::Op().WithElementType(F32)),
-                                          0)
-                           .WithElementType(F32);
-
-    VlogIfFailureToMatch(
-        instr, s8_pattern, "s8->s8 conv",
-        m::Convert(m::Clamp(m::Op(),  //
-                            m::GetTupleElement(
-                                m::Op().WithPredicate(IsConvCustomCall)),  //
-                            m::Op()))
-            .WithElementType(S8));
-
-    VlogIfFailureToMatch(
-        instr, f32_pattern, "s8->f32 conv",
-        m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
-            .WithElementType(F32));
-
-    PrimitiveType conv_output_ty;
-    if (Match(instr, s8_pattern)) {
-      conv_output_ty = S8;
-    } else if (Match(instr, f32_pattern)) {
-      conv_output_ty = F32;
-    } else {
-      continue;
-    }
-    if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
-          return absl::StrCat("FuseConvertToS8: ", conv->ToString());
-        })) {
-      continue;
-    }
-
-    absl::InlinedVector<HloInstruction*, 4> new_operands(
-        conv->operands().begin(), conv->operands().end());
-    new_operands[0] = MakeConvertToHlo(new_operands[0], S8);
-    new_operands[1] = MakeConvertToHlo(new_operands[1], S8);
-    // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn.  See
-    // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-    if (new_operands.size() >= 4) {
-      // side-input always matches conv output type.  We checked in the patterns
-      // above that it's losslessly-convertible to this type.
-      new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty);
-    }
-
-    Shape new_shape = conv->shape();
-    new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
-
-    HloInstruction* new_conv = comp->AddInstruction(
-        conv->CloneWithNewOperands(new_shape, new_operands));
-    comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-    TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
-                        MakeGetTupleElementHlo(new_conv, 0));
-    TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
-    changed = true;
-  }
-  return changed;
-}
-
-Status CheckNoS32Convs(HloComputation* comp) {
-  std::vector<HloInstruction*> bad_convs;
-  for (HloInstruction* instr : comp->instructions()) {
-    if (!IsConvCustomCall(instr)) {
-      continue;
-    }
-    if (instr->shape().tuple_shapes(0).element_type() == S32 ||
-        instr->operand(0)->shape().element_type() == S32 ||
-        instr->operand(1)->shape().element_type() == S32 ||
-        (instr->operand_count() >= 4 &&
-         instr->operand(3)->shape().element_type() == S32)) {
-      bad_convs.push_back(instr);
-    }
-  }
-
-  if (bad_convs.empty()) {
-    return Status::OK();
-  }
-
-  return Unimplemented(
-      R"(
-Can't lower one or more integer convolutions to idioms supported by CuDNN.
-
-CuDNN integer convolutions must have:
-
-  - s8 input and filter,
-  - f32 bias (if present),
-  - s8 or f32 output, and
-  - s8 side_input (if present) if output is s8.
-
-For each of the unsupported convs below, we weren't able to lower one of the
-operands or the output to the appropriate type.
-
-See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
-
-https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
-https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
-
-Unsupported convs:
-%s
-
-******* Full HLO module *******
-%s
-)",
-      absl::StrJoin(bad_convs, "\n",
-                    [](std::string* out, HloInstruction* instr) {
-                      absl::StrAppend(out, " - ", instr->ToString());
-                    }),
-      comp->parent()->ToString());
-}
-
-void VlogStats(HloModule* module) {
-  if (!VLOG_IS_ON(1)) {
-    return;
-  }
-
-  VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
-  absl::flat_hash_map<std::string, int> stats;
-  for (HloComputation* comp : module->MakeNonfusionComputations()) {
-    for (HloInstruction* instr : comp->instructions()) {
-      if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
-        continue;
-      }
-
-      VLOG(3) << instr->ToString();
-
-      if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
-        stats["01 non-fused forward convs"]++;
-      } else if (instr->custom_call_target() ==
-                 kCudnnConvBiasActivationForwardCallTarget) {
-        stats["02 fused forward convs"]++;
-      }
-
-      PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
-      PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
-      if (conv_in_ty == F32) {
-        stats["10 f32 convs"]++;
-      } else if (conv_in_ty == F16) {
-        stats["11 f16 convs"]++;
-      } else if (conv_in_ty == S8) {
-        if (conv_out_ty == S8) {
-          stats["12 s8->s8 convs"]++;
-        } else if (conv_out_ty == F32) {
-          stats["13 s8->f32 convs"]++;
+    // Then try to match each addend with one of the three patterns: bias, conv,
+    // or side_input. Notice that side_input matching must go last, as it
+    // also matches a conv or a bias.
+    HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
+    auto add3_pattern = [&] {
+      auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
+      return AnyOf<HloInstruction>(
+          AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
+          Op(&addends[0]));
+    }();
+    CHECK(Match(relu_input, add3_pattern));
+    for (auto addend : addends) {
+      if (addend) {
+        if (bias == nullptr && Match(addend, bias_pattern)) {
+          CHECK(bias);
+        } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
+          CHECK(conv_instr);
+        } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
+          CHECK(side_input);
         } else {
-          LOG(ERROR) << "Unexpected conv: " << instr->ToString();
+          return absl::nullopt;
         }
       }
-
-      if (instr->operand_count() > 2) {
-        stats["20 convs with bias"]++;
-        if (Match(instr->operand(2),
-                  m::Broadcast(m::ConstantEffectiveScalar(0)))) {
-          stats["21 convs with 0 bias"]++;
-        }
-      }
-      if (instr->operand_count() > 3) {
-        stats["22 convs with side-input"]++;
-      }
-
-      auto config = instr->backend_config<CudnnConvBackendConfig>();
-      if (!config.ok()) {
-        LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
-        continue;
-      }
-
-      if (config->conv_result_scale() != 1) {
-        stats["30 convs with result scale"]++;
-      }
-      if (config->side_input_scale() != 0 && config->side_input_scale() != 1) {
-        stats["31 convs with side-input scale"]++;
-      }
-      stats[absl::StrCat(
-          "32 convs with activation mode ",
-          se::dnn::ActivationMode_Name(config->activation_mode()))]++;
     }
   }
 
-  std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
-                                                        stats.end());
-  absl::c_sort(stats_sorted);
-  for (const auto& kv : stats_sorted) {
-    VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
-                                  absl::string_view(kv.first).substr(3));
+  if (conv_instr == nullptr) {
+    return absl::nullopt;
   }
+
+  for (HloInstruction* instr :
+       {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
+    if (instr && instr->user_count() > 1) {
+      return absl::nullopt;
+    }
+  }
+
+  auto conv = Cast<HloCustomCallInstruction>(conv_instr);
+  auto bias_broadcast =
+      CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
+
+  if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
+    return absl::nullopt;
+  }
+
+  // In order to map to cudnnConvolutionBiasActivationForward for int8_t, the
+  // convolution output is float, i.e. conv<float>(int8_x, int8_w)
+  if (conv->operand(0)->shape().element_type() == xla::S8) {
+    if (conv->shape().tuple_shapes(0).element_type() != xla::F32) {
+      return absl::nullopt;
+    }
+  }
+
+  if (bias_broadcast) {
+    // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
+    if (bias_broadcast_instr->dimensions().size() != 1) {
+      return absl::nullopt;
+    }
+    if (bias_broadcast_instr->dimensions(0) !=
+        conv->convolution_dimension_numbers().output_feature_dimension()) {
+      return absl::nullopt;
+    }
+  }
+
+  return ConvWithRelu{
+      instr,
+      conv,
+      bias,
+      side_input,
+      CastOrNull<HloConstantInstruction>(alpha_conv_instr),
+      CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
 }
 
+StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
+    ConvWithRelu match) {
+  auto conv = match.conv;
+
+  HloComputation* computation = conv->parent();
+
+  const auto get_alpha_value =
+      [](HloConstantInstruction* instr) -> StatusOr<double> {
+    TF_ASSIGN_OR_RETURN(
+        auto alpha,
+        Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
+    return alpha.GetFirstElement<double>();
+  };
+
+  double alpha_conv = 1;
+  if (match.alpha_conv) {
+    TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
+  }
+
+  double alpha_side_input;
+  if (match.side_input) {
+    if (match.alpha_side_input) {
+      TF_ASSIGN_OR_RETURN(alpha_side_input,
+                          get_alpha_value(match.alpha_side_input));
+    } else {
+      alpha_side_input = 1;
+    }
+  } else {
+    CHECK(match.alpha_side_input == nullptr);
+    alpha_side_input = 0;
+  }
+
+  auto bias = match.bias;
+  if (!bias) {
+    PrimitiveType conv_output_type =
+        conv->shape().tuple_shapes(0).element_type();
+    auto zero = computation->AddInstruction(
+        HloInstruction::CreateConstant(LiteralUtil::Zero(conv_output_type)));
+
+    int64_t num_output_feature = conv->shape().tuple_shapes(0).dimensions(
+        conv->convolution_dimension_numbers().output_feature_dimension());
+    bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
+        ShapeUtil::MakeShapeWithDescendingLayout(conv_output_type,
+                                                 {num_output_feature}),
+        zero, {}));
+  }
+
+  CHECK(bias);
+  std::vector<HloInstruction*> args = {conv->mutable_operand(0),
+                                       conv->mutable_operand(1), bias};
+  if (match.side_input) {
+    args.push_back(match.side_input);
+  }
+  auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
+      conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+  new_conv->set_feature_group_count(conv->feature_group_count());
+  new_conv->set_window(conv->window());
+  new_conv->set_convolution_dimension_numbers(
+      conv->convolution_dimension_numbers());
+  new_conv->set_metadata(conv->metadata());
+  computation->parent()->SetAndUniquifyInstrName(new_conv,
+                                                 "cudnn-conv-bias-activation");
+  TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
+                      conv->backend_config<CudnnConvBackendConfig>());
+  config.set_activation_mode(
+      static_cast<int64_t>(se::dnn::ActivationMode::kRelu));
+  config.set_conv_result_scale(alpha_conv);
+  config.set_side_input_scale(alpha_side_input);
+  TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
+
+  VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
+          << new_conv->ToString();
+  return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
+                                               new_conv, 0);
+}
+
+// Fuse bias/scaling/ReLU with convolution custom call with floating point
+// output
+StatusOr<bool> RunFuseBiasSideActivation(HloModule* module) {
+  bool changed = false;
+  for (HloComputation* computation : module->MakeNonfusionComputations()) {
+    std::vector<ConvWithRelu> matches;
+    int num_forward_convs = 0;
+    for (auto instr : computation->instructions()) {
+      auto match = FindConvWithRelu(instr);
+      if (match.has_value()) {
+        matches.push_back(*match);
+      }
+      if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+        if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
+          num_forward_convs++;
+        }
+      }
+    }
+    VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
+            << " out of " << num_forward_convs << " forward convs.";
+    std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
+        replacements;
+    for (const ConvWithRelu& match : matches) {
+      TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
+      replacements.push_back({match.maximum, std::move(new_instr)});
+      changed = true;
+    }
+    for (auto& replacement : replacements) {
+      TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+          replacement.first, std::move(replacement.second)));
+    }
+  }
+  return changed;
+}
+
+// Describes a matched pattern:
+// convert_or_clamp(get_tuple_element(custom_call(x,w, ...)));
+// where the custom_call targets CuDNN convolution (either pure convolution or
+// fused convolution).
+struct ConvWithConvertOrClamp {
+  HloInstruction* convert_or_clamp;
+  HloInstruction* gte;
+  HloCustomCallInstruction* conv;
+};
+
+// The pattern we want to match:
+//   convert<int8_t>(clamp(broadcast(-128),
+//   (get_tuple_element(custom_call(int8_x, int8_w, ...)), broadcast(127));
+absl::optional<ConvWithConvertOrClamp> FindConvWithClampAndConvertToInt8(
+    HloInstruction* instr) {
+  using match::Broadcast;
+  using match::Clamp;
+  using match::Convert;
+  using match::GetTupleElement;
+  using match::Op;
+
+  HloInstruction* gte = nullptr;
+  HloInstruction* conv_instr = nullptr;
+  auto lower_pattern = Broadcast(match::ConstantScalar(-128));
+  auto upper_pattern = Broadcast(match::ConstantScalar(127));
+  auto pattern = Convert(
+      Clamp(lower_pattern,
+            GetTupleElement(
+                &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0),
+            upper_pattern));
+
+  if (Match(instr, pattern)) {
+    if (conv_instr->operand(0)->shape().element_type() == xla::S8 &&
+        instr->shape().element_type() == xla::S8) {
+      HloCustomCallInstruction* conv =
+          CastOrNull<HloCustomCallInstruction>(conv_instr);
+      return ConvWithConvertOrClamp{instr, gte, conv};
+    }
+  }
+  return absl::nullopt;
+}
+
+// A help function to rewrite convert_or_clamp_or_other<new_type>(gte(conv()))
+// to gte<new_type>(conv<new_type>()).  It bypasses convert_or_clamp_or_other
+// and set the output data type on gte and conv.
+Status RewriteForConvertOrClampImpl(ConvWithConvertOrClamp match) {
+  auto conv = match.conv;
+  auto gte = match.gte;
+  auto convert_or_clamp = match.convert_or_clamp;
+
+  // Change type on conv and gte
+  auto convert_out_type = convert_or_clamp->shape().element_type();
+  conv->mutable_shape()->mutable_tuple_shapes(0)->set_element_type(
+      convert_out_type);
+  gte->mutable_shape()->set_element_type(convert_out_type);
+
+  // Remove clamp/convert and so on and just keep
+  // get_tuple_element(custom_call(x,w, ...))
+  TF_RETURN_IF_ERROR(convert_or_clamp->ReplaceAllUsesWithDifferentShape(gte));
+  TF_RETURN_IF_ERROR(
+      conv->parent()->RemoveInstructionAndUnusedOperands(convert_or_clamp));
+  return Status::OK();
+}
+
+Status RewriteForFinalOutput(ConvWithConvertOrClamp match) {
+  // When the matched clamp has a single user, which is convert<int8_t>, we
+  // will absorb it, if
+  // 1. the side_input matches a convert<float>(int8_side_input), or
+  // 2. there is no side input
+  const auto is_one_to_one_X_to_Y_cast = [](const HloInstruction* instr,
+                                            PrimitiveType X,
+                                            PrimitiveType Y) -> bool {
+    return (instr->opcode() == HloOpcode::kConvert &&
+            instr->shape().element_type() == Y && instr->operand_count() == 1 &&
+            instr->operand(0)->user_count() == 1 &&
+            instr->operand(0)->shape().element_type() == X);
+  };
+
+  if (match.conv->operand_count() < 4) {
+    // Conv input #3 (zero based) is side_input, after x, w, and bias.
+    // Side input doesn't exist in this case.
+    TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
+  } else if (is_one_to_one_X_to_Y_cast(match.conv->operand(3), S8, F32)) {
+    // If side_input has a convert_float_to_int8, absorb it as well.
+    auto side_converter = match.conv->mutable_operand(3);
+    TF_RETURN_IF_ERROR(side_converter->ReplaceAllUsesWithDifferentShape(
+        side_converter->mutable_operand(0)));
+    TF_RETURN_IF_ERROR(
+        side_converter->parent()->RemoveInstructionAndUnusedOperands(
+            side_converter));
+
+    TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
+  }
+  return Status::OK();
+}
+
+// Fuse the clamp/convert pattern with the int8_t convolution custom call
+// (either pure or fused) for int8_t output
+StatusOr<bool> RunFuseClamp(HloModule* module) {
+  bool changed = false;
+  for (HloComputation* computation : module->MakeNonfusionComputations()) {
+    std::vector<ConvWithConvertOrClamp> matches;
+    for (auto instr : computation->instructions()) {
+      auto match = FindConvWithClampAndConvertToInt8(instr);
+      if (match.has_value()) {
+        matches.push_back(*match);
+      }
+    }
+    for (const ConvWithConvertOrClamp& match : matches) {
+      TF_RETURN_IF_ERROR(RewriteForFinalOutput(match));
+      changed = true;
+    }
+
+    // Report error for any convolution still having int32_t output.
+    // Although int32_t output convolution will trigger other sanity check
+    // errors later, we want to give specific error message here.
+    for (auto instr : computation->instructions()) {
+      if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+        if ((call->custom_call_target() == kCudnnConvForwardCallTarget ||
+             call->custom_call_target() ==
+                 kCudnnConvBiasActivationForwardCallTarget) &&
+            call->shape().tuple_shapes(0).element_type() == xla::S32) {
+          return Unimplemented(
+              "Integer convolutions for CuDNN must have float or int8_t "
+              "output.  "
+              "Use convert to cast output to float or the following pattern to "
+              "int8_t: "
+              "clamp(broadcast(-128), conv(int8_x, int8_w, ...), "
+              "broadcast(127)).");
+        }
+      }
+    }
+  }
+  return changed;
+}
+
+// The pattern we want to match:
+//   convert<float>(get_tuple_element<int32_t>(custom_call()));
+absl::optional<ConvWithConvertOrClamp> FindConvWithConvertToFloat(
+    HloInstruction* instr) {
+  using match::Convert;
+  using match::GetTupleElement;
+  using match::Op;
+
+  HloInstruction* gte = nullptr;
+  HloInstruction* conv_instr = nullptr;
+  auto pattern =
+      Convert(GetTupleElement(
+                  &gte,
+                  Op(&conv_instr)
+                      .WithOpcode(HloOpcode::kCustomCall)
+                      .WithCustomCallTarget(kCudnnConvForwardCallTarget),
+                  0)
+                  .WithShape(match::Shape().WithElementType(xla::S32)))
+          .WithShape(match::Shape().WithElementType(xla::F32));
+  if (Match(instr, pattern)) {
+    HloCustomCallInstruction* conv =
+        CastOrNull<HloCustomCallInstruction>(conv_instr);
+    return ConvWithConvertOrClamp{instr, gte, conv};
+  }
+  return absl::nullopt;
+}
+
+// Transform
+// convert<float>(GetTupleElement<int32_t>(custom_call<int32_t>(int8_x,
+// int8_w))) to GetTupleElement<float>(custom_call<int32_t>(int8_x, int8_w))
+StatusOr<bool> RunFuseConvertToFloat(HloModule* module) {
+  bool changed = false;
+  for (HloComputation* computation : module->MakeNonfusionComputations()) {
+    std::vector<ConvWithConvertOrClamp> matches;
+    for (auto instr : computation->instructions()) {
+      auto match = FindConvWithConvertToFloat(instr);
+      if (match.has_value()) {
+        matches.push_back(*match);
+      }
+    }
+
+    for (const ConvWithConvertOrClamp& match : matches) {
+      TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
+      changed = true;
+    }
+  }
+  return changed;
+}
 }  // namespace
 
 StatusOr<bool> CudnnFusedConvRewriter::Run(HloModule* module) {
-  bool any_changed = false;
+  TF_ASSIGN_OR_RETURN(bool fused_for_convert_to_float,
+                      RunFuseConvertToFloat(module));
 
-  for (HloComputation* comp : module->MakeNonfusionComputations()) {
-    // Fuse "inside out" starting with the operations closest to the conv.
-    bool changed = false;
+  TF_ASSIGN_OR_RETURN(bool fused_for_bias, RunFuseBiasSideActivation(module));
 
-    TF_ASSIGN_OR_RETURN(changed, FuseConvertToFloat(comp));
-    any_changed |= changed;
+  TF_ASSIGN_OR_RETURN(bool fused_for_clamp, RunFuseClamp(module));
 
-    TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
-    any_changed |= changed;
-
-    // s8 convs' bias and side-input appear before conversion to s8.
-    //
-    // Run FuseBiasOrSideInput twice, so we get both the bias and the side
-    // input, if both are present.
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
-    any_changed |= changed;
-
-    // Relu might appear before or after convert-to-f16/s8, so we check in both
-    // cases.
-    TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
-    any_changed |= changed;
-
-    TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
-    any_changed |= changed;
-
-    TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp));
-    any_changed |= changed;
-
-    // f16 convs' bias+side-input can appear before or after conversion to f16.
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
-    any_changed |= changed;
-    TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
-    any_changed |= changed;
-
-    TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
-    any_changed |= changed;
-
-    // Check that we don't have any convs outputing s32.  cudnn does not support
-    // these.  They should have been transformed to int8->int8 or int8->float
-    // above.
-    TF_RETURN_IF_ERROR(CheckNoS32Convs(comp));
-  }
-
-  VlogStats(module);
-
-  return any_changed;
+  return fused_for_convert_to_float || fused_for_bias || fused_for_clamp;
 }
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h
index b03cc2b..2a43ff3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h
@@ -22,76 +22,40 @@
 namespace xla {
 namespace gpu {
 
-// Rewrites custom-calls targeting cudnnConvolutionForward to
-// cudnnConvolutionBiasActivationForward by fusing operations following forward
-// convolution.  This transform must run after cudnn_conv_rewriter.
+// Rewrite the custom call targeting cudnnConvolutionForward to
+// cudnnConvolutionBiasActivationForward by fusing applicable point-wise
+// operations following forward convolution.  This transform must run after
+// cudnn_conv_rewriter.
+// It is straightforward for floating point convolutions:
+// transforming
+//   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias))
+// to
+//   cudnnConvolutionBiasActivationForward(x, w, bias, alpha1, alpha2, side)
 //
-// Semantics of underlying cudnn ops:
-//
-// https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-// https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
-//
-// ## Floating-point convs
-//
-// A "complete" fused floating-point conv has the form
-//
-//   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)),
-//
-// which we fuse to
-//
-//   cudnnConvolutionBiasActivationForward(x, w, bias, side_input).
-//
-// You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still
-// get a fused convolution.  alpha1/2 must be broadcasts of scalar constants.
-//
-// f16 convs accumulate in f32.  We represent this in HLO as an f32 convolution
-// whose inputs can be converted to f16 without loss of precision and whose
-// output is immediately converted to f16.  A fused f16 conv must follow one of
-// the following idioms.
-//
-//   1. convert_f16(conv_f32(x_f32, w_f32)) +
-//      side_input_f16 + broadcast(bias_f16)
-//
-//   2. convert_f16(conv_f32(x_f32, w_f32) +
-//                  side_input_f32 + broadcast(bias_f32))
-//
-// (These are not strictly mathematically equivalent, but cudnn doesn't tell us
-// which one it does, and we deem them "close enough".)
-//
-// The foo_f32 HLOs must all be losslessly-convertible to f16.  Some valid
-// examples:
-//
-//   - foo_f32 = convert_f32(foo_f16)
-//   - foo_f32 = an f32 constant whose values all fit within f16
-//   - foo_f32 = broadcast/transpose/reshape(one of the above)
-//
-// If you have a relu, it can appear before or after the convert_f16.
-//
-// Note that here `bias` must be losslessly-convertible to f16; this is
-// different than for s8 convolutions, where bias is f32.
-//
-// ## Integer convs
-//
-// In pure HLO, a "complete" integer conv is spelled as one of the following
-// `result`s.
-//
-//   base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) +
-//          alpha2_f32 * side_input +
-//          bias_f32
-//
-//   result_f32        = max(result_f32, 0)
-//   result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0)
-//   result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127))
-//
-// The foo_s32 HLOs must be losslessly-convertible to s8.  If the `result_s8`
-// case, side_input should be an f32 HLO that's losslessly-convertible to s8;
-// otherwise, it should be losslessly-convertible to f32.
-//
-// In the `result_s8` case where there's no bias, side-input, or alpha1, you can
-// skip the convert_f32 on conv.
-//
-// If you have an integer convolution that doesn't fit one of these idioms, this
-// pass returns an error -- cudnn will not be able to run it.
+// Integer convolution requires additional patterns to match CuDNN semantics:
+//   #1 from
+//   cast<int8_t>(clamp<-128, 127>(conv(int8_x, int8_w)))
+//   to
+//   cudnnConvolutionForward<int8_t>(int8_x, int8_w)
+// or #2 from
+//   cast<float>(conv(int8_x, int8_w))
+//   to
+//   cudnnConvolutionForward<float>(int8_x, int8_w)
+// or #3 from
+//   cast<int8_t>(clamp<-128, 127>(max(0, alpha1 *
+//                           cast<float>(conv(int8_x, int8_w)) +
+//                           alpha2 * cast<float>(int8_side) +
+//                           broadcast(bias)))
+//   to
+//   cudnnConvolutionBiasActivationForward<int8_t>(int8_x, int8_w, bias, alpha1,
+//   alpha2, int8_side)
+// or #4 from
+//   max(0, alpha1 * cast<float>(conv(int8_x, int8_w)) +
+//          alpha2 * float_side + broadcast(bias))
+//   to
+//   cudnnConvolutionBiasActivationForward<float>(int8_x, int8_w, bias, alpha1,
+//   alpha2, float_side)
+
 class CudnnFusedConvRewriter : public HloModulePass {
  public:
   absl::string_view name() const override {
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
index 3a328ca..1c7c1f3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
@@ -15,26 +15,14 @@
 
 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
 
-#include <string>
-
 #include "absl/strings/str_replace.h"
-#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
-#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
-#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
-#include "tensorflow/compiler/xla/service/pattern_matcher.h"
-#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
-#include "tensorflow/compiler/xla/service/reshape_mover.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/filecheck.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace xla {
@@ -43,13 +31,9 @@
 
 // TODO(b/210165681): The tests in this file are fragile to HLO op names.
 
-namespace m = match;
-
 using ::testing::HasSubstr;
 using ::testing::Not;
 
-class CudnnFusedConvRewriterHloTest : public HloTestBase {};
-
 class CudnnFusedConvRewriterTest : public GpuCodegenTest {
  protected:
   std::string GetOptimizedHlo(absl::string_view hlo_string) {
@@ -63,17 +47,16 @@
     debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
     config.set_debug_options(debug_opts);
 
-    auto result = backend().compiler()->RunHloPasses(
-        ParseAndReturnVerifiedModule(hlo_string, config).ConsumeValueOrDie(),
-        backend().default_stream_executor(), backend().memory_allocator());
-    if (!result.status().ok()) {
-      TF_EXPECT_OK(result.status())
-          << "HLO compilation failed: " << result.status();
-      return "";
-    }
     HloPrintOptions print_opts;
     print_opts.set_print_operand_shape(false);
-    return (*result)->ToString(print_opts);
+    return backend()
+        .compiler()
+        ->RunHloPasses(ParseAndReturnVerifiedModule(hlo_string, config)
+                           .ConsumeValueOrDie(),
+                       backend().default_stream_executor(),
+                       backend().memory_allocator())
+        .ConsumeValueOrDie()
+        ->ToString(print_opts);
   }
 
   void TestMatchWithAllTypes(absl::string_view hlo_string) {
@@ -82,8 +65,7 @@
           absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
       std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
       EXPECT_THAT(optimized_hlo_string,
-                  Not(HasSubstr(kCudnnConvForwardCallTarget)))
-          << optimized_hlo_string;
+                  Not(HasSubstr(kCudnnConvForwardCallTarget)));
       EXPECT_THAT(optimized_hlo_string,
                   HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
       EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
@@ -113,7 +95,6 @@
       const std::string hlo_with_new_type =
           absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
       std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
-      SCOPED_TRACE(optimized_hlo_string);
       EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
       EXPECT_THAT(optimized_hlo_string,
                   Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
@@ -331,6 +312,27 @@
     })");
 }
 
+TEST_F(CudnnFusedConvRewriterTest, TestMatchBroadcastedBiasOnly) {
+  // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
+  TestNotMatchWithAllTypes(R"(
+    HloModule Test
+
+    ENTRY Test {
+      zero = TYPE[] constant(0)
+      zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+      input = TYPE[1,3,3,64] parameter(0)
+      filter = TYPE[3,3,64,64] parameter(1)
+      side_input1 = TYPE[1,3,3,64] parameter(2)
+      side_input2 = TYPE[1,3,3,64] parameter(3)
+
+      conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+      add1 = TYPE[1,3,3,64] add(conv, side_input2)
+      add2 = TYPE[1,3,3,64] add(add1, side_input1)
+      ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+    })");
+}
+
 TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
   const char* kHloString = R"(
     HloModule Test
@@ -410,7 +412,8 @@
 
       clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
 
-      ROOT convert = s8[1,32,9,9] convert(clamp)
+      convert = s8[1,32,9,9] convert(clamp)
+      ROOT relu = s8[1,32,9,9] maximum(zeros, convert)
     })",
       // post_hlo
       R"(
@@ -419,8 +422,12 @@
       )");
 }
 
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
-  const std::string module_str = R"(
+TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToFloat) {
+  // convert<float>(conv<int32_t>(convert<int32_t>(int8_x),
+  // convert<int32_t>(int8_w)));
+  TestClamp(
+      // pre_hlo
+      R"(
     HloModule Test
 
     ENTRY Test {
@@ -430,1004 +437,15 @@
       inputs32 = s32[1,17,9,9] convert(input)
       filters32 = s32[3,3,17,32] convert(filter)
 
-      conv = s32[1,32,9,9] convolution(inputs32, filters32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
+      conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
 
       ROOT convert = f32[1,32,9,9] convert(conv)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::GetTupleElement(
-                             m::CustomCall(kCudnnConvForwardCallTarget), 0)
-                             .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
-
-      conv = s32[1,32,9,9] convolution(input, filter),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_f32 = f32[1,32,9,9] convert(conv)
-      ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
-                                             add(add(conv_f32, bias), side_input),
-                                             f32[1,32,9,9] broadcast(f32[] constant(127))))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
-                                   m::Parameter(0), m::Parameter(1),
-                                   m::Parameter(2), m::Parameter(3)),
-                     0)
-                     .WithShape(S8, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-
-      conv = s32[1,32,9,9] convolution(input, filter),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
-                                           conv,
-                                           s32[1,32,9,9] broadcast(s32[] constant(127))))
-      zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
-      ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(
-                  &conv, kCudnnConvBiasActivationForwardCallTarget,
-                  m::Parameter(0),  //
-                  m::Parameter(1),  //
-                  m::Broadcast(
-                      m::ConstantEffectiveScalar(0).WithElementType(F32))),
-              0)
-              .WithShape(S8, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s8[1,17,9,9] parameter(0)
-      filter = s8[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      side_input_f32 = f32[1,32,9,9] parameter(3)
-
-      inputs32 = s32[1,17,9,9] convert(input)
-      filters32 = s32[3,3,17,32] convert(filter)
-
-      conv = s32[1,32,9,9] convolution(inputs32, filters32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_f32 = f32[1,32,9,9] convert(conv)
-      sum1 = add(conv_f32, bias_broadcast)
-      ROOT sum2 = add(sum1, side_input_f32)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  EXPECT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
-                                   m::Parameter(0), m::Parameter(1),
-                                   m::Parameter(2), m::Parameter(3)),
-                     0)
-                     .WithShape(F32, {1, 32, 9, 9})));
-}
-
-// The ReshapeMover pass changes
-//   reshape(side_input) * alpha -->
-//   reshape(side_input * alpha).
-// Make sure we can pattern-match this.
-TEST_F(CudnnFusedConvRewriterTest, Int8SideInputWithScaleAndReshape) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
-      side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
-
-      conv = s32[1,32,9,9] convolution(input, filter),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
-                                             add(add(f32[1,32,9,9] convert(conv), bias), side_input),
-                                             f32[1,32,9,9] broadcast(f32[] constant(127))))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  HloPassFix<HloPassPipeline> simplify("simplify");
-  simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
-  simplify.AddPass<ReshapeMover>();
-  TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(
-                  &conv, kCudnnConvBiasActivationForwardCallTarget,
-                  m::Parameter(0),  //
-                  m::Parameter(1),  //
-                  m::Parameter(2),  //
-                  m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
-              0)
-              .WithShape(S8, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.side_input_scale(), 0.25);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      input = s8[1,17,9,9] parameter(0)
-      filter = s8[3,3,17,32] parameter(1)
-      inputs32 = s32[1,17,9,9] convert(input)
-      filters32 = s32[3,3,17,32] convert(filter)
-      alpha = f32[] constant(42)
-      alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
-
-      conv = s32[1,32,9,9] convolution(inputs32, filters32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      convert = f32[1,32,9,9] convert(conv)
-      ROOT root = multiply(convert, alpha_broadcast)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv = nullptr;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.conv_result_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      zero = f32[] constant(0)
-      zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias_broadcast)
-      ROOT relu = maximum(sum, zeros)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, bias)
-      relu = maximum(sum, zeros)
-      not_relu = minimum(sum, zeros)
-      ROOT root = tuple(relu, not_relu)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::MaximumAnyOrder(
-              m::Broadcast(m::ConstantEffectiveScalar(0)),
-              m::GetTupleElement(
-                  m::CustomCall(
-                      &conv, kCudnnConvBiasActivationForwardCallTarget,
-                      m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-                  0)
-                  .WithShape(F32, {1, 32, 9, 9})),
-          m::Minimum())));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(multiply(alpha, conv), bias)
-      ROOT root = tuple(conv, sum)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
-                         m::MultiplyAnyOrder(
-                             m::Broadcast(m::Parameter(3)),
-                             m::GetTupleElement(m::CustomCall(&conv2), 0))))));
-  EXPECT_EQ(conv1, conv2);
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv1->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = tuple(conv, add(conv, bias))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
-                         m::GetTupleElement(m::CustomCall(&conv2), 0)))));
-  EXPECT_EQ(conv1, conv2);
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv1->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      side_input = f32[1,32,9,9] parameter(2)
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = tuple(conv, add(conv, side_input))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::AddAnyOrder(m::Parameter(2),
-                         m::GetTupleElement(m::CustomCall(&conv2), 0)))));
-  EXPECT_EQ(conv1, conv2);
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv1->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.conv_result_scale(), 1);
-  EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
-      filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Tuple(
-                  m::GetTupleElement(m::CustomCall(&conv1), 0),
-                  m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
-  EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
-      filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      conv_s8 = s8[1,32,9,9] convert(clamp(
-                  f32[1,32,9,9] broadcast(f32[] constant(-128)),
-                  conv,
-                  f32[1,32,9,9] broadcast(f32[] constant(127))))
-      ROOT root = tuple(conv, conv_s8)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv1;
-  const HloInstruction* conv2;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::Tuple(
-          m::GetTupleElement(m::CustomCall(&conv1), 0),
-          m::Convert(m::Clamp(m::Op(),  //
-                              m::GetTupleElement(m::CustomCall(&conv2), 0),
-                              m::Op())))));
-  EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, bias_broadcast)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      side_input = f32[1,32,9,9] parameter(2)
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, side_input)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0), m::Parameter(1),
-                            m::Broadcast(m::ConstantEffectiveScalar(0))
-                                .WithShape(F32, {32}),
-                            m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      side_input = f32[1,32,9,9] parameter(2)
-      side_input_scale = f32[] constant(42)
-      side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
-      side_input_product = multiply(side_input, side_input_scale_broadcast)
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, side_input_product)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0), m::Parameter(1),
-                            m::Broadcast(m::ConstantEffectiveScalar(0))
-                                .WithShape(F32, {32}),
-                            m::Parameter(2)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      side_input = f32[1,32,9,9] parameter(3)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input)
-      ROOT sum2 = add(sum, bias_broadcast)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2),
-                            m::Parameter(3)),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] parameter(0)
-      filters = f32[3,3,17,32] parameter(1)
-      bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
-      conv = f32[1,32,9,9] convolution(inputs, filters),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      ROOT root = add(conv, bias)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0), m::Parameter(1),
-                            m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
-              0)
-              .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,17,9,9] parameter(0)
-      filters = f16[3,3,17,32] parameter(1)
-      bias = f16[32] parameter(2)
-      side_input = f16[1,32,9,9] parameter(3)
-
-      inputs_f32 = f32[1,17,9,9] convert(inputs)
-      filters_f32 = f32[3,3,17,32] convert(filters)
-      bias_f32 = f32[32] convert(bias)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
-      side_input_f32 = f32[1,32,9,9] convert(side_input)
-      conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0), m::Parameter(1), m::Parameter(2),
-                            m::Parameter(3)),
-              0)
-              .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-// We should be able to lower this to an f16 convolution even though the
-// f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
-TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
-      filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
-      bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
-      side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
-
-      conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
-                 window={size=3x3 pad=1_1x1_1},
-                 dim_labels=bf01_01io->bf01
-      conv_f16 = f16[1,32,9,9] convert(conv_f32)
-      ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(
-                         &conv, kCudnnConvBiasActivationForwardCallTarget,
-                         m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
-                             .WithElementType(F16),
-                         m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
-                             .WithElementType(F16),
-                         m::Parameter(2), m::Reshape(m::Parameter(3))),
-                     0)
-                     .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,17,9,9] parameter(0)
-      filters = f16[3,3,17,32] parameter(1)
-      bias = f32[32] parameter(2)
-      side_input = f16[1,32,9,9] parameter(3)
-
-      inputs_f32 = f32[1,17,9,9] convert(inputs)
-      filters_f32 = f32[3,3,17,32] convert(filters)
-      bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
-      side_input_f32 = f32[1,32,9,9] convert(side_input)
-      conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
-               window={size=3x3 pad=1_1x1_1},
-               dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  // fp16 convs only support fp16 biases.  Because bias is fp32, it doesn't get
-  // fused in, and we get an fp32 conv.
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Convert(m::GetTupleElement(
-                         m::CustomCall(
-                             &conv, kCudnnConvBiasActivationForwardCallTarget,
-                             m::Convert(m::Parameter(0)).WithElementType(F32),
-                             m::Convert(m::Parameter(1)).WithElementType(F32),
-                             m::Parameter(2),
-                             m::Convert(m::Parameter(3)).WithElementType(F32)),
-                         0))
-              .WithShape(F16, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,2,2,2] parameter(0)
-      filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
-      bias = f16[2] parameter(1)
-      bias_f32 = f32[2] convert(bias)
-      side_input_f32 = f32[1,2,2,2] constant({{
-        {{0.5, 0.25}, {0.125, 0.0625}},
-        {{0.5, 0.25}, {0.125, 0.0625}}
-      }})
-
-      inputs_f32 = f32[1,2,2,2] convert(inputs)
-      bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
-      conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
-               window={size=1x1}, dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph, and fold
-  // convert back into constants.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-  HloConstantFolding constant_folding;
-  TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(m::GetTupleElement(
-                     m::CustomCall(
-                         &conv, kCudnnConvBiasActivationForwardCallTarget,
-                         m::Parameter(0), m::Constant().WithElementType(F16),
-                         m::Parameter(1), m::Constant().WithElementType(F16)),
-                     0)
-                     .WithShape(F16, {1, 2, 2, 2})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
-  const std::string module_str = R"(
-    HloModule Test
-
-    ENTRY Test {
-      inputs = f16[1,2,2,2] parameter(0)
-      filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
-      bias = f16[2] parameter(1)
-      bias_f32 = f32[2] convert(bias)
-      side_input_f32 = f32[1,2,2,2] constant({{
-        {{0.1, 0.2}, {0.3, 0.4}},
-        {{0.5, 0.6}, {0.7, 0.8}}
-      }})
-
-      inputs_f32 = f32[1,2,2,2] convert(inputs)
-      bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
-      conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
-               window={size=1x1}, dim_labels=bf01_01io->bf01
-      sum = add(conv, side_input_f32)
-      sum2 = add(sum, bias_broadcast)
-      ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
-    })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph, and fold
-  // convert back into constants.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-  HloConstantFolding constant_folding;
-  TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  // This doesn't get transformed into an f16 conv because the filters param is
-  // not losslessly expressible as f16.
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::Convert(m::GetTupleElement(
-                         m::CustomCall(
-                             &conv, kCudnnConvBiasActivationForwardCallTarget,
-                             m::Convert(m::Parameter(0)).WithElementType(F32),
-                             m::Constant().WithElementType(F32),
-                             m::Convert(m::Parameter(1)).WithElementType(F32),
-                             m::Constant().WithElementType(F32)),
-                         0)
-                         .WithShape(F32, {1, 2, 2, 2}))
-              .WithElementType(F16)));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterTest, FuseReluBeforeConvert) {
-  const std::string module_str = R"(
-  HloModule Test
-
-  ENTRY Test {
-    input = s8[1,17,9,9] parameter(0)
-    filter = s8[3,3,17,32] parameter(1)
-    inputs32 = s32[1,17,9,9] convert(input)
-    filters32 = s32[3,3,17,32] convert(filter)
-
-    conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
-    zero = s32[] constant(0)
-    zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
-    relu = maximum(conv, zeros)
-
-    lower = s32[] constant(-128)
-    lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
-    upper = s32[] constant(127)
-    uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
-
-    clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
-
-    ROOT convert = s8[1,32,9,9] convert(clamp)
-  })";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
-  GpuConvRewriter rewriter;
-  TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
-  CudnnFusedConvRewriter fuser;
-  TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
-  // Simplify new `convert`'s that may be added to the graph.
-  AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
-  TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
-  SCOPED_TRACE(m->ToString());
-  const HloInstruction* conv;
-  ASSERT_THAT(
-      m->entry_computation()->root_instruction(),
-      GmockMatch(
-          m::GetTupleElement(
-              m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
-                            m::Parameter(0),  //
-                            m::Parameter(1),  //
-                            m::Broadcast(m::ConstantEffectiveScalar(0))
-                                .WithShape(F32, {32})),
-              0)
-              .WithShape(S8, {1, 32, 9, 9})));
-  TF_ASSERT_OK_AND_ASSIGN(auto config,
-                          conv->backend_config<CudnnConvBackendConfig>());
-  EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+    })",
+      // post_hlo
+      R"(
+      ; CHECK-LABEL: ENTRY %Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> f32[1,32,9,9] {
+      ; CHECK:  %cudnn-conv{{(\.[0-9])?}} = (f32[1,32,9,9]{1,3,2,0}, u8[{{[0-9]+}}]{0}) custom-call(%fusion{{(\.[0-9])?}}, %fusion{{(\.[0-9])?}}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config=
+      )");
 }
 
 TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
@@ -1462,7 +480,7 @@
 
       clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
 
-      ROOT convert = s8[1,3,3,64] convert(clamp)
+      ROOT convert = s8[1,3,3,64] convert(clamp)      
     })",
       // post_hlo
       R"(
@@ -1497,7 +515,7 @@
       convfloat = f32[1,3,3,64] convert(conv)
       broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
       add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
-      ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
+      ROOT relu = f32[1,3,3,64] maximum(zeros, add1)     
     })",
       // post_hlo
       R"(
@@ -1550,18 +568,13 @@
 
       clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
 
-      ROOT convert = s8[1,3,3,64] convert(clamp)
+      ROOT convert = s8[1,3,3,64] convert(clamp) 
     })",
       // post_hlo
       R"(
       ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] {
-      ; CHECK:  %cudnn-conv-bias-activation{{(\.[0-9])?}} =
-      ; CHECK-SAME: (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0})
-      ; CHECK-SAME: custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input),
-      ; CHECK-SAME: window={size=3x3 pad=1_1x1_1},
-      ; CHECK-SAME: dim_labels=b01f_01io->b01f,
-      ; CHECK-SAME: custom_call_target="__cudnn$convBiasActivationForward",
-      ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9])?}}), index=0
+      ; CHECK:  %cudnn-conv-bias-activation{{(\.[0-9])?}} = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
+      ; CHECK-NEXT:  ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9])?}}), index=0
       )");
 }
 
@@ -1609,7 +622,7 @@
 
       clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
 
-      ROOT convert = s8[1,3,3,64] convert(clamp)
+      ROOT convert = s8[1,3,3,64] convert(clamp) 
     })",
       //  post_hlo
       R"(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
index 8e7939f..44c9923 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
@@ -686,6 +686,45 @@
 
   // If all else fails, try a forward convolution.
   if (conv_matchers::CanImplementAsGpuForwardConv(conv)) {
+    if (primitive_util::IsIntegralType(
+            conv->operand(0)->shape().element_type())) {
+      // In addition to replacing a convolution instruction with
+      // a custom call, integer convolutions must have this pattern to match
+      // CuDNN semantics:
+      // conv<InputT=int32_t, ResultT=int32_t>(
+      //   convert<int32_t>(int8_x), convert<int32_t>(int8_y))
+      // We transform it to:
+      // custom_call<int32_t>(int8_x, int8_y, target=cudnnConvolutionForward)
+      //
+      // We will error out, if the pattern is not found for integer
+      // convolution.
+      const auto is_int8_to_int32_cast =
+          [](const HloInstruction* instr) -> bool {
+        return (instr->opcode() == HloOpcode::kConvert &&
+                instr->operand(0)->shape().element_type() == S8 &&
+                instr->shape().element_type() == S32);
+      };
+      HloInstruction* input_convert = conv->mutable_operand(0);
+      HloInstruction* kernel_convert = conv->mutable_operand(1);
+      if (conv->shape().element_type() != S32 ||
+          !is_int8_to_int32_cast(input_convert) ||
+          !is_int8_to_int32_cast(kernel_convert)) {
+        return Unimplemented(
+            "Integer convolutions for CuDNN must have this pattern: "
+            "conv<InputT=int32_t, ResultT=int32_t>(convert<int32_t>(int8_x), "
+            "convert<int32_t>(int8_y))");
+      }
+      // Bypass the convert<int32_t> for both inputs.
+      TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
+          0, input_convert->mutable_operand(0)));
+      TF_RETURN_IF_ERROR(
+          conv->parent()->RemoveInstructionAndUnusedOperands(input_convert));
+      TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
+          1, kernel_convert->mutable_operand(0)));
+      TF_RETURN_IF_ERROR(
+          conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert));
+    }
+
     if (conv->batch_group_count() > 1) {
       conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
     }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
index 4025e6e..08b037c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
@@ -641,6 +641,23 @@
                           0));
 }
 
+// Check that a forward convolution instruction with int8_t inputs is not
+// allowed
+TEST_F(GpuConvRewriterTest, TestForwardInt8Convolution) {
+  const std::string module_str = absl::StrFormat(R"(
+    HloModule Test
+
+    ENTRY Test {
+      input = s8[1,2,3,3] parameter(0)
+      filter = s8[3,3,2,5] parameter(1)
+
+      ROOT conv = s8[1,5,3,3] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+    })");
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+  ASSERT_FALSE(GpuConvRewriter().Run(m.get()).ok());
+}
+
 TEST_F(GpuConvRewriterTest, TestBackwardFilterPattern) {
   const std::string module_str = absl::StrFormat(R"(
     HloModule Test