Use `AlgorithmProto` directly in `BackendConfig`.
This avoids the need to add duplicates of every new field of `AlgorithmProto`
here in the next few CLs.
PiperOrigin-RevId: 407402953
Change-Id: I07f930b51e36f455c791174e6a3464d720136faf
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index bc29def..867e85c 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -359,7 +359,7 @@
ROOT %custom-call.1 = (f32[4,256,2,2]{3,2, 1,0}, u8[65536]{0}) custom-call(f32[4,256,3,3]{3,2,1,0} %input, f32[256,256,2,2]{3,2,1,0} %filter),
window={size=2x2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01,
custom_call_target="__cudnn$convForward",
- backend_config="{\"algorithm\":\"2\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
+ backend_config="{\"algorithm\": {\"algo_id\":\"2\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
}
// -----
@@ -386,7 +386,7 @@
%input = f16[1,17,9,9]{1,3,2,0} parameter(0)
%filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
%bias = f16[32]{0} parameter(2)
- ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}"
+ ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\": {\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}"
}
// -----
@@ -415,7 +415,7 @@
%filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
%bias = f16[32]{0} parameter(2)
%side = f16[32]{0} parameter(3)
- ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}"
+ ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":{\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}"
}
// -----
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 6cdbc26..c133ad9 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -902,9 +902,14 @@
&custom_call->precision_config(), &builder_));
op.result_scaleAttr(
builder_.getF64FloatAttr(backend_config.conv_result_scale()));
+
+ const auto& algorithm = backend_config.algorithm();
+
auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get(
- builder_.getI64IntegerAttr(backend_config.algorithm()),
- builder_.getBoolAttr(backend_config.tensor_ops_enabled()),
+ builder_.getI64IntegerAttr(algorithm.algo_id()),
+ builder_.getBoolAttr(
+ algorithm.math_type() ==
+ stream_executor::dnn::AlgorithmProto::TENSOR_OP_MATH),
get_layout_attribute(custom_call->operand(0)->shape().layout()),
get_layout_attribute(custom_call->operand(1)->shape().layout()),
get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()),
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 0d49461..811b107 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -110,7 +110,10 @@
name = "backend_configs",
srcs = ["backend_configs.proto"],
cc_api_version = 2,
- protodeps = ["//tensorflow/compiler/xla:xla_data_proto"],
+ protodeps = [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/stream_executor:dnn_proto",
+ ],
)
cc_library(
diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
index 5840faa..9f17803 100644
--- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto
+++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
@@ -3,6 +3,7 @@
package xla.gpu;
import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/stream_executor/dnn.proto";
// Backend configs for XLA:GPU.
//
@@ -19,13 +20,10 @@
// Backend config for a convolution that runs through cudnn.
message CudnnConvBackendConfig {
- // Opaque algorithm number of cudnn algorithm chosen for this conv.
- int64 algorithm = 1;
+ reserved 1, 2;
- // Whether we may use tensor cores when running this conv. Even if this is
- // true, cudnn may choose not to use tensor cores, e.g. because the GPU or
- // selected algorithm doesn't support it.
- bool tensor_ops_enabled = 2;
+ // Opaque algorithm number and tuning knobs chosen for this conv.
+ stream_executor.dnn.AlgorithmProto algorithm = 6;
// The scaling factor multiplied with the convolution result.
double conv_result_scale = 4;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
index b758821..8695395 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
@@ -414,8 +414,7 @@
profile_results.emplace_back();
AutotuneResult& result = profile_results.back();
- result.mutable_conv()->set_algorithm(alg.algo_id());
- result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
+ *result.mutable_algorithm() = alg.ToProto();
if (absl::c_linear_search(disabled_algos, alg)) {
LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
@@ -650,9 +649,7 @@
auto profile_result = algorithms[0];
profile_results.emplace_back();
auto& result = profile_results.back();
- result.mutable_conv()->set_algorithm(profile_result.algorithm().algo_id());
- result.mutable_conv()->set_tensor_ops_enabled(
- profile_result.algorithm().tensor_ops_enabled());
+ *result.mutable_algorithm() = profile_result.algorithm().ToProto();
result.set_scratch_bytes(profile_result.scratch_size());
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
@@ -750,8 +747,7 @@
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
instr->backend_config<CudnnConvBackendConfig>());
- backend_config.set_algorithm(best_algo.conv().algorithm());
- backend_config.set_tensor_ops_enabled(best_algo.conv().tensor_ops_enabled());
+ *backend_config.mutable_algorithm() = best_algo.algorithm();
HloInstruction* new_call = computation->AddInstruction(
instr->CloneWithNewOperands(new_call_shape, instr->operands()));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
index f9d08ae..a7393d2 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
@@ -304,13 +304,12 @@
config.output_type = result_shape.element_type();
config.kind = desc.kind;
- // The third field is scratch size stored from conv_algorithm_picker
+ // The second field is scratch size stored from conv_algorithm_picker
// The operand is added to the shape field of the conv instruction
// in GpuConvAlgorithmPicker::RunOnInstruction() call.
config.algorithm = se::dnn::AlgorithmConfig(
- se::dnn::AlgorithmDesc(backend_config.algorithm(),
- backend_config.tensor_ops_enabled()),
- desc.scratch_size);
+ se::dnn::AlgorithmDesc(backend_config.algorithm()), desc.scratch_size);
+
config.conv_result_scale = backend_config.conv_result_scale();
switch (config.kind) {
@@ -349,10 +348,7 @@
const Window& window = desc.window;
const ConvolutionDimensionNumbers& dnums = desc.dnums;
- VLOG(3) << "Convolution Algorithm: "
- << config.algorithm.algorithm()->algo_id();
- VLOG(3) << "tensor_ops_enabled: "
- << config.algorithm.algorithm()->tensor_ops_enabled();
+ VLOG(3) << "Convolution Algorithm: " << config.algorithm.ToString();
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
VLOG(3) << "input shape: "
<< ShapeUtil::HumanStringWithLayout(config.input_shape);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 107efaa..158a0e4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1032,10 +1032,14 @@
dim->set_window_reversal(window_reversal.getValue<bool>(index));
}
descriptor.feature_group_count = op.feature_group_count();
- descriptor.backend_config.set_algorithm(
- op.backend_config().algorithm().getInt());
- descriptor.backend_config.set_tensor_ops_enabled(
- op.backend_config().tensor_ops_enabled().getValue());
+ {
+ auto* algorithm = descriptor.backend_config.mutable_algorithm();
+ algorithm->set_algo_id(op.backend_config().algorithm().getInt());
+ algorithm->set_math_type(
+ op.backend_config().tensor_ops_enabled().getValue()
+ ? se::dnn::AlgorithmProto::TENSOR_OP_MATH
+ : se::dnn::AlgorithmProto::DEFAULT_MATH);
+ }
descriptor.backend_config.set_conv_result_scale(
op.result_scale().convertToDouble());
};
@@ -4935,7 +4939,6 @@
se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability();
-
int smallest_input_dtype_bits = std::numeric_limits<int>::max();
for (mlir::Value operand : fusion.getInputBuffers()) {
smallest_input_dtype_bits =