blob: 958419418f894b6a0f56dd71a4cd0f39df21aea4 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
namespace gpu {
namespace {
namespace m = match;
// Give this instruction a more useful name than "custom-call.42".
Status SetName(HloModule *module, HloInstruction *gemm) {
GemmBackendConfig config;
TF_ASSIGN_OR_RETURN(config, gemm->backend_config<GemmBackendConfig>());
bool is_batch_dot = config.batch_size() > 1;
module->SetAndUniquifyInstrName(
gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm");
return Status::OK();
}
// The rewriting proceeds in a bottom-up way:
//
// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
//
// (kMultiply (kCustomCall:gemm A B) C) is folding C (provided it's a constant)
// into an alpha parameter of the custom call.
//
// (kAdd (kCustomCall:gemm A B) C) is rewritten into (kCustomCall:gemm A B C),
// where the "beta" parameter is set to 1 (provided it was zero before,
// and provided C has no other users).
// We then guide the buffer assignment to alias the buffer of the custom call
// and C.
class GemmRewriterVisitor : public DfsHloRewriteVisitor {
public:
Status HandleDot(HloInstruction *instr) override {
if (IsMatrixMultiplication(*instr)) {
CHECK(!instr->IsRank2Transpose());
HloInstruction *lhs = instr->mutable_operand(0);
HloInstruction *rhs = instr->mutable_operand(1);
CHECK(!lhs->IsRank2Transpose());
CHECK(!rhs->IsRank2Transpose());
const Shape &output_shape = instr->shape();
int64_t batch_size = std::accumulate(output_shape.dimensions().begin(),
output_shape.dimensions().end() - 2,
1, std::multiplies<int64_t>());
std::unique_ptr<HloInstruction> gemm_call =
HloInstruction::CreateCustomCall(output_shape, {lhs, rhs},
kGemmCallTarget);
GemmBackendConfig gemm_config;
gemm_config.set_alpha_real(1.0);
gemm_config.set_alpha_imag(0.0);
gemm_config.set_beta(0.0);
*gemm_config.mutable_dot_dimension_numbers() =
instr->dot_dimension_numbers();
gemm_config.set_batch_size(batch_size);
int64_t lhs_batch_dims_size =
instr->dot_dimension_numbers().lhs_batch_dimensions_size();
int64_t lhs_stride = lhs->shape().dimensions(lhs_batch_dims_size) *
lhs->shape().dimensions(lhs_batch_dims_size + 1);
int64_t rhs_stride = rhs->shape().dimensions(lhs_batch_dims_size) *
rhs->shape().dimensions(lhs_batch_dims_size + 1);
gemm_config.set_lhs_stride(lhs_stride);
gemm_config.set_rhs_stride(rhs_stride);
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gemm_config));
TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
TF_RETURN_IF_ERROR(
ReplaceWithNewInstruction(instr, std::move(gemm_call)));
}
return Status::OK();
}
Status HandleMultiply(HloInstruction *instr) override {
HloInstruction *alpha, *existing_gemm;
if (Match(instr,
m::MultiplyAnyOrder(
m::Op(&existing_gemm).WithCustomCallTarget(kGemmCallTarget),
m::Broadcast(m::ConstantScalar(&alpha))))) {
TF_ASSIGN_OR_RETURN(auto config,
existing_gemm->backend_config<GemmBackendConfig>());
// Do not fuse alpha into S32 GEMM, as they only support fixed values for
// alpha/beta.
if (existing_gemm->shape().element_type() == S32) {
return Status::OK();
}
if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
complex128 new_alpha =
*alpha->literal().GetAsComplex128({}) * prev_alpha;
config.set_alpha_real(new_alpha.real());
config.set_alpha_imag(new_alpha.imag());
TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
}
}
return Status::OK();
}
Status HandleAdd(HloInstruction *instr) override {
HloInstruction *bias, *existing_gemm;
if (Match(instr,
m::AddAnyOrder(
m::Op(&existing_gemm).WithCustomCallTarget(kGemmCallTarget),
m::Op(&bias)))) {
return FuseBiasedGemm(instr, bias, existing_gemm);
}
return Status::OK();
}
Status HandleConvert(HloInstruction *instr) override {
HloInstruction *bias, *existing_gemm;
if (Match(
instr,
m::Convert(m::AddAnyOrder(
m::Convert(m::Op(&existing_gemm)
.WithCustomCallTarget(kGemmCallTarget)
.WithElementType(BF16)),
m::Convert(m::Op(&bias).WithElementType(BF16))))
.WithElementType(BF16))) {
return FuseBiasedGemm(instr, bias, existing_gemm);
}
return Status::OK();
}
Status FuseBiasedGemm(HloInstruction *instr, HloInstruction *bias,
HloInstruction *existing_gemm) {
// Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
// supports fixed values for alpha/beta.
if (existing_gemm->shape().element_type() == S32) {
return Status::OK();
}
auto config =
existing_gemm->backend_config<GemmBackendConfig>().ValueOrDie();
if (config.beta() == 0 && bias->user_count() == 1 &&
existing_gemm->user_count() == 1 &&
bias->shape() == existing_gemm->shape()) {
config.set_beta(1.0);
CHECK_EQ(existing_gemm->operand_count(), 2);
std::unique_ptr<HloInstruction> gemm_call =
existing_gemm->CloneWithNewOperands(
instr->shape(), {existing_gemm->mutable_operand(0),
existing_gemm->mutable_operand(1), bias});
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
TF_RETURN_IF_ERROR(
ReplaceWithNewInstruction(instr, std::move(gemm_call)));
}
return Status::OK();
}
};
StatusOr<bool> RunOnComputation(HloComputation *computation) {
GemmRewriterVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}
} // anonymous namespace
StatusOr<bool> GemmRewriter::Run(HloModule *module) {
bool changed = false;
for (HloComputation *computation : module->MakeNonfusionComputations()) {
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
changed |= result;
}
return changed;
}
} // namespace gpu
} // namespace xla