Give cudnn/cublas custom-calls better names in HLO.
Instead of naming these instructions "custom-call.42", name them e.g.
"cublas-gemm.42".
PiperOrigin-RevId: 398100270
Change-Id: Ie3bb77cf72b4e22e1f487a2838bf94f1610ac25a
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 681b2b3..757825e 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1957,12 +1957,6 @@
absl::optional<ConvolutionDimensionNumbers> dnums,
CustomCallSchedule schedule, CustomCallApiVersion api_version) {
HloInstructionProto instr;
- // Bit of a hack: conv-bias-activation-forward custom-calls are created
- // through this API. Give them a user-friendly name. (This has no effect on
- // correctness, it's just cosmetic.)
- if (call_target_name == "__cudnn$convBiasActivationForward") {
- instr.set_name("cudnn-conv-bias-activation");
- }
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
index 515df31..039cc33 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc
@@ -242,7 +242,6 @@
new_conv->set_convolution_dimension_numbers(
conv->convolution_dimension_numbers());
new_conv->set_metadata(conv->metadata());
- new_conv->SetAndSanitizeName("cudnn-conv-bias-activation");
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
conv->backend_config<CudnnConvBackendConfig>());
config.set_activation_mode(
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
index 45e0059..6301b54 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
@@ -29,21 +29,9 @@
namespace xla {
namespace gpu {
-namespace {
namespace m = match;
-// Give this instruction a more useful name than "custom-call.42".
-Status SetName(HloModule *module, HloInstruction *gemm) {
- GemmBackendConfig config;
- TF_ASSIGN_OR_RETURN(config, gemm->backend_config<GemmBackendConfig>());
- bool is_batch_dot = config.batch_size() > 1;
-
- gemm->SetAndSanitizeName(module->instruction_name_uniquer().GetUniqueName(
- is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm"));
- return Status::OK();
-}
-
// The rewriting proceeds in a bottom-up way:
//
// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
@@ -90,7 +78,6 @@
gemm_config.set_lhs_stride(lhs_stride);
gemm_config.set_rhs_stride(rhs_stride);
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gemm_config));
- TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
TF_RETURN_IF_ERROR(
ReplaceWithNewInstruction(instr, std::move(gemm_call)));
}
@@ -150,7 +137,6 @@
existing_gemm->mutable_operand(1), bias},
kGemmCallTarget);
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
- TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
TF_RETURN_IF_ERROR(
ReplaceWithNewInstruction(instr, std::move(gemm_call)));
}
@@ -159,14 +145,12 @@
}
};
-StatusOr<bool> RunOnComputation(HloComputation *computation) {
+static StatusOr<bool> RunOnComputation(HloComputation *computation) {
GemmRewriterVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}
-} // anonymous namespace
-
StatusOr<bool> GemmRewriter::Run(HloModule *module) {
bool changed = false;
for (HloComputation *computation : module->MakeNonfusionComputations()) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
index db62f86..c821a20 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc
@@ -37,7 +37,7 @@
namespace {
-HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
+HloInstruction* CreateGpuConv(const char* call_target, const Shape& shape,
HloInstruction* lhs, HloInstruction* rhs,
const Window& window,
const ConvolutionDimensionNumbers& dnums,
@@ -62,22 +62,6 @@
custom_call->set_convolution_dimension_numbers(dnums);
custom_call->set_feature_group_count(feature_group_count);
custom_call->set_metadata(metadata);
-
- // Give the customcall a user-friendly name.
- absl::optional<std::string> name;
- if (call_target == kCudnnConvForwardCallTarget) {
- name = "cudnn-conv";
- } else if (call_target == kCudnnConvBackwardInputCallTarget) {
- name = "cudnn-conv-bw-input";
- } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
- name = "cudnn-conv-bw-filter";
- } else if (call_target == kCudnnConvBiasActivationForwardCallTarget) {
- name = "cudnn-conv-bias-activation";
- }
- if (name.has_value()) {
- custom_call->SetAndSanitizeName(*name);
- }
-
return custom_call;
}
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
index 2873b3d..9aad6dd 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
@@ -44,7 +44,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,2], y: f32[2,2]) -> f32[3,2,2] {
; CHECK-NEXT: %x = f32[3,2,2]{2,1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: ROOT %cublas-batch-gemm.1 = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[]},\"batch_size\":\"3\",\"lhs_stride\":\"4\",\"rhs_stride\":\"0\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[]},\"batch_size\":\"3\",\"lhs_stride\":\"4\",\"rhs_stride\":\"0\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -68,7 +68,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[3,2,2]) -> f32[3,2,2] {
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[3,2,2]{2,1,0} parameter(1)
-; CHECK-NEXT: ROOT %cublas-batch-gemm.1 = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[\"0\"]},\"batch_size\":\"3\",\"lhs_stride\":\"0\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[\"0\"]},\"batch_size\":\"3\",\"lhs_stride\":\"0\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
index be40bd4..53cdd60 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
@@ -79,7 +79,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: ROOT %cublas-gemm.1 = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -155,7 +155,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: ROOT %cublas-gemm.1 = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -179,7 +179,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
-; CHECK-NEXT: ROOT %cublas-gemm.1 = f32[2,2]{1,0} custom-call(%y, %x), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call = f32[2,2]{1,0} custom-call(%y, %x), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -205,7 +205,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: ROOT %cublas-gemm.1 = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -231,7 +231,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: c64[2,2], y: c64[2,2]) -> c64[2,2] {
; CHECK-NEXT: %x = c64[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = c64[2,2]{1,0} parameter(1)
-; CHECK-NEXT: ROOT %cublas-gemm.1 = c64[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":3,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call = c64[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":3,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -254,7 +254,7 @@
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
MatchOptimizedHlo(hlo_text,
R"(
-; CHECK: %cublas-gemm.1 = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -279,7 +279,7 @@
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: %cublas-gemm.1 = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -308,7 +308,7 @@
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
; CHECK-NEXT: %bias = f32[2,2]{1,0} parameter(2)
-; CHECK-NEXT: ROOT %cublas-gemm.3 = f32[2,2]{1,0} custom-call(%x, %y, %bias), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: ROOT %custom-call.1 = f32[2,2]{1,0} custom-call(%x, %y, %bias), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -338,7 +338,7 @@
; CHECK-NEXT: %bias = f32[2,2]{1,0} parameter(2)
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT: %cublas-gemm.1 = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
+; CHECK-NEXT: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"lhs_stride\":\"4\",\"rhs_stride\":\"4\",\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
)");
}
@@ -411,7 +411,7 @@
} else {
MatchOptimizedHlo(hlo_text,
R"(
- ; CHECK: ROOT %cublas-batch-gemm.1 = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), custom_call_target="__cublas$gemm"
+ ; CHECK: ROOT %custom-call = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), custom_call_target="__cublas$gemm"
)",
/*print_operand_shape=*/true);
}