[xla::gpu] fuse bias addition for bf16 gemms.
PiperOrigin-RevId: 441133520
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
index 7c1b721..9584194 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
@@ -132,29 +132,48 @@
m::AddAnyOrder(
m::Op(&existing_gemm).WithCustomCallTarget(kGemmCallTarget),
m::Op(&bias)))) {
- // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
- // supports fixed values for alpha/beta.
- if (existing_gemm->shape().element_type() == S32) {
- return Status::OK();
- }
- auto config =
- existing_gemm->backend_config<GemmBackendConfig>().ValueOrDie();
- if (config.beta() == 0 && bias->user_count() == 1 &&
- existing_gemm->user_count() == 1 &&
- bias->shape() == existing_gemm->shape()) {
- config.set_beta(1.0);
- CHECK_EQ(existing_gemm->operand_count(), 2);
- std::unique_ptr<HloInstruction> gemm_call =
- HloInstruction::CreateCustomCall(
- instr->shape(),
- {existing_gemm->mutable_operand(0),
- existing_gemm->mutable_operand(1), bias},
- kGemmCallTarget);
- TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
- TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(instr, std::move(gemm_call)));
- }
+ return FuseBiasedGemm(instr, bias, existing_gemm);
+ }
+ return Status::OK();
+ }
+
+ Status HandleConvert(HloInstruction *instr) override {
+ HloInstruction *bias, *existing_gemm;
+ if (Match(
+ instr,
+ m::Convert(m::AddAnyOrder(
+ m::Convert(m::Op(&existing_gemm)
+ .WithCustomCallTarget(kGemmCallTarget)
+ .WithElementType(BF16)),
+ m::Convert(m::Op(&bias).WithElementType(BF16))))
+ .WithElementType(BF16))) {
+ return FuseBiasedGemm(instr, bias, existing_gemm);
+ }
+ return Status::OK();
+ }
+
+ Status FuseBiasedGemm(HloInstruction *instr, HloInstruction *bias,
+ HloInstruction *existing_gemm) {
+ // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
+ // supports fixed values for alpha/beta.
+ if (existing_gemm->shape().element_type() == S32) {
+ return Status::OK();
+ }
+ auto config =
+ existing_gemm->backend_config<GemmBackendConfig>().ValueOrDie();
+ if (config.beta() == 0 && bias->user_count() == 1 &&
+ existing_gemm->user_count() == 1 &&
+ bias->shape() == existing_gemm->shape()) {
+ config.set_beta(1.0);
+ CHECK_EQ(existing_gemm->operand_count(), 2);
+ std::unique_ptr<HloInstruction> gemm_call =
+ existing_gemm->CloneWithNewOperands(
+ instr->shape(), {existing_gemm->mutable_operand(0),
+ existing_gemm->mutable_operand(1), bias});
+ TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
+ TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(instr, std::move(gemm_call)));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
index 5ce39cf..34f3940 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
@@ -576,6 +576,30 @@
/*print_operand_shape=*/true);
}
}
+
+TEST_F(GemmRewriteTest, BF16GemmWithBias) {
+ const char* hlo_text = R"(
+HloModule BF16GemmWithBias
+
+ENTRY BF16GemmWithBias {
+ x = bf16[8,8]{1,0} parameter(0)
+ y = bf16[8,8]{1,0} parameter(1)
+ dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ bias = bf16[8,8]{1,0} parameter(2)
+ ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
+}
+ )";
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+ MatchOptimizedHlo(hlo_text,
+ R"(
+; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], bias: bf16[8,8]) -> bf16[8,8] {
+; CHECK-NEXT: [[INSTR_0:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
+; CHECK-NEXT: [[INSTR_1:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
+; CHECK-NEXT: [[INSTR_2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
+; CHECK-NEXT: ROOT [[INSTR_3:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[INSTR_0]], [[INSTR_1]], [[INSTR_2]]), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"64\",\"rhs_stride\":\"64\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+ )");
+}
} // namespace
} // namespace gpu
} // namespace xla