[xla::gpu] fuse bias addition for bf16 gemms.

PiperOrigin-RevId: 441133520
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
index 7c1b721..9584194 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
@@ -132,29 +132,48 @@
               m::AddAnyOrder(
                   m::Op(&existing_gemm).WithCustomCallTarget(kGemmCallTarget),
                   m::Op(&bias)))) {
-      // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
-      // supports fixed values for alpha/beta.
-      if (existing_gemm->shape().element_type() == S32) {
-        return Status::OK();
-      }
-      auto config =
-          existing_gemm->backend_config<GemmBackendConfig>().ValueOrDie();
-      if (config.beta() == 0 && bias->user_count() == 1 &&
-          existing_gemm->user_count() == 1 &&
-          bias->shape() == existing_gemm->shape()) {
-        config.set_beta(1.0);
-        CHECK_EQ(existing_gemm->operand_count(), 2);
-        std::unique_ptr<HloInstruction> gemm_call =
-            HloInstruction::CreateCustomCall(
-                instr->shape(),
-                {existing_gemm->mutable_operand(0),
-                 existing_gemm->mutable_operand(1), bias},
-                kGemmCallTarget);
-        TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
-        TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
-        TF_RETURN_IF_ERROR(
-            ReplaceWithNewInstruction(instr, std::move(gemm_call)));
-      }
+      return FuseBiasedGemm(instr, bias, existing_gemm);
+    }
+    return Status::OK();
+  }
+
+  Status HandleConvert(HloInstruction *instr) override {
+    HloInstruction *bias, *existing_gemm;
+    if (Match(
+            instr,
+            m::Convert(m::AddAnyOrder(
+                           m::Convert(m::Op(&existing_gemm)
+                                          .WithCustomCallTarget(kGemmCallTarget)
+                                          .WithElementType(BF16)),
+                           m::Convert(m::Op(&bias).WithElementType(BF16))))
+                .WithElementType(BF16))) {
+      return FuseBiasedGemm(instr, bias, existing_gemm);
+    }
+    return Status::OK();
+  }
+
+  Status FuseBiasedGemm(HloInstruction *instr, HloInstruction *bias,
+                        HloInstruction *existing_gemm) {
+    // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
+    // supports fixed values for alpha/beta.
+    if (existing_gemm->shape().element_type() == S32) {
+      return Status::OK();
+    }
+    auto config =
+        existing_gemm->backend_config<GemmBackendConfig>().ValueOrDie();
+    if (config.beta() == 0 && bias->user_count() == 1 &&
+        existing_gemm->user_count() == 1 &&
+        bias->shape() == existing_gemm->shape()) {
+      config.set_beta(1.0);
+      CHECK_EQ(existing_gemm->operand_count(), 2);
+      std::unique_ptr<HloInstruction> gemm_call =
+          existing_gemm->CloneWithNewOperands(
+              instr->shape(), {existing_gemm->mutable_operand(0),
+                               existing_gemm->mutable_operand(1), bias});
+      TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
+      TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
+      TF_RETURN_IF_ERROR(
+          ReplaceWithNewInstruction(instr, std::move(gemm_call)));
     }
     return Status::OK();
   }
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
index 5ce39cf..34f3940 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
@@ -576,6 +576,30 @@
                       /*print_operand_shape=*/true);
   }
 }
+
+TEST_F(GemmRewriteTest, BF16GemmWithBias) {
+  const char* hlo_text = R"(
+HloModule BF16GemmWithBias
+
+ENTRY BF16GemmWithBias {
+  x = bf16[8,8]{1,0} parameter(0)
+  y = bf16[8,8]{1,0} parameter(1)
+  dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  bias = bf16[8,8]{1,0} parameter(2)
+  ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias)
+}
+  )";
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3}));
+  MatchOptimizedHlo(hlo_text,
+                    R"(
+; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], bias: bf16[8,8]) -> bf16[8,8] {
+; CHECK-NEXT:    [[INSTR_0:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
+; CHECK-NEXT:    [[INSTR_1:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
+; CHECK-NEXT:    [[INSTR_2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2)
+; CHECK-NEXT:    ROOT [[INSTR_3:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[INSTR_0]], [[INSTR_1]], [[INSTR_2]]), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"64\",\"rhs_stride\":\"64\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+      )");
+}
 }  // namespace
 }  // namespace gpu
 }  // namespace xla