Updating the algorithm attribute from the mlir_gemm_test as requested in the PR feedback
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index edde24e..a719dc9 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -123,12 +123,6 @@
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
- // Ignore the "algorithm" field on the ROCm platform. This is because
- // autotuning for GEMM is not yet available on the ROCm platform
- // The "algorithm" field does not get populated in the "normal" flow
- // on the ROCm platform, but atleast one unittest directly populates it
- // and hence the need for this check
-#if !defined(TENSORFLOW_USE_ROCM)
if (algorithm) {
// Autotuning is disabled for batch_size != 1.
CHECK_EQ(1, batch_size);
@@ -143,7 +137,6 @@
/*leading dim of output=*/output_matrix.num_rows, computation_type,
*algorithm, output_profile_result);
}
-#endif // !defined(TENSORFLOW_USE_ROCM)
if (batch_size != 1) {
int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc b/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc
index 2c8edcc..d3f4110 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc
@@ -31,7 +31,7 @@
%arg2: memref<2x2xf32> {lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes {
result_xla_shape = "(f32[4]) "
} {
- "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {algorithm = 7 : i64, alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, batch_size = 1 : i64, dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
+ "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, batch_size = 1 : i64, dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
})";