Unit test to exercise gemm thunk.
PiperOrigin-RevId: 368051681
Change-Id: I89e7c639e2418f2a7360e2c3543ca8b96d4e21ae
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 6cd46e5..e7bd107 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -83,6 +83,18 @@
)
tf_cc_test(
+ name = "mlir_gemm_test",
+ srcs = ["mlir_gemm_test.cc"],
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":mlir_gpu_test_base",
+ "//tensorflow/compiler/jit:xla_gpu_jit",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
name = "mlir_sorting_test",
srcs = ["mlir_sorting_test.cc"],
tags = tf_cuda_tests_tags(),
diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc b/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc
new file mode 100644
index 0000000..1acaa66
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace gpu {
+
+using ::testing::ElementsAreArray;
+
+class GemmTest : public MlirGpuTestBase {};
+
+TEST_F(GemmTest, SimpleCase1) {
+ const char* mlir_text = R"(
+ module attributes {hlo.unique_id = 0 : i32} {
+ func @main(%arg0: memref<2x2xf32> {lmhlo.alloc = 1 : index, lmhlo.params = 0 : index},
+ %arg1: memref<2x2xf32> {lmhlo.alloc = 2 : index, lmhlo.params = 1 : index},
+ %arg2: memref<2x2xf32> {lmhlo.alloc = 0 : index, lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes {
+ result_xla_shape = "(f32[4]) "
+ } {
+ "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {algorithm = 7 : i64, alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, batch_size = 1 : i64, dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
+ "lmhlo.terminator"() : () -> ()
+ }
+ })";
+ std::vector<float> arg0 = {2, 3, 4, 5};
+ std::vector<float> arg1 = {1, 2, 3, 4};
+ auto outputs = RunMlirTextWithHostBuffers(
+ mlir_text, {ToUint8Span(&arg0), ToUint8Span(&arg1)})
+ .ConsumeValueOrDie();
+ ASSERT_EQ(1, outputs.size());
+ EXPECT_THAT(FromUint8Span<float>(outputs[0]),
+ ElementsAreArray<float>({11, 16, 19, 28}));
+}
+
+} // namespace gpu
+} // namespace xla