Roll forward cl/420217183 with fix.
The original change had a bug where it would incorrectly fuse relu even if the
conv had a user that *wasn't* relu. See the new DontFuseReluIfMultipleUses
test. This was caught by the failing TAP test. I also added similar tests for
the other fusions in this patch, though they didn't exhibit this bug.
I also added compiler fuel to this patch.
PiperOrigin-RevId: 420515732
Change-Id: I58d3a33f912b703f2afc5f87281ee04925629f4c
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 9569653..c9707cc 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -2204,16 +2204,11 @@
deps = [
":backend_configs_cc",
":cublas_cudnn",
- "//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:pattern_matcher",
- "//tensorflow/core/platform:errors",
- "//tensorflow/core/platform:statusor",
"//tensorflow/core/platform:stream_executor_no_cuda",
- "//tensorflow/stream_executor:dnn_proto_cc",
],
)
@@ -2228,20 +2223,11 @@
"requires-gpu-sm70",
],
deps = [
- ":backend_configs_cc",
":cublas_cudnn",
":cudnn_fused_conv_rewriter",
- ":gpu_conv_rewriter",
":ir_emission_utils",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/service:algebraic_simplifier",
- "//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
- "//tensorflow/compiler/xla/service:pattern_matcher",
- "//tensorflow/compiler/xla/service:pattern_matcher_gmock",
- "//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
index c3d31dd..7a1e9dd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
@@ -15,806 +15,479 @@
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
-#include <functional>
-#include <string>
-
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/stream_executor/dnn.pb.h"
namespace xla {
namespace gpu {
namespace {
-namespace m = match;
+// Describes matched patterns:
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+// for floating point types or
+// max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
+// * side_input + broadcast(bias));
+// for int8_t.
+// Where side_input has the shape of output buffer, and bias is a 1D array with
+// the dimension of number of output features.
+struct ConvWithRelu {
+ HloInstruction* maximum;
+ HloCustomCallInstruction* conv;
+ HloInstruction* bias;
+ HloInstruction* side_input;
+ HloConstantInstruction* alpha_conv;
+ HloConstantInstruction* alpha_side_input;
+};
-// If VLOG is on and `instr` matches `filter_pattern`, prints out why it doesn't
-// match `log_pattern`. You can use this to explain "near-hits".
-template <typename FilterPattern, typename LogPattern>
-void VlogIfFailureToMatch(HloInstruction* instr, const LogPattern& log_pattern,
- absl::string_view desc,
- const FilterPattern& filter_pattern) {
- if (!VLOG_IS_ON(3) || !Match(instr, filter_pattern)) {
- return;
+// The pattern we want to match:
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+// or
+// max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
+// * side_input + broadcast(bias));
+// With its variants involving commute/reassociation of adds, multiplies, and
+// max, and omission of alpha1, side_input, alpha2, or bias.
+absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Broadcast;
+ using match::ConstantScalar;
+ using match::GetTupleElement;
+ using match::Maximum;
+ using match::MultiplyAnyOrder;
+ using match::Op;
+
+ HloInstruction* relu_input;
+
+ // Match max(0, relu_input).
+ auto zero_pattern = Broadcast(ConstantScalar(0));
+ if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
+ !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
+ return absl::nullopt;
}
- std::stringstream os;
- if (!Match(instr, log_pattern, {/*capture=*/false, /*explain_os=*/&os})) {
- VLOG(3) << "Failed to match " << desc << ":\n" << os.str();
- }
-}
+ HloInstruction* conv_instr = nullptr;
+ HloInstruction* alpha_conv_instr = nullptr;
+ HloInstruction* alpha_side_input_instr = nullptr;
+ HloInstruction* bias_broadcast_instr = nullptr;
+ HloInstruction* bias = nullptr;
+ HloInstruction* side_input = nullptr;
-bool IsConvCustomCall(const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kCustomCall &&
- (instr->custom_call_target() == kCudnnConvForwardCallTarget ||
- instr->custom_call_target() ==
- kCudnnConvBiasActivationForwardCallTarget);
-}
+ // These nodes will not be in the returned value, but we need to check them
+ // for single use.
+ HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
+ *mul1 = nullptr, *mul2 = nullptr;
-// Can instr be converted to type `dst_ty` without losing any precision? For
-// our purposes, this is true if:
-//
-// - instr already has type dst_ty, or
-// - instr is convert<wider type>(op_with_dst_ty), or
-// - instr is a constant which we can convert orig_ty -> dst_ty -> orig_ty and
-// get back exactly the original value, or
-// - instr is a broadcast, reshape, or transpose of one of the above.
-bool IsLosslesslyConvertibleTo(const HloInstruction* instr,
- PrimitiveType dst_ty) {
- if (instr->shape().element_type() == dst_ty) {
- return true;
- }
-
- if (Match(instr, m::Convert(m::Op().WithElementType(dst_ty)))) {
- // Check that the convert from dst_ty to instr->element_type() doesn't lose
- // precision. Otherwise, this convert is not lossless.
- return primitive_util::CastPreservesValues(dst_ty,
- instr->shape().element_type());
- }
-
- if (instr->opcode() == HloOpcode::kConstant) {
- if (!instr->shape().IsArray()) {
- return false;
- }
- // Check if instr's literal roundtrips to ty and back to its original type
- // without modification.
- PrimitiveType orig_ty = instr->shape().element_type();
-
- // The only reason Convert() should fail is if we don't support converting
- // from x to y, which indeed means it's not losslessly-convertible.
- StatusOr<Literal> converted1 = instr->literal().Convert(dst_ty);
- if (!converted1.ok()) {
- return false;
- }
- StatusOr<Literal> converted2 = converted1->Convert(orig_ty);
- if (!converted2.ok()) {
- return false;
- }
-
- return instr->literal() == *converted2;
- }
-
- if (instr->opcode() == HloOpcode::kBroadcast ||
- instr->opcode() == HloOpcode::kReshape ||
- instr->opcode() == HloOpcode::kTranspose) {
- return IsLosslesslyConvertibleTo(instr->operand(0), dst_ty);
- }
-
- return false;
-}
-
-// Helpers suitable for use in m::Op().WithPredicate(...).
-bool IsLosslesslyConvertibleToS8(const HloInstruction* instr) {
- return IsLosslesslyConvertibleTo(instr, S8);
-}
-bool IsLosslesslyConvertibleToF16(const HloInstruction* instr) {
- return IsLosslesslyConvertibleTo(instr, F16);
-}
-
-// If `conv` is a vanilla forward conv, transforms it into a
-// conv-bias-activation. If it's already a conv-bias-activation, does nothing.
-//
-// If `conv` is anything else, returns an error.
-StatusOr<HloInstruction*> EnsureIsConvBiasActivation(HloInstruction* conv) {
- CHECK_EQ(conv->opcode(), HloOpcode::kCustomCall);
-
- if (conv->custom_call_target() == kCudnnConvBiasActivationForwardCallTarget) {
- return conv;
- }
-
- if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
- HloComputation* comp = conv->parent();
-
- const Shape& shape = conv->shape().tuple_shapes(0);
- int64_t num_output_features = shape.dimensions(
- conv->convolution_dimension_numbers().output_feature_dimension());
-
- // bias for integer convs is always f32, see
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- PrimitiveType bias_ty;
- if (primitive_util::IsIntegralType(shape.element_type())) {
- bias_ty = F32;
- } else {
- bias_ty = shape.element_type();
- }
- auto bias = BroadcastZeros(comp, bias_ty, {num_output_features});
-
- absl::InlinedVector<HloInstruction*, 3> new_operands(
- conv->operands().begin(), conv->operands().end());
- new_operands.push_back(bias);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), new_operands));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
- new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget);
- comp->parent()->SetAndUniquifyInstrName(new_conv,
- "cudnn-conv-bias-activation");
- return new_conv;
- }
-
- return FailedPrecondition("Unsupported conv: %s", conv->ToString());
-}
-
-// convert<float>(gte(custom-call<int32>(int8_x, int8_w))) ->
-// gte(custom-call<float>(int8_x, int8_w))
-StatusOr<bool> FuseConvertToFloat(HloComputation* comp) {
- bool changed = false;
- for (auto instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv = nullptr;
- auto pattern =
- m::Convert(
- m::GetTupleElement(m::Op(&conv).WithPredicate(IsConvCustomCall), 0)
- .WithElementType(S32))
- .WithElementType(F32);
- if (!Match(instr, pattern)) {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvertToFloat: ", conv->ToString());
- })) {
- continue;
- }
-
- Shape new_shape = conv->shape();
- new_shape.mutable_tuple_shapes(0)->set_element_type(F32);
- HloInstruction* new_conv =
- comp->AddInstruction(conv->CloneWithNewShape(new_shape));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_ASSIGN_OR_RETURN(HloInstruction * new_gte,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte));
-
- changed = true;
- }
-
- return changed;
-}
-
-// alpha * gte(custom-call(...)) ->
-// gte(custom-call(..., backend_config={alpha})).
-StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
- bool changed = false;
- for (auto instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv = nullptr;
- HloInstruction* gte = nullptr;
- HloInstruction* alpha = nullptr;
- auto pattern = m::MultiplyAnyOrder(
- m::GetTupleElement(>e, m::Op(&conv).WithPredicate(IsConvCustomCall),
- 0)
- .WithOneUse(),
- m::Broadcast(m::ConstantEffectiveScalar(&alpha)));
- if (!Match(instr, pattern)) {
- continue;
- }
-
- // alpha is f32 except for f64 convs, where it's f64. See
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- PrimitiveType alpha_ty = gte->shape().element_type() == F64 ? F64 : F32;
- if (!IsLosslesslyConvertibleTo(alpha, alpha_ty)) {
- continue;
- }
-
- TF_ASSIGN_OR_RETURN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- if (config.conv_result_scale() != 1) {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvAlpha: ", conv->ToString());
- })) {
- continue;
- }
-
- // StreamExecutor doesn't support the alpha parameter on non-bias-activation
- // convs, so we have to upgrade `conv`.
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
-
- TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
- config.set_conv_result_scale(alpha_f64.GetFirstElement<double>());
-
- TF_RETURN_IF_ERROR(conv->set_backend_config(config));
- TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte));
-
- changed = true;
- }
- return changed;
-}
-
-StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
- bool changed = false;
- for (auto instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv = nullptr;
- HloInstruction* gte = nullptr;
- HloInstruction* addend = nullptr;
- auto pattern = m::AddAnyOrder(
- m::GetTupleElement(
- >e, m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse(), 0)
- .WithOneUse(),
- m::Op(&addend));
- if (!Match(instr, pattern)) {
- continue;
- }
-
- // If it's a vanilla forward conv, upgrade it to a bias-activation conv. We
- // only want to do this if the fusion will succeed, but we're guaranteed
- // that it will, because the only reason we'll bail at this point is if
- // !can_accept_bias && !can_accept_side_input, and our shiny new
- // bias-activation conv will be able to accept both.
- if (conv->custom_call_target() == kCudnnConvForwardCallTarget) {
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
- }
-
- // Does `conv` already have a (nonzero) bias? Does it already have a
- // side_input?
- bool can_accept_bias =
- Match(conv->operand(2), m::Broadcast(m::ConstantEffectiveScalar(0)));
- bool can_accept_side_input = conv->operand_count() < 4;
-
- // The addend can be fused as a bias if
- // - it is 1D broadcasted in the output feature dimension, and
- // - it is losslessly-convertible to the correct type (f32 for s8/f32/u32
- // convs, and f16 for f16 convs)
- PrimitiveType conv_ty = gte->shape().element_type();
- PrimitiveType bias_ty = conv_ty == F16 ? F16 : F32;
- bool addend_may_be_rank1_bias =
- addend->opcode() == HloOpcode::kBroadcast &&
- addend->dimensions().size() == 1 &&
- addend->dimensions(0) ==
- conv->convolution_dimension_numbers().output_feature_dimension() &&
- IsLosslesslyConvertibleTo(addend, bias_ty);
-
- bool addend_may_be_rank0_bias = addend->opcode() == HloOpcode::kBroadcast &&
- addend->dimensions().empty() &&
- IsLosslesslyConvertibleTo(addend, bias_ty);
-
- absl::InlinedVector<HloInstruction*, 4> new_operands(
- conv->operands().begin(), conv->operands().end());
- TF_ASSIGN_OR_RETURN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- if (can_accept_bias && addend_may_be_rank1_bias) {
- new_operands[2] = MakeConvertToHlo(addend->mutable_operand(0), bias_ty);
- } else if (can_accept_bias && addend_may_be_rank0_bias) {
- new_operands[2] = MakeBroadcastHlo(
- MakeConvertToHlo(addend->mutable_operand(0), bias_ty),
- /*broadcast_dimensions=*/{},
- /*result_shape_bounds=*/
- {gte->shape().dimensions(conv->convolution_dimension_numbers()
- .output_feature_dimension())});
- } else if (can_accept_side_input) {
- CHECK_EQ(new_operands.size(), 3);
- new_operands.push_back(addend);
- config.set_side_input_scale(1);
- } else {
- // Can't fuse; this op already has a bias and a side-input.
- continue;
- }
-
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
- })) {
- continue;
- }
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
- TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
- changed = true;
- }
- return changed;
-}
-
-// custom-call(..., alpha * side_input) ->
-// custom-call(..., side_input, backend_config={alpha}).
-//
-// We also have to support the more complicated case of
-//
-// custom-call(..., reshape(side_input * alpha)) -->
-// custom-call(..., reshape(side_input), backend_config={alpha}),
-//
-// where `reshape` can be an arbitrary chain of reshapes+transposes. This idiom
-// is created by the ReshapeMover pass.
-StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* conv;
- HloInstruction* side_input;
- auto pattern = m::Op(&conv)
- .WithPredicate(IsConvCustomCall)
- .WithOperand(3, m::Op(&side_input));
- if (!Match(instr, pattern)) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- if (config.side_input_scale() != 1) {
- continue;
- }
-
- // Given side_input, pattern match the following (working from bottom up).
+ const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
+ const auto conv_pattern = [&] {
+ auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
+ auto conv_pattern = GetTupleElement(
+ >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
+ }();
+ const auto side_input_pattern = [&] {
+ auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
+ // If bias is already matched, match arbitrary additional input as side
+ // input. Note this may force a cheap operation (e.g. broadcast) to be
+ // materialized into a large buffer, as large as the output buffer.
//
- // before_reshape = multiply(base, broadcast(alpha))
- // side_input = chain_of_reshapes_and_transposes(before_reshape)
+ // TODO(timshen): If in practice there are significant false positives, we
+ // should fix it.
+ auto side_input_pattern = Op(&side_input);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
+ side_input_pattern);
+ }();
+
+ {
+ // Try to match any of the following form of add, in any association:
+ // addends[0]
+ // addends[0] + addends[1]
+ // addends[0] + addends[1] + addends[2]
//
- // where alpha is a scalar constant.
- //
- // alpha is f32 except for f64 convs, where it's f64. See
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- HloInstruction* before_reshape = side_input;
- while (before_reshape->opcode() == HloOpcode::kReshape ||
- before_reshape->opcode() == HloOpcode::kTranspose) {
- before_reshape = before_reshape->mutable_operand(0);
- }
-
- PrimitiveType conv_ty = conv->shape().tuple_shapes(0).element_type();
- PrimitiveType alpha_ty = conv_ty == F64 ? F64 : F32;
- HloInstruction* base;
- HloInstruction* alpha;
- if (!Match(
- before_reshape,
- m::MultiplyAnyOrder(
- m::Op(&base),
- m::Broadcast(m::ConstantEffectiveScalar(&alpha).WithPredicate(
- [&](const HloInstruction* instr) {
- return IsLosslesslyConvertibleTo(instr, alpha_ty);
- }))))) {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
- })) {
- continue;
- }
-
- // Rewrite conv's operand 3 to
- //
- // chain_of_reshapes_and_transposes(before_reshape).
- //
- // and store alpha in the conv's backend config.
- //
- // We're going to do something bad here: We aren't going to check that the
- // chain of reshapes/transposes has one use, so we're potentially
- // duplicating all these instructions (once with alpha and once without).
- //
- // This is justified because
- //
- // - duplicating reshapes/transposes shouldn't be "that bad" -- these
- // instructions can usually be fused, and
- //
- // - *not* fusing alpha can be catastrophic. For s8->s8 convolutions, the
- // side-input must be s8. But the product side_input * alpha is f32, so
- // we can only see that side-input is s8 if we fuse alpha. IOW not fusing
- // alpha means we'll run this s8->s8 conv as s8->f32, which is *much*
- // slower than some extra transposes.
-
- // Recursively clone chain_of_reshapes_and_transposes until we get to
- // `before_reshape`, at which point we skip the multiply(base, alpha) and
- // just return base.
- std::function<HloInstruction*(const HloInstruction*)> clone =
- [&](const HloInstruction* instr) {
- if (instr == before_reshape) {
- return base;
- }
- CHECK(instr->opcode() == HloOpcode::kReshape ||
- instr->opcode() == HloOpcode::kTranspose)
- << "Must be reshape or transpose: " << instr->ToString();
- return comp->AddInstruction(instr->CloneWithNewOperands(
- instr->shape(), {clone(instr->operand(0))}));
- };
- absl::InlinedVector<HloInstruction*, 4> new_operands(
- conv->operands().begin(), conv->operands().end());
- new_operands[3] = clone(side_input);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
-
- TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64));
- config.set_side_input_scale(alpha_f64.GetFirstElement<double>());
- TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
-
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv));
- changed = true;
- }
- return changed;
-}
-
-StatusOr<bool> FuseRelu(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* gte;
- HloInstruction* conv;
- if (!Match(
- instr,
- m::MaximumAnyOrder(
- m::Broadcast(m::ConstantEffectiveScalar(0)),
- m::GetTupleElement(
- >e,
- m::Op(&conv).WithPredicate(IsConvCustomCall).WithOneUse())
- .WithOneUse()))) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
- conv->backend_config<CudnnConvBackendConfig>());
- if (config.activation_mode() != se::dnn::kNone) {
- continue;
- }
-
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseRelu: ", conv->ToString());
- })) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv));
- config.set_activation_mode(se::dnn::kRelu);
- TF_RETURN_IF_ERROR(conv->set_backend_config(config));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
- changed = true;
- }
- return changed;
-}
-
-StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* gte = nullptr;
- HloInstruction* conv = nullptr;
-
- auto f32_convertible_to_f16_pattern =
- m::Op().WithElementType(F32).WithPredicate(
- IsLosslesslyConvertibleToF16);
- auto pattern =
- m::Convert(
- m::GetTupleElement(
- >e,
- m::Op(&conv)
- .WithPredicate(IsConvCustomCall)
- .WithOperand(0, f32_convertible_to_f16_pattern)
- .WithOperand(1, f32_convertible_to_f16_pattern)
- .WithOperandIfPresent(2, f32_convertible_to_f16_pattern)
- .WithOperandIfPresent(3, f32_convertible_to_f16_pattern),
- 0)
- .WithOneUse())
- .WithElementType(F16);
- if (!Match(instr, pattern)) {
- VlogIfFailureToMatch(
- instr, pattern, "fp16 conv",
- m::Op().WithOperand(
- 0, m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))));
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvertToF16: ", conv->ToString());
- })) {
- continue;
- }
-
- VLOG(2) << "Matched fp16 conv: " << conv->ToString();
-
- // In fp16 convs, all operands, including `bias`, must be fp16. This is
- // different from int8 convs, where the bias is fp32. See table of
- // supported datatypes at
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- absl::InlinedVector<HloInstruction*, 4> new_operands;
- for (HloInstruction* operand : conv->operands()) {
- new_operands.push_back(MakeConvertToHlo(operand, F16));
- }
-
- Shape new_shape = conv->shape();
- new_shape.mutable_tuple_shapes(0)->set_element_type(F16);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(new_shape, new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
- changed = true;
- }
- return changed;
-}
-
-StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
- bool changed = false;
- for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
- HloInstruction* gte = nullptr;
- HloInstruction* conv = nullptr;
-
- auto conv_pattern =
- m::Op(&conv)
- .WithPredicate(IsConvCustomCall)
- .WithOperand(0, m::Op().WithPredicate(IsLosslesslyConvertibleToS8))
- .WithOperand(1, m::Op().WithPredicate(IsLosslesslyConvertibleToS8));
-
- // int8 -> int8 conv
- auto s8_pattern =
- m::Convert(
- m::Clamp(
- m::Broadcast(m::ConstantEffectiveScalar(-128)),
- m::GetTupleElement(
- >e,
- conv_pattern.WithOperandIfPresent(
- 3, m::Op().WithPredicate(IsLosslesslyConvertibleToS8)),
- 0)
- .WithOneUse(),
- m::Broadcast(m::ConstantEffectiveScalar(127))))
- .WithElementType(S8);
-
- // int8 -> fp32 conv
- auto f32_pattern = m::GetTupleElement(>e,
- conv_pattern.WithOperandIfPresent(
- 3, m::Op().WithElementType(F32)),
- 0)
- .WithElementType(F32);
-
- VlogIfFailureToMatch(
- instr, s8_pattern, "s8->s8 conv",
- m::Convert(m::Clamp(m::Op(), //
- m::GetTupleElement(
- m::Op().WithPredicate(IsConvCustomCall)), //
- m::Op()))
- .WithElementType(S8));
-
- VlogIfFailureToMatch(
- instr, f32_pattern, "s8->f32 conv",
- m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))
- .WithElementType(F32));
-
- PrimitiveType conv_output_ty;
- if (Match(instr, s8_pattern)) {
- conv_output_ty = S8;
- } else if (Match(instr, f32_pattern)) {
- conv_output_ty = F32;
- } else {
- continue;
- }
- if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
- return absl::StrCat("FuseConvertToS8: ", conv->ToString());
- })) {
- continue;
- }
-
- absl::InlinedVector<HloInstruction*, 4> new_operands(
- conv->operands().begin(), conv->operands().end());
- new_operands[0] = MakeConvertToHlo(new_operands[0], S8);
- new_operands[1] = MakeConvertToHlo(new_operands[1], S8);
- // Don't convert bias (operand 2); it's always f32 for s8 ops in cudnn. See
- // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
- if (new_operands.size() >= 4) {
- // side-input always matches conv output type. We checked in the patterns
- // above that it's losslessly-convertible to this type.
- new_operands[3] = MakeConvertToHlo(new_operands[3], conv_output_ty);
- }
-
- Shape new_shape = conv->shape();
- new_shape.mutable_tuple_shapes(0)->set_element_type(conv_output_ty);
-
- HloInstruction* new_conv = comp->AddInstruction(
- conv->CloneWithNewOperands(new_shape, new_operands));
- comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name());
- TF_ASSIGN_OR_RETURN(HloInstruction * new_instr,
- MakeGetTupleElementHlo(new_conv, 0));
- TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
- changed = true;
- }
- return changed;
-}
-
-Status CheckNoS32Convs(HloComputation* comp) {
- std::vector<HloInstruction*> bad_convs;
- for (HloInstruction* instr : comp->instructions()) {
- if (!IsConvCustomCall(instr)) {
- continue;
- }
- if (instr->shape().tuple_shapes(0).element_type() == S32 ||
- instr->operand(0)->shape().element_type() == S32 ||
- instr->operand(1)->shape().element_type() == S32 ||
- (instr->operand_count() >= 4 &&
- instr->operand(3)->shape().element_type() == S32)) {
- bad_convs.push_back(instr);
- }
- }
-
- if (bad_convs.empty()) {
- return Status::OK();
- }
-
- return Unimplemented(
- R"(
-Can't lower one or more integer convolutions to idioms supported by CuDNN.
-
-CuDNN integer convolutions must have:
-
- - s8 input and filter,
- - f32 bias (if present),
- - s8 or f32 output, and
- - s8 side_input (if present) if output is s8.
-
-For each of the unsupported convs below, we weren't able to lower one of the
-operands or the output to the appropriate type.
-
-See specific HLO idioms in cudnn_fused_conv_rewriter.h, and see cudnn semantics:
-
-https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward and
-https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
-
-Unsupported convs:
-%s
-
-******* Full HLO module *******
-%s
-)",
- absl::StrJoin(bad_convs, "\n",
- [](std::string* out, HloInstruction* instr) {
- absl::StrAppend(out, " - ", instr->ToString());
- }),
- comp->parent()->ToString());
-}
-
-void VlogStats(HloModule* module) {
- if (!VLOG_IS_ON(1)) {
- return;
- }
-
- VLOG(1) << "Results of CudnnFusedConvRewriter for " << module->name();
- absl::flat_hash_map<std::string, int> stats;
- for (HloComputation* comp : module->MakeNonfusionComputations()) {
- for (HloInstruction* instr : comp->instructions()) {
- if (!Match(instr, m::Op().WithPredicate(IsConvCustomCall))) {
- continue;
- }
-
- VLOG(3) << instr->ToString();
-
- if (instr->custom_call_target() == kCudnnConvForwardCallTarget) {
- stats["01 non-fused forward convs"]++;
- } else if (instr->custom_call_target() ==
- kCudnnConvBiasActivationForwardCallTarget) {
- stats["02 fused forward convs"]++;
- }
-
- PrimitiveType conv_in_ty = instr->operand(0)->shape().element_type();
- PrimitiveType conv_out_ty = instr->shape().tuple_shapes(0).element_type();
- if (conv_in_ty == F32) {
- stats["10 f32 convs"]++;
- } else if (conv_in_ty == F16) {
- stats["11 f16 convs"]++;
- } else if (conv_in_ty == S8) {
- if (conv_out_ty == S8) {
- stats["12 s8->s8 convs"]++;
- } else if (conv_out_ty == F32) {
- stats["13 s8->f32 convs"]++;
+ // Then try to match each addend with one of the three patterns: bias, conv,
+ // or side_input. Notice that side_input matching must go last, as it
+ // also matches a conv or a bias.
+ HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
+ auto add3_pattern = [&] {
+ auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
+ return AnyOf<HloInstruction>(
+ AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
+ Op(&addends[0]));
+ }();
+ CHECK(Match(relu_input, add3_pattern));
+ for (auto addend : addends) {
+ if (addend) {
+ if (bias == nullptr && Match(addend, bias_pattern)) {
+ CHECK(bias);
+ } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
+ CHECK(conv_instr);
+ } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
+ CHECK(side_input);
} else {
- LOG(ERROR) << "Unexpected conv: " << instr->ToString();
+ return absl::nullopt;
}
}
-
- if (instr->operand_count() > 2) {
- stats["20 convs with bias"]++;
- if (Match(instr->operand(2),
- m::Broadcast(m::ConstantEffectiveScalar(0)))) {
- stats["21 convs with 0 bias"]++;
- }
- }
- if (instr->operand_count() > 3) {
- stats["22 convs with side-input"]++;
- }
-
- auto config = instr->backend_config<CudnnConvBackendConfig>();
- if (!config.ok()) {
- LOG(ERROR) << "Couldn't parse backend config for " << instr->ToString();
- continue;
- }
-
- if (config->conv_result_scale() != 1) {
- stats["30 convs with result scale"]++;
- }
- if (config->side_input_scale() != 0 && config->side_input_scale() != 1) {
- stats["31 convs with side-input scale"]++;
- }
- stats[absl::StrCat(
- "32 convs with activation mode ",
- se::dnn::ActivationMode_Name(config->activation_mode()))]++;
}
}
- std::vector<std::pair<std::string, int>> stats_sorted(stats.begin(),
- stats.end());
- absl::c_sort(stats_sorted);
- for (const auto& kv : stats_sorted) {
- VLOG(1) << absl::StreamFormat("%4d %s", kv.second,
- absl::string_view(kv.first).substr(3));
+ if (conv_instr == nullptr) {
+ return absl::nullopt;
}
+
+ for (HloInstruction* instr :
+ {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
+ if (instr && instr->user_count() > 1) {
+ return absl::nullopt;
+ }
+ }
+
+ auto conv = Cast<HloCustomCallInstruction>(conv_instr);
+ auto bias_broadcast =
+ CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
+
+ if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
+ return absl::nullopt;
+ }
+
+ // In order to map to cudnnConvolutionBiasActivationForward for int8_t, the
+ // convolution output is float, i.e. conv<float>(int8_x, int8_w)
+ if (conv->operand(0)->shape().element_type() == xla::S8) {
+ if (conv->shape().tuple_shapes(0).element_type() != xla::F32) {
+ return absl::nullopt;
+ }
+ }
+
+ if (bias_broadcast) {
+ // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
+ if (bias_broadcast_instr->dimensions().size() != 1) {
+ return absl::nullopt;
+ }
+ if (bias_broadcast_instr->dimensions(0) !=
+ conv->convolution_dimension_numbers().output_feature_dimension()) {
+ return absl::nullopt;
+ }
+ }
+
+ return ConvWithRelu{
+ instr,
+ conv,
+ bias,
+ side_input,
+ CastOrNull<HloConstantInstruction>(alpha_conv_instr),
+ CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
}
+StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
+ ConvWithRelu match) {
+ auto conv = match.conv;
+
+ HloComputation* computation = conv->parent();
+
+ const auto get_alpha_value =
+ [](HloConstantInstruction* instr) -> StatusOr<double> {
+ TF_ASSIGN_OR_RETURN(
+ auto alpha,
+ Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
+ return alpha.GetFirstElement<double>();
+ };
+
+ double alpha_conv = 1;
+ if (match.alpha_conv) {
+ TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
+ }
+
+ double alpha_side_input;
+ if (match.side_input) {
+ if (match.alpha_side_input) {
+ TF_ASSIGN_OR_RETURN(alpha_side_input,
+ get_alpha_value(match.alpha_side_input));
+ } else {
+ alpha_side_input = 1;
+ }
+ } else {
+ CHECK(match.alpha_side_input == nullptr);
+ alpha_side_input = 0;
+ }
+
+ auto bias = match.bias;
+ if (!bias) {
+ PrimitiveType conv_output_type =
+ conv->shape().tuple_shapes(0).element_type();
+ auto zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(conv_output_type)));
+
+ int64_t num_output_feature = conv->shape().tuple_shapes(0).dimensions(
+ conv->convolution_dimension_numbers().output_feature_dimension());
+ bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShapeWithDescendingLayout(conv_output_type,
+ {num_output_feature}),
+ zero, {}));
+ }
+
+ CHECK(bias);
+ std::vector<HloInstruction*> args = {conv->mutable_operand(0),
+ conv->mutable_operand(1), bias};
+ if (match.side_input) {
+ args.push_back(match.side_input);
+ }
+ auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
+ conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+ new_conv->set_feature_group_count(conv->feature_group_count());
+ new_conv->set_window(conv->window());
+ new_conv->set_convolution_dimension_numbers(
+ conv->convolution_dimension_numbers());
+ new_conv->set_metadata(conv->metadata());
+ computation->parent()->SetAndUniquifyInstrName(new_conv,
+ "cudnn-conv-bias-activation");
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ config.set_activation_mode(
+ static_cast<int64_t>(se::dnn::ActivationMode::kRelu));
+ config.set_conv_result_scale(alpha_conv);
+ config.set_side_input_scale(alpha_side_input);
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
+
+ VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
+ << new_conv->ToString();
+ return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
+ new_conv, 0);
+}
+
+// Fuse bias/scaling/ReLU with convolution custom call with floating point
+// output
+StatusOr<bool> RunFuseBiasSideActivation(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithRelu> matches;
+ int num_forward_convs = 0;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithRelu(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+ if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
+ num_forward_convs++;
+ }
+ }
+ }
+ VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
+ << " out of " << num_forward_convs << " forward convs.";
+ std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
+ replacements;
+ for (const ConvWithRelu& match : matches) {
+ TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
+ replacements.push_back({match.maximum, std::move(new_instr)});
+ changed = true;
+ }
+ for (auto& replacement : replacements) {
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ replacement.first, std::move(replacement.second)));
+ }
+ }
+ return changed;
+}
+
+// Describes a matched pattern:
+// convert_or_clamp(get_tuple_element(custom_call(x,w, ...)));
+// where the custom_call targets CuDNN convolution (either pure convolution or
+// fused convolution).
+struct ConvWithConvertOrClamp {
+ HloInstruction* convert_or_clamp;
+ HloInstruction* gte;
+ HloCustomCallInstruction* conv;
+};
+
+// The pattern we want to match:
+// convert<int8_t>(clamp(broadcast(-128),
+// (get_tuple_element(custom_call(int8_x, int8_w, ...)), broadcast(127));
+absl::optional<ConvWithConvertOrClamp> FindConvWithClampAndConvertToInt8(
+ HloInstruction* instr) {
+ using match::Broadcast;
+ using match::Clamp;
+ using match::Convert;
+ using match::GetTupleElement;
+ using match::Op;
+
+ HloInstruction* gte = nullptr;
+ HloInstruction* conv_instr = nullptr;
+ auto lower_pattern = Broadcast(match::ConstantScalar(-128));
+ auto upper_pattern = Broadcast(match::ConstantScalar(127));
+ auto pattern = Convert(
+ Clamp(lower_pattern,
+ GetTupleElement(
+ >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0),
+ upper_pattern));
+
+ if (Match(instr, pattern)) {
+ if (conv_instr->operand(0)->shape().element_type() == xla::S8 &&
+ instr->shape().element_type() == xla::S8) {
+ HloCustomCallInstruction* conv =
+ CastOrNull<HloCustomCallInstruction>(conv_instr);
+ return ConvWithConvertOrClamp{instr, gte, conv};
+ }
+ }
+ return absl::nullopt;
+}
+
+// A help function to rewrite convert_or_clamp_or_other<new_type>(gte(conv()))
+// to gte<new_type>(conv<new_type>()). It bypasses convert_or_clamp_or_other
+// and set the output data type on gte and conv.
+Status RewriteForConvertOrClampImpl(ConvWithConvertOrClamp match) {
+ auto conv = match.conv;
+ auto gte = match.gte;
+ auto convert_or_clamp = match.convert_or_clamp;
+
+ // Change type on conv and gte
+ auto convert_out_type = convert_or_clamp->shape().element_type();
+ conv->mutable_shape()->mutable_tuple_shapes(0)->set_element_type(
+ convert_out_type);
+ gte->mutable_shape()->set_element_type(convert_out_type);
+
+ // Remove clamp/convert and so on and just keep
+ // get_tuple_element(custom_call(x,w, ...))
+ TF_RETURN_IF_ERROR(convert_or_clamp->ReplaceAllUsesWithDifferentShape(gte));
+ TF_RETURN_IF_ERROR(
+ conv->parent()->RemoveInstructionAndUnusedOperands(convert_or_clamp));
+ return Status::OK();
+}
+
+Status RewriteForFinalOutput(ConvWithConvertOrClamp match) {
+ // When the matched clamp has a single user, which is convert<int8_t>, we
+ // will absorb it, if
+ // 1. the side_input matches a convert<float>(int8_side_input), or
+ // 2. there is no side input
+ const auto is_one_to_one_X_to_Y_cast = [](const HloInstruction* instr,
+ PrimitiveType X,
+ PrimitiveType Y) -> bool {
+ return (instr->opcode() == HloOpcode::kConvert &&
+ instr->shape().element_type() == Y && instr->operand_count() == 1 &&
+ instr->operand(0)->user_count() == 1 &&
+ instr->operand(0)->shape().element_type() == X);
+ };
+
+ if (match.conv->operand_count() < 4) {
+ // Conv input #3 (zero based) is side_input, after x, w, and bias.
+ // Side input doesn't exist in this case.
+ TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
+ } else if (is_one_to_one_X_to_Y_cast(match.conv->operand(3), S8, F32)) {
+ // If side_input has a convert_float_to_int8, absorb it as well.
+ auto side_converter = match.conv->mutable_operand(3);
+ TF_RETURN_IF_ERROR(side_converter->ReplaceAllUsesWithDifferentShape(
+ side_converter->mutable_operand(0)));
+ TF_RETURN_IF_ERROR(
+ side_converter->parent()->RemoveInstructionAndUnusedOperands(
+ side_converter));
+
+ TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
+ }
+ return Status::OK();
+}
+
+// Fuse the clamp/convert pattern with the int8_t convolution custom call
+// (either pure or fused) for int8_t output
+StatusOr<bool> RunFuseClamp(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithConvertOrClamp> matches;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithClampAndConvertToInt8(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ }
+ for (const ConvWithConvertOrClamp& match : matches) {
+ TF_RETURN_IF_ERROR(RewriteForFinalOutput(match));
+ changed = true;
+ }
+
+ // Report error for any convolution still having int32_t output.
+ // Although int32_t output convolution will trigger other sanity check
+ // errors later, we want to give specific error message here.
+ for (auto instr : computation->instructions()) {
+ if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+ if ((call->custom_call_target() == kCudnnConvForwardCallTarget ||
+ call->custom_call_target() ==
+ kCudnnConvBiasActivationForwardCallTarget) &&
+ call->shape().tuple_shapes(0).element_type() == xla::S32) {
+ return Unimplemented(
+ "Integer convolutions for CuDNN must have float or int8_t "
+ "output. "
+ "Use convert to cast output to float or the following pattern to "
+ "int8_t: "
+ "clamp(broadcast(-128), conv(int8_x, int8_w, ...), "
+ "broadcast(127)).");
+ }
+ }
+ }
+ }
+ return changed;
+}
+
+// The pattern we want to match:
+// convert<float>(get_tuple_element<int32_t>(custom_call()));
+absl::optional<ConvWithConvertOrClamp> FindConvWithConvertToFloat(
+ HloInstruction* instr) {
+ using match::Convert;
+ using match::GetTupleElement;
+ using match::Op;
+
+ HloInstruction* gte = nullptr;
+ HloInstruction* conv_instr = nullptr;
+ auto pattern =
+ Convert(GetTupleElement(
+ >e,
+ Op(&conv_instr)
+ .WithOpcode(HloOpcode::kCustomCall)
+ .WithCustomCallTarget(kCudnnConvForwardCallTarget),
+ 0)
+ .WithShape(match::Shape().WithElementType(xla::S32)))
+ .WithShape(match::Shape().WithElementType(xla::F32));
+ if (Match(instr, pattern)) {
+ HloCustomCallInstruction* conv =
+ CastOrNull<HloCustomCallInstruction>(conv_instr);
+ return ConvWithConvertOrClamp{instr, gte, conv};
+ }
+ return absl::nullopt;
+}
+
+// Transform
+// convert<float>(GetTupleElement<int32_t>(custom_call<int32_t>(int8_x,
+// int8_w))) to GetTupleElement<float>(custom_call<int32_t>(int8_x, int8_w))
+StatusOr<bool> RunFuseConvertToFloat(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithConvertOrClamp> matches;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithConvertToFloat(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ }
+
+ for (const ConvWithConvertOrClamp& match : matches) {
+ TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
+ changed = true;
+ }
+ }
+ return changed;
+}
} // namespace
StatusOr<bool> CudnnFusedConvRewriter::Run(HloModule* module) {
- bool any_changed = false;
+ TF_ASSIGN_OR_RETURN(bool fused_for_convert_to_float,
+ RunFuseConvertToFloat(module));
- for (HloComputation* comp : module->MakeNonfusionComputations()) {
- // Fuse "inside out" starting with the operations closest to the conv.
- bool changed = false;
+ TF_ASSIGN_OR_RETURN(bool fused_for_bias, RunFuseBiasSideActivation(module));
- TF_ASSIGN_OR_RETURN(changed, FuseConvertToFloat(comp));
- any_changed |= changed;
+ TF_ASSIGN_OR_RETURN(bool fused_for_clamp, RunFuseClamp(module));
- TF_ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp));
- any_changed |= changed;
-
- // s8 convs' bias and side-input appear before conversion to s8.
- //
- // Run FuseBiasOrSideInput twice, so we get both the bias and the side
- // input, if both are present.
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
- any_changed |= changed;
-
- // Relu might appear before or after convert-to-f16/s8, so we check in both
- // cases.
- TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
- any_changed |= changed;
-
- TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp));
- any_changed |= changed;
-
- TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp));
- any_changed |= changed;
-
- // f16 convs' bias+side-input can appear before or after conversion to f16.
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseBiasOrSideInput(comp));
- any_changed |= changed;
- TF_ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp));
- any_changed |= changed;
-
- TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp));
- any_changed |= changed;
-
- // Check that we don't have any convs outputing s32. cudnn does not support
- // these. They should have been transformed to int8->int8 or int8->float
- // above.
- TF_RETURN_IF_ERROR(CheckNoS32Convs(comp));
- }
-
- VlogStats(module);
-
- return any_changed;
+ return fused_for_convert_to_float || fused_for_bias || fused_for_clamp;
}
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h
index b03cc2b..2a43ff3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h
@@ -22,76 +22,40 @@
namespace xla {
namespace gpu {
-// Rewrites custom-calls targeting cudnnConvolutionForward to
-// cudnnConvolutionBiasActivationForward by fusing operations following forward
-// convolution. This transform must run after cudnn_conv_rewriter.
+// Rewrite the custom call targeting cudnnConvolutionForward to
+// cudnnConvolutionBiasActivationForward by fusing applicable point-wise
+// operations following forward convolution. This transform must run after
+// cudnn_conv_rewriter.
+// It is straightforward for floating point convolutions:
+// transforming
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias))
+// to
+// cudnnConvolutionBiasActivationForward(x, w, bias, alpha1, alpha2, side)
//
-// Semantics of underlying cudnn ops:
-//
-// https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBiasActivationForward
-// https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters
-//
-// ## Floating-point convs
-//
-// A "complete" fused floating-point conv has the form
-//
-// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)),
-//
-// which we fuse to
-//
-// cudnnConvolutionBiasActivationForward(x, w, bias, side_input).
-//
-// You can leave out side_input, bias, alpha1, alpha2, and max(x, 0) and still
-// get a fused convolution. alpha1/2 must be broadcasts of scalar constants.
-//
-// f16 convs accumulate in f32. We represent this in HLO as an f32 convolution
-// whose inputs can be converted to f16 without loss of precision and whose
-// output is immediately converted to f16. A fused f16 conv must follow one of
-// the following idioms.
-//
-// 1. convert_f16(conv_f32(x_f32, w_f32)) +
-// side_input_f16 + broadcast(bias_f16)
-//
-// 2. convert_f16(conv_f32(x_f32, w_f32) +
-// side_input_f32 + broadcast(bias_f32))
-//
-// (These are not strictly mathematically equivalent, but cudnn doesn't tell us
-// which one it does, and we deem them "close enough".)
-//
-// The foo_f32 HLOs must all be losslessly-convertible to f16. Some valid
-// examples:
-//
-// - foo_f32 = convert_f32(foo_f16)
-// - foo_f32 = an f32 constant whose values all fit within f16
-// - foo_f32 = broadcast/transpose/reshape(one of the above)
-//
-// If you have a relu, it can appear before or after the convert_f16.
-//
-// Note that here `bias` must be losslessly-convertible to f16; this is
-// different than for s8 convolutions, where bias is f32.
-//
-// ## Integer convs
-//
-// In pure HLO, a "complete" integer conv is spelled as one of the following
-// `result`s.
-//
-// base = alpha1_f32 * convert_f32(conv_s32(input_s32, filter_s32)) +
-// alpha2_f32 * side_input +
-// bias_f32
-//
-// result_f32 = max(result_f32, 0)
-// result_s8_option1 = max(convert_s8(clamp(-128, base, 127)), 0)
-// result_s8_option2 = convert_s8(clamp(-128, max(base, 0), 127))
-//
-// The foo_s32 HLOs must be losslessly-convertible to s8. If the `result_s8`
-// case, side_input should be an f32 HLO that's losslessly-convertible to s8;
-// otherwise, it should be losslessly-convertible to f32.
-//
-// In the `result_s8` case where there's no bias, side-input, or alpha1, you can
-// skip the convert_f32 on conv.
-//
-// If you have an integer convolution that doesn't fit one of these idioms, this
-// pass returns an error -- cudnn will not be able to run it.
+// Integer convolution requires additional patterns to match CuDNN semantics:
+// #1 from
+// cast<int8_t>(clamp<-128, 127>(conv(int8_x, int8_w)))
+// to
+// cudnnConvolutionForward<int8_t>(int8_x, int8_w)
+// or #2 from
+// cast<float>(conv(int8_x, int8_w))
+// to
+// cudnnConvolutionForward<float>(int8_x, int8_w)
+// or #3 from
+// cast<int8_t>(clamp<-128, 127>(max(0, alpha1 *
+// cast<float>(conv(int8_x, int8_w)) +
+// alpha2 * cast<float>(int8_side) +
+// broadcast(bias)))
+// to
+// cudnnConvolutionBiasActivationForward<int8_t>(int8_x, int8_w, bias, alpha1,
+// alpha2, int8_side)
+// or #4 from
+// max(0, alpha1 * cast<float>(conv(int8_x, int8_w)) +
+// alpha2 * float_side + broadcast(bias))
+// to
+// cudnnConvolutionBiasActivationForward<float>(int8_x, int8_w, bias, alpha1,
+// alpha2, float_side)
+
class CudnnFusedConvRewriter : public HloModulePass {
public:
absl::string_view name() const override {
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
index 3a328ca..1c7c1f3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc
@@ -15,26 +15,14 @@
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
-#include <string>
-
#include "absl/strings/str_replace.h"
-#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
-#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
-#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
-#include "tensorflow/compiler/xla/service/pattern_matcher.h"
-#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
-#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -43,13 +31,9 @@
// TODO(b/210165681): The tests in this file are fragile to HLO op names.
-namespace m = match;
-
using ::testing::HasSubstr;
using ::testing::Not;
-class CudnnFusedConvRewriterHloTest : public HloTestBase {};
-
class CudnnFusedConvRewriterTest : public GpuCodegenTest {
protected:
std::string GetOptimizedHlo(absl::string_view hlo_string) {
@@ -63,17 +47,16 @@
debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
config.set_debug_options(debug_opts);
- auto result = backend().compiler()->RunHloPasses(
- ParseAndReturnVerifiedModule(hlo_string, config).ConsumeValueOrDie(),
- backend().default_stream_executor(), backend().memory_allocator());
- if (!result.status().ok()) {
- TF_EXPECT_OK(result.status())
- << "HLO compilation failed: " << result.status();
- return "";
- }
HloPrintOptions print_opts;
print_opts.set_print_operand_shape(false);
- return (*result)->ToString(print_opts);
+ return backend()
+ .compiler()
+ ->RunHloPasses(ParseAndReturnVerifiedModule(hlo_string, config)
+ .ConsumeValueOrDie(),
+ backend().default_stream_executor(),
+ backend().memory_allocator())
+ .ConsumeValueOrDie()
+ ->ToString(print_opts);
}
void TestMatchWithAllTypes(absl::string_view hlo_string) {
@@ -82,8 +65,7 @@
absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
EXPECT_THAT(optimized_hlo_string,
- Not(HasSubstr(kCudnnConvForwardCallTarget)))
- << optimized_hlo_string;
+ Not(HasSubstr(kCudnnConvForwardCallTarget)));
EXPECT_THAT(optimized_hlo_string,
HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
@@ -113,7 +95,6 @@
const std::string hlo_with_new_type =
absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
- SCOPED_TRACE(optimized_hlo_string);
EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
EXPECT_THAT(optimized_hlo_string,
Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
@@ -331,6 +312,27 @@
})");
}
+TEST_F(CudnnFusedConvRewriterTest, TestMatchBroadcastedBiasOnly) {
+ // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input1 = TYPE[1,3,3,64] parameter(2)
+ side_input2 = TYPE[1,3,3,64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input2)
+ add2 = TYPE[1,3,3,64] add(add1, side_input1)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
const char* kHloString = R"(
HloModule Test
@@ -410,7 +412,8 @@
clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
- ROOT convert = s8[1,32,9,9] convert(clamp)
+ convert = s8[1,32,9,9] convert(clamp)
+ ROOT relu = s8[1,32,9,9] maximum(zeros, convert)
})",
// post_hlo
R"(
@@ -419,8 +422,12 @@
)");
}
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
- const std::string module_str = R"(
+TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToFloat) {
+ // convert<float>(conv<int32_t>(convert<int32_t>(int8_x),
+ // convert<int32_t>(int8_w)));
+ TestClamp(
+ // pre_hlo
+ R"(
HloModule Test
ENTRY Test {
@@ -430,1004 +437,15 @@
inputs32 = s32[1,17,9,9] convert(input)
filters32 = s32[3,3,17,32] convert(filter)
- conv = s32[1,32,9,9] convolution(inputs32, filters32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
+ conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
ROOT convert = f32[1,32,9,9] convert(conv)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(kCudnnConvForwardCallTarget), 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
-
- conv = s32[1,32,9,9] convolution(input, filter),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_f32 = f32[1,32,9,9] convert(conv)
- ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
- add(add(conv_f32, bias), side_input),
- f32[1,32,9,9] broadcast(f32[] constant(127))))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1),
- m::Parameter(2), m::Parameter(3)),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
-
- conv = s32[1,32,9,9] convolution(input, filter),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
- conv,
- s32[1,32,9,9] broadcast(s32[] constant(127))))
- zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
- ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(
- &conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), //
- m::Parameter(1), //
- m::Broadcast(
- m::ConstantEffectiveScalar(0).WithElementType(F32))),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- side_input_f32 = f32[1,32,9,9] parameter(3)
-
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_f32 = f32[1,32,9,9] convert(conv)
- sum1 = add(conv_f32, bias_broadcast)
- ROOT sum2 = add(sum1, side_input_f32)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- EXPECT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1),
- m::Parameter(2), m::Parameter(3)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-// The ReshapeMover pass changes
-// reshape(side_input) * alpha -->
-// reshape(side_input * alpha).
-// Make sure we can pattern-match this.
-TEST_F(CudnnFusedConvRewriterTest, Int8SideInputWithScaleAndReshape) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
- side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
-
- conv = s32[1,32,9,9] convolution(input, filter),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
- add(add(f32[1,32,9,9] convert(conv), bias), side_input),
- f32[1,32,9,9] broadcast(f32[] constant(127))))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- HloPassFix<HloPassPipeline> simplify("simplify");
- simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
- simplify.AddPass<ReshapeMover>();
- TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(
- &conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), //
- m::Parameter(1), //
- m::Parameter(2), //
- m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.side_input_scale(), 0.25);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
- alpha = f32[] constant(42)
- alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- convert = f32[1,32,9,9] convert(conv)
- ROOT root = multiply(convert, alpha_broadcast)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv = nullptr;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.conv_result_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- zero = f32[] constant(0)
- zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias_broadcast)
- ROOT relu = maximum(sum, zeros)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, bias)
- relu = maximum(sum, zeros)
- not_relu = minimum(sum, zeros)
- ROOT root = tuple(relu, not_relu)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::MaximumAnyOrder(
- m::Broadcast(m::ConstantEffectiveScalar(0)),
- m::GetTupleElement(
- m::CustomCall(
- &conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})),
- m::Minimum())));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(multiply(alpha, conv), bias)
- ROOT root = tuple(conv, sum)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
- m::MultiplyAnyOrder(
- m::Broadcast(m::Parameter(3)),
- m::GetTupleElement(m::CustomCall(&conv2), 0))))));
- EXPECT_EQ(conv1, conv2);
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv1->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = tuple(conv, add(conv, bias))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
- m::GetTupleElement(m::CustomCall(&conv2), 0)))));
- EXPECT_EQ(conv1, conv2);
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv1->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- side_input = f32[1,32,9,9] parameter(2)
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = tuple(conv, add(conv, side_input))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::AddAnyOrder(m::Parameter(2),
- m::GetTupleElement(m::CustomCall(&conv2), 0)))));
- EXPECT_EQ(conv1, conv2);
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv1->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.conv_result_scale(), 1);
- EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
- filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
- EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
- filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_s8 = s8[1,32,9,9] convert(clamp(
- f32[1,32,9,9] broadcast(f32[] constant(-128)),
- conv,
- f32[1,32,9,9] broadcast(f32[] constant(127))))
- ROOT root = tuple(conv, conv_s8)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv1;
- const HloInstruction* conv2;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::GetTupleElement(m::CustomCall(&conv1), 0),
- m::Convert(m::Clamp(m::Op(), //
- m::GetTupleElement(m::CustomCall(&conv2), 0),
- m::Op())))));
- EXPECT_EQ(conv1, conv2);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, bias_broadcast)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1), m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- side_input = f32[1,32,9,9] parameter(2)
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, side_input)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::ConstantEffectiveScalar(0))
- .WithShape(F32, {32}),
- m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- side_input = f32[1,32,9,9] parameter(2)
- side_input_scale = f32[] constant(42)
- side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
- side_input_product = multiply(side_input, side_input_scale_broadcast)
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, side_input_product)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::ConstantEffectiveScalar(0))
- .WithShape(F32, {32}),
- m::Parameter(2)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 42);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- side_input = f32[1,32,9,9] parameter(3)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, side_input)
- ROOT sum2 = add(sum, bias_broadcast)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1), m::Parameter(2),
- m::Parameter(3)),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] parameter(0)
- filters = f32[3,3,17,32] parameter(1)
- bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
- conv = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- ROOT root = add(conv, bias)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1),
- m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
- 0)
- .WithShape(F32, {1, 32, 9, 9})));
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,17,9,9] parameter(0)
- filters = f16[3,3,17,32] parameter(1)
- bias = f16[32] parameter(2)
- side_input = f16[1,32,9,9] parameter(3)
-
- inputs_f32 = f32[1,17,9,9] convert(inputs)
- filters_f32 = f32[3,3,17,32] convert(filters)
- bias_f32 = f32[32] convert(bias)
- bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
- side_input_f32 = f32[1,32,9,9] convert(side_input)
- conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Parameter(1), m::Parameter(2),
- m::Parameter(3)),
- 0)
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-// We should be able to lower this to an f16 convolution even though the
-// f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
-TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
- filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
- bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
- side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
-
- conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- conv_f16 = f16[1,32,9,9] convert(conv_f32)
- ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(
- &conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
- .WithElementType(F16),
- m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
- .WithElementType(F16),
- m::Parameter(2), m::Reshape(m::Parameter(3))),
- 0)
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,17,9,9] parameter(0)
- filters = f16[3,3,17,32] parameter(1)
- bias = f32[32] parameter(2)
- side_input = f16[1,32,9,9] parameter(3)
-
- inputs_f32 = f32[1,17,9,9] convert(inputs)
- filters_f32 = f32[3,3,17,32] convert(filters)
- bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
- side_input_f32 = f32[1,32,9,9] convert(side_input)
- conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
- window={size=3x3 pad=1_1x1_1},
- dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- // fp16 convs only support fp16 biases. Because bias is fp32, it doesn't get
- // fused in, and we get an fp32 conv.
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Convert(m::GetTupleElement(
- m::CustomCall(
- &conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Convert(m::Parameter(0)).WithElementType(F32),
- m::Convert(m::Parameter(1)).WithElementType(F32),
- m::Parameter(2),
- m::Convert(m::Parameter(3)).WithElementType(F32)),
- 0))
- .WithShape(F16, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,2,2,2] parameter(0)
- filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
- bias = f16[2] parameter(1)
- bias_f32 = f32[2] convert(bias)
- side_input_f32 = f32[1,2,2,2] constant({{
- {{0.5, 0.25}, {0.125, 0.0625}},
- {{0.5, 0.25}, {0.125, 0.0625}}
- }})
-
- inputs_f32 = f32[1,2,2,2] convert(inputs)
- bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
- conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
- window={size=1x1}, dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph, and fold
- // convert back into constants.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
- HloConstantFolding constant_folding;
- TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(m::GetTupleElement(
- m::CustomCall(
- &conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), m::Constant().WithElementType(F16),
- m::Parameter(1), m::Constant().WithElementType(F16)),
- 0)
- .WithShape(F16, {1, 2, 2, 2})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- inputs = f16[1,2,2,2] parameter(0)
- filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
- bias = f16[2] parameter(1)
- bias_f32 = f32[2] convert(bias)
- side_input_f32 = f32[1,2,2,2] constant({{
- {{0.1, 0.2}, {0.3, 0.4}},
- {{0.5, 0.6}, {0.7, 0.8}}
- }})
-
- inputs_f32 = f32[1,2,2,2] convert(inputs)
- bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
- conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
- window={size=1x1}, dim_labels=bf01_01io->bf01
- sum = add(conv, side_input_f32)
- sum2 = add(sum, bias_broadcast)
- ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph, and fold
- // convert back into constants.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
- HloConstantFolding constant_folding;
- TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- // This doesn't get transformed into an f16 conv because the filters param is
- // not losslessly expressible as f16.
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::Convert(m::GetTupleElement(
- m::CustomCall(
- &conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Convert(m::Parameter(0)).WithElementType(F32),
- m::Constant().WithElementType(F32),
- m::Convert(m::Parameter(1)).WithElementType(F32),
- m::Constant().WithElementType(F32)),
- 0)
- .WithShape(F32, {1, 2, 2, 2}))
- .WithElementType(F16)));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.side_input_scale(), 1);
-}
-
-TEST_F(CudnnFusedConvRewriterTest, FuseReluBeforeConvert) {
- const std::string module_str = R"(
- HloModule Test
-
- ENTRY Test {
- input = s8[1,17,9,9] parameter(0)
- filter = s8[3,3,17,32] parameter(1)
- inputs32 = s32[1,17,9,9] convert(input)
- filters32 = s32[3,3,17,32] convert(filter)
-
- conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
-
- zero = s32[] constant(0)
- zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
- relu = maximum(conv, zeros)
-
- lower = s32[] constant(-128)
- lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
- upper = s32[] constant(127)
- uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
-
- clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
-
- ROOT convert = s8[1,32,9,9] convert(clamp)
- })";
- TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
-
- GpuConvRewriter rewriter;
- TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
- CudnnFusedConvRewriter fuser;
- TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
-
- // Simplify new `convert`'s that may be added to the graph.
- AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
- TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
-
- SCOPED_TRACE(m->ToString());
- const HloInstruction* conv;
- ASSERT_THAT(
- m->entry_computation()->root_instruction(),
- GmockMatch(
- m::GetTupleElement(
- m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
- m::Parameter(0), //
- m::Parameter(1), //
- m::Broadcast(m::ConstantEffectiveScalar(0))
- .WithShape(F32, {32})),
- 0)
- .WithShape(S8, {1, 32, 9, 9})));
- TF_ASSERT_OK_AND_ASSIGN(auto config,
- conv->backend_config<CudnnConvBackendConfig>());
- EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
+ })",
+ // post_hlo
+ R"(
+ ; CHECK-LABEL: ENTRY %Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> f32[1,32,9,9] {
+ ; CHECK: %cudnn-conv{{(\.[0-9])?}} = (f32[1,32,9,9]{1,3,2,0}, u8[{{[0-9]+}}]{0}) custom-call(%fusion{{(\.[0-9])?}}, %fusion{{(\.[0-9])?}}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config=
+ )");
}
TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
@@ -1462,7 +480,7 @@
clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
- ROOT convert = s8[1,3,3,64] convert(clamp)
+ ROOT convert = s8[1,3,3,64] convert(clamp)
})",
// post_hlo
R"(
@@ -1497,7 +515,7 @@
convfloat = f32[1,3,3,64] convert(conv)
broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
- ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
+ ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
})",
// post_hlo
R"(
@@ -1550,18 +568,13 @@
clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
- ROOT convert = s8[1,3,3,64] convert(clamp)
+ ROOT convert = s8[1,3,3,64] convert(clamp)
})",
// post_hlo
R"(
; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] {
- ; CHECK: %cudnn-conv-bias-activation{{(\.[0-9])?}} =
- ; CHECK-SAME: (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0})
- ; CHECK-SAME: custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input),
- ; CHECK-SAME: window={size=3x3 pad=1_1x1_1},
- ; CHECK-SAME: dim_labels=b01f_01io->b01f,
- ; CHECK-SAME: custom_call_target="__cudnn$convBiasActivationForward",
- ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9])?}}), index=0
+ ; CHECK: %cudnn-conv-bias-activation{{(\.[0-9])?}} = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
+ ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9])?}}), index=0
)");
}
@@ -1609,7 +622,7 @@
clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
- ROOT convert = s8[1,3,3,64] convert(clamp)
+ ROOT convert = s8[1,3,3,64] convert(clamp)
})",
// post_hlo
R"(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
index 8e7939f..44c9923 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
@@ -686,6 +686,45 @@
// If all else fails, try a forward convolution.
if (conv_matchers::CanImplementAsGpuForwardConv(conv)) {
+ if (primitive_util::IsIntegralType(
+ conv->operand(0)->shape().element_type())) {
+ // In addition to replacing a convolution instruction with
+ // a custom call, integer convolutions must have this pattern to match
+ // CuDNN semantics:
+ // conv<InputT=int32_t, ResultT=int32_t>(
+ // convert<int32_t>(int8_x), convert<int32_t>(int8_y))
+ // We transform it to:
+ // custom_call<int32_t>(int8_x, int8_y, target=cudnnConvolutionForward)
+ //
+ // We will error out, if the pattern is not found for integer
+ // convolution.
+ const auto is_int8_to_int32_cast =
+ [](const HloInstruction* instr) -> bool {
+ return (instr->opcode() == HloOpcode::kConvert &&
+ instr->operand(0)->shape().element_type() == S8 &&
+ instr->shape().element_type() == S32);
+ };
+ HloInstruction* input_convert = conv->mutable_operand(0);
+ HloInstruction* kernel_convert = conv->mutable_operand(1);
+ if (conv->shape().element_type() != S32 ||
+ !is_int8_to_int32_cast(input_convert) ||
+ !is_int8_to_int32_cast(kernel_convert)) {
+ return Unimplemented(
+ "Integer convolutions for CuDNN must have this pattern: "
+ "conv<InputT=int32_t, ResultT=int32_t>(convert<int32_t>(int8_x), "
+ "convert<int32_t>(int8_y))");
+ }
+ // Bypass the convert<int32_t> for both inputs.
+ TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
+ 0, input_convert->mutable_operand(0)));
+ TF_RETURN_IF_ERROR(
+ conv->parent()->RemoveInstructionAndUnusedOperands(input_convert));
+ TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
+ 1, kernel_convert->mutable_operand(0)));
+ TF_RETURN_IF_ERROR(
+ conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert));
+ }
+
if (conv->batch_group_count() > 1) {
conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
index 4025e6e..08b037c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter_test.cc
@@ -641,6 +641,23 @@
0));
}
+// Check that a forward convolution instruction with int8_t inputs is not
+// allowed
+TEST_F(GpuConvRewriterTest, TestForwardInt8Convolution) {
+ const std::string module_str = absl::StrFormat(R"(
+ HloModule Test
+
+ ENTRY Test {
+ input = s8[1,2,3,3] parameter(0)
+ filter = s8[3,3,2,5] parameter(1)
+
+ ROOT conv = s8[1,5,3,3] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ })");
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+
+ ASSERT_FALSE(GpuConvRewriter().Run(m.get()).ok());
+}
+
TEST_F(GpuConvRewriterTest, TestBackwardFilterPattern) {
const std::string module_str = absl::StrFormat(R"(
HloModule Test