Use `AlgorithmProto` directly in `BackendConfig`.

This avoids the need to add duplicates of every new field of `AlgorithmProto`
here in the next few CLs.

PiperOrigin-RevId: 407402953
Change-Id: I07f930b51e36f455c791174e6a3464d720136faf
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index bc29def..867e85c 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -359,7 +359,7 @@
   ROOT %custom-call.1 = (f32[4,256,2,2]{3,2, 1,0}, u8[65536]{0}) custom-call(f32[4,256,3,3]{3,2,1,0} %input, f32[256,256,2,2]{3,2,1,0} %filter),
                         window={size=2x2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01,
                         custom_call_target="__cudnn$convForward",
-                        backend_config="{\"algorithm\":\"2\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
+                        backend_config="{\"algorithm\": {\"algo_id\":\"2\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
 }
 
 // -----
@@ -386,7 +386,7 @@
   %input = f16[1,17,9,9]{1,3,2,0} parameter(0)
   %filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
   %bias = f16[32]{0} parameter(2)
-  ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}"
+  ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\": {\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}"
 }
 
 // -----
@@ -415,7 +415,7 @@
   %filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
   %bias = f16[32]{0} parameter(2)
   %side = f16[32]{0} parameter(3)
-  ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}"
+  ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":{\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}"
 }
 
 // -----
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 6cdbc26..c133ad9 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -902,9 +902,14 @@
         &custom_call->precision_config(), &builder_));
     op.result_scaleAttr(
         builder_.getF64FloatAttr(backend_config.conv_result_scale()));
+
+    const auto& algorithm = backend_config.algorithm();
+
     auto config = mlir::lmhlo_gpu::ConvolutionBackendConfig::get(
-        builder_.getI64IntegerAttr(backend_config.algorithm()),
-        builder_.getBoolAttr(backend_config.tensor_ops_enabled()),
+        builder_.getI64IntegerAttr(algorithm.algo_id()),
+        builder_.getBoolAttr(
+            algorithm.math_type() ==
+            stream_executor::dnn::AlgorithmProto::TENSOR_OP_MATH),
         get_layout_attribute(custom_call->operand(0)->shape().layout()),
         get_layout_attribute(custom_call->operand(1)->shape().layout()),
         get_layout_attribute(custom_call->shape().tuple_shapes(0).layout()),
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 0d49461..811b107 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -110,7 +110,10 @@
     name = "backend_configs",
     srcs = ["backend_configs.proto"],
     cc_api_version = 2,
-    protodeps = ["//tensorflow/compiler/xla:xla_data_proto"],
+    protodeps = [
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/stream_executor:dnn_proto",
+    ],
 )
 
 cc_library(
diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
index 5840faa..9f17803 100644
--- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto
+++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
@@ -3,6 +3,7 @@
 package xla.gpu;
 
 import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/stream_executor/dnn.proto";
 
 // Backend configs for XLA:GPU.
 //
@@ -19,13 +20,10 @@
 
 // Backend config for a convolution that runs through cudnn.
 message CudnnConvBackendConfig {
-  // Opaque algorithm number of cudnn algorithm chosen for this conv.
-  int64 algorithm = 1;
+  reserved 1, 2;
 
-  // Whether we may use tensor cores when running this conv.  Even if this is
-  // true, cudnn may choose not to use tensor cores, e.g. because the GPU or
-  // selected algorithm doesn't support it.
-  bool tensor_ops_enabled = 2;
+  // Opaque algorithm number and tuning knobs chosen for this conv.
+  stream_executor.dnn.AlgorithmProto algorithm = 6;
 
   // The scaling factor multiplied with the convolution result.
   double conv_result_scale = 4;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
index b758821..8695395 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
@@ -414,8 +414,7 @@
 
     profile_results.emplace_back();
     AutotuneResult& result = profile_results.back();
-    result.mutable_conv()->set_algorithm(alg.algo_id());
-    result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
+    *result.mutable_algorithm() = alg.ToProto();
 
     if (absl::c_linear_search(disabled_algos, alg)) {
       LOG(INFO) << "Omitted potentially buggy algorithm " << alg.ToString()
@@ -650,9 +649,7 @@
     auto profile_result = algorithms[0];
     profile_results.emplace_back();
     auto& result = profile_results.back();
-    result.mutable_conv()->set_algorithm(profile_result.algorithm().algo_id());
-    result.mutable_conv()->set_tensor_ops_enabled(
-        profile_result.algorithm().tensor_ops_enabled());
+    *result.mutable_algorithm() = profile_result.algorithm().ToProto();
 
     result.set_scratch_bytes(profile_result.scratch_size());
     *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
@@ -750,8 +747,7 @@
 
   TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
                       instr->backend_config<CudnnConvBackendConfig>());
-  backend_config.set_algorithm(best_algo.conv().algorithm());
-  backend_config.set_tensor_ops_enabled(best_algo.conv().tensor_ops_enabled());
+  *backend_config.mutable_algorithm() = best_algo.algorithm();
 
   HloInstruction* new_call = computation->AddInstruction(
       instr->CloneWithNewOperands(new_call_shape, instr->operands()));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
index f9d08ae..a7393d2 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc
@@ -304,13 +304,12 @@
   config.output_type = result_shape.element_type();
   config.kind = desc.kind;
 
-  // The third field is scratch size stored from conv_algorithm_picker
+  // The second field is scratch size stored from conv_algorithm_picker
   // The operand is added to the shape field of the conv instruction
   // in GpuConvAlgorithmPicker::RunOnInstruction() call.
   config.algorithm = se::dnn::AlgorithmConfig(
-      se::dnn::AlgorithmDesc(backend_config.algorithm(),
-                             backend_config.tensor_ops_enabled()),
-      desc.scratch_size);
+      se::dnn::AlgorithmDesc(backend_config.algorithm()), desc.scratch_size);
+
   config.conv_result_scale = backend_config.conv_result_scale();
 
   switch (config.kind) {
@@ -349,10 +348,7 @@
   const Window& window = desc.window;
   const ConvolutionDimensionNumbers& dnums = desc.dnums;
 
-  VLOG(3) << "Convolution Algorithm: "
-          << config.algorithm.algorithm()->algo_id();
-  VLOG(3) << "tensor_ops_enabled: "
-          << config.algorithm.algorithm()->tensor_ops_enabled();
+  VLOG(3) << "Convolution Algorithm: " << config.algorithm.ToString();
   VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
   VLOG(3) << "input shape: "
           << ShapeUtil::HumanStringWithLayout(config.input_shape);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 107efaa..158a0e4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1032,10 +1032,14 @@
       dim->set_window_reversal(window_reversal.getValue<bool>(index));
     }
     descriptor.feature_group_count = op.feature_group_count();
-    descriptor.backend_config.set_algorithm(
-        op.backend_config().algorithm().getInt());
-    descriptor.backend_config.set_tensor_ops_enabled(
-        op.backend_config().tensor_ops_enabled().getValue());
+    {
+      auto* algorithm = descriptor.backend_config.mutable_algorithm();
+      algorithm->set_algo_id(op.backend_config().algorithm().getInt());
+      algorithm->set_math_type(
+          op.backend_config().tensor_ops_enabled().getValue()
+              ? se::dnn::AlgorithmProto::TENSOR_OP_MATH
+              : se::dnn::AlgorithmProto::DEFAULT_MATH);
+    }
     descriptor.backend_config.set_conv_result_scale(
         op.result_scale().convertToDouble());
   };
@@ -4935,7 +4939,6 @@
 
   se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability();
 
-
   int smallest_input_dtype_bits = std::numeric_limits<int>::max();
   for (mlir::Value operand : fusion.getInputBuffers()) {
     smallest_input_dtype_bits =