Make some kernel generator tests externally visible.

PiperOrigin-RevId: 333279975
Change-Id: Ia3dde126bf932bb7f5ee0be27cf41fe5d77d556a
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
index 4e0eedc..590a2b5 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
@@ -163,7 +163,7 @@
 tf_cc_binary(
     name = "kernel-gen-opt",
     srcs = ["tools/kernel-gen-opt/kernel-gen-opt.cc"],
-    visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__pkg__"],
+    visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__subpackages__"],
     deps = [
         "//tensorflow/compiler/mlir/hlo:all_passes",
         "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD
new file mode 100644
index 0000000..8c7ee72
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD
@@ -0,0 +1,28 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+package(licenses = ["notice"])
+
+glob_lit_tests(
+    data = [
+        ":test_utilities",
+        "@llvm-project//mlir:run_lit.sh",
+    ],
+    default_tags = [
+        # MSAN does not work with JIT: b/139082472
+        "nomsan",
+    ],
+    driver = "//tensorflow/compiler/mlir:run_lit.sh",
+    test_file_exts = ["mlir"],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+    name = "test_utilities",
+    testonly = True,
+    data = [
+        "//tensorflow/compiler/mlir:tf-opt",
+        "//tensorflow/compiler/mlir/hlo:mlir-hlo-opt",
+        "//tensorflow/compiler/mlir/tools/kernel_gen:kernel-gen-opt",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
new file mode 100644
index 0000000..762962f
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
@@ -0,0 +1,62 @@
+// RUN: kernel-gen-opt %s --bufferize | FileCheck %s
+
+// CHECK-LABEL: @extract_element
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32
+func @extract_element(%arg : tensor<?xf32>) -> f32 {
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[RESULT:.*]] = load %[[ARG]][%[[C0]]]
+  // CHECK: return %[[RESULT]]
+  %c0 = constant 0 : index
+  %result = extract_element %arg[%c0] : tensor<?xf32>
+  return %result : f32
+}
+
+// CHECK-LABEL: @tensor_load
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> memref<?xf32>
+func @tensor_load(%arg : memref<?xf32>) -> tensor<?xf32> {
+  // CHECK: return %[[ARG]] : memref<?xf32>
+  %result = tensor_load %arg : memref<?xf32>
+  return %result : tensor<?xf32>
+}
+
+// CHECK-LABEL: @tensor_from_elements
+// CHECK-SAME: (%[[A:.*]]: f32) -> memref<3xf32>
+func @tensor_from_elements(%a : f32) -> tensor<3xf32> {
+  // CHECK: %[[B:.*]] = constant 1.2
+  // CHECK: %[[C:.*]] = constant 2.3
+  // CHECK: %[[MEM:.*]] = alloca() : memref<3xf32>
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: store %[[A]], %[[MEM]][%[[C0]]] : memref<3xf32>
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: store %[[B]], %[[MEM]][%[[C1]]] : memref<3xf32>
+  // CHECK: %[[C2:.*]] = constant 2 : index
+  // CHECK: store %[[C]], %[[MEM]][%[[C2]]] : memref<3xf32>
+  // CHECK: return %[[MEM]] : memref<3xf32>
+  %b = constant 1.2 : f32
+  %c = constant 2.3 : f32
+  %result = tensor_from_elements %a, %b, %c : tensor<3xf32>
+  return %result : tensor<3xf32>
+}
+
+// CHECK-LABEL: @dynamic_tensor_from_elements
+// CHECK-SAME: (%[[ARG:.*]]: memref<*xf32>) -> memref<?xindex>
+func @dynamic_tensor_from_elements(%arg : tensor<*xf32>) -> tensor<?xindex> {
+  // CHECK: %[[C3:.*]] = constant 3 : index
+  // CHECK: %[[MEM:.*]] = alloca(%c3) : memref<?xindex>
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[C3]]) step (%[[C1]]) {
+  // CHECK:   %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : memref<*xf32>
+  // CHECK:   store %[[ELEM]], %[[MEM]][%[[I]]] : memref<?xindex>
+  // CHECK:   scf.yield
+  // CHECK: }
+  // CHECK: return %[[MEM]] : memref<?xindex>
+  %c3 = constant 3 : index
+  %result = dynamic_tensor_from_elements %c3 {
+  ^bb0(%i : index):
+    %elem = dim %arg, %i : tensor<*xf32>
+    yield %elem : index
+  } : tensor<?xindex>
+  return %result : tensor<?xindex>
+}
+
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir
new file mode 100644
index 0000000..bb0f192
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir
@@ -0,0 +1,37 @@
+// RUN: kernel-gen-opt %s -embed-tf-framework -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @tf_entry(
+// CHECK-SAME:    [[CTX:%.*]]: !tf_framework.op_kernel_context,
+// CHECK-SAME:    [[SIZE_0:%.*]]: index,
+// CHECK-SAME:    [[SIZE_2:%.*]]: index) -> index attributes {tf_entry} {
+func @tf_entry(%size_0 : index , %size_2 : index) -> index
+    attributes {tf_entry} {
+  %buf = alloc(%size_0, %size_2)[] : memref<?x10x?xf32>
+  dealloc %buf : memref<?x10x?xf32>
+  std.return %size_0 : index
+}
+// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw
+// CHECK-SAME:   ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32>
+// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
+// CHECK-NEXT: return [[SIZE_0]] : index
+
+// -----
+
+// CHECK-LABEL: func @non_tf_entry(
+// CHECK-SAME:    [[SIZE_0:%.*]]: index, [[SIZE_2:%.*]]: index) -> index
+func @non_tf_entry(%size_0 : index , %size_2 : index) -> index {
+  std.return %size_0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @tf_entry(
+func @tf_entry(%size : index) attributes {tf_entry} {
+  %buf = alloc()[%size] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+  dealloc %buf : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+  std.return
+}
+// CHECK_NOT: alloc_raw
+// CHECK: alloc()
+// CHECK_NOT: dealloc_raw
+// CHECK: dealloc %
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
new file mode 100644
index 0000000..1d1b331
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
@@ -0,0 +1,7 @@
+// RUN: kernel-gen-opt %s -split-input-file -verify-diagnostics
+
+func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) {
+  // expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
+  %buf = tf_framework.alloc_raw(%ctx, %size) : memref<?x10x?xi8>
+  return
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
new file mode 100644
index 0000000..fc8e7c9
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
@@ -0,0 +1,25 @@
+// RUN: kernel-gen-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: kernel-gen-opt %s | kernel-gen-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s
+
+// CHECK-LABEL: func @alloc_raw
+func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
+                   %size_0 : index , %size_2 : index) {
+  %buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8>
+  %buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
+  return
+}
+
+// CHECK-LABEL: func @dealloc_raw
+func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref<?x10xf32>) {
+  tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
+  return
+}
+
+// CHECK-LABEL: func @null_context
+func @null_context() {
+  tf_framework.null_context() : !tf_framework.op_kernel_context
+  return
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/parallel_loops_to_sequential.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/parallel_loops_to_sequential.mlir
new file mode 100644
index 0000000..df05975
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/parallel_loops_to_sequential.mlir
@@ -0,0 +1,17 @@
+// RUN: kernel-gen-opt %s --parallel-loops-to-sequential | FileCheck %s
+
+// CHECK-LABEL: @parallel_loop
+func @parallel_loop(%lb_0 : index, %lb_1 : index,
+                     %ub_0 : index, %ub_1 : index,
+                     %s_0 : index, %s_1 : index,
+                     %buf: memref<?x?xindex>) {
+  scf.parallel (%i0, %i1) = (%lb_0, %lb_1) to (%ub_0, %ub_1) step (%s_0, %s_1) {
+    %sum_elem = addi %i0, %i1 : index
+    store %sum_elem, %buf[%i0, %i1] : memref<?x?xindex>
+  }
+  return
+}
+// CHECK: scf.for [[I_0:%.*]] = [[LB_0:%.*]] to [[UB_0:%.*]] step [[S_0:%.*]]
+// CHECK:   scf.for [[I_1:%.*]] = [[LB_1:%.*]] to [[UB_1:%.*]] step [[S_1:%.*]]
+// CHECK:     [[SUM:%.*]] = addi [[I_0]], [[I_1]] : index
+// CHECK:     store [[SUM]], {{%.*}}{{\[}}[[I_0]], [[I_1]]] : memref<?x?xindex>
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
new file mode 100644
index 0000000..53d0232
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir
@@ -0,0 +1,20 @@
+// RUN: tf-opt %s --xla-legalize-tf | mlir-hlo-opt --transform-unranked-hlo | kernel-gen-opt -allow-unregistered-dialect --shape-to-descriptors --canonicalize --bufferize | FileCheck %s
+
+// Test whether all shape computations required for tanh can be lowered to
+// the standard dialect, scf and descriptors. We check for a sparse pattern here,
+// as each lowering pattern is already tested and we just care for the
+// integration.
+// TODO: Expand this pattern once things have stabilized.
+// CHECK-LABEL: @tanh
+func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK: alloca
+  // CHECK: scf.parallel
+  // CHECK-NOT: tensor_load
+  // CHECK: scf.for
+  // CHECK-NOT: tensor_from_elements
+  // CHECK: mhlo.reshape_memref_cast
+  // CHECK: lmhlo.tanh
+  // CHECK: mhlo.reshape_memref_cast
+  %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir
new file mode 100644
index 0000000..64c7ff2
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir
@@ -0,0 +1,26 @@
+// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --mhlo-test-chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize
+
+func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+func @tan(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Tan"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+func @sin(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Sin"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+func @sinh(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Sinh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
new file mode 100644
index 0000000..0b2834a
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
@@ -0,0 +1,75 @@
+// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s
+
+// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw
+// CHECK-SAME:  (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
+
+// CHECK-LABEL: llvm.func @alloc_raw(
+// CHECK-SAME:    [[TF_CTX:%.*]]: !llvm.ptr<i8>,
+// CHECK-SAME:    [[SIZE_0:%.*]]: !llvm.i64,
+// CHECK-SAME:    [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] {
+func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
+                %size_0 : index , %size_2 : index) -> memref<?x10x?xf32> {
+  %buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
+  std.return %buf : memref<?x10x?xf32>
+}
+// Compute number of elements.
+// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
+// CHECK: [[NUM_ELEM_0:%.*]] = llvm.mul [[SIZE_0]], [[SIZE_1]] : !llvm.i64
+// CHECK: [[NUM_ELEM_1:%.*]] = llvm.mul [[NUM_ELEM_0]], [[SIZE_2]] : !llvm.i64
+
+// Compute the size of an individual element.
+// CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<float>
+// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]]{{\[}}[[C1]]]
+// CHECK-SAME:            (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
+// CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]]
+// CHECK-SAME:            !llvm.ptr<float> to !llvm.i64
+
+// Allocate memory.
+// CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]]
+// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]])
+// CHECK-SAME:                  (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
+
+// Build memref descriptor.
+// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]]
+
+// Set pointers and offset.
+// CHECK: [[FLOAT_PTR:%.*]] = llvm.bitcast [[BYTES_PTR]]
+// CHECK-SAME:                  !llvm.ptr<i8> to !llvm.ptr<float>
+// CHECK: [[DESC_1:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_0]][0]
+// CHECK: [[DESC_2:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_1]][1]
+// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: [[DESC_3:%.*]] = llvm.insertvalue [[C0]], [[DESC_2]][2] : [[DESC_TY]]
+
+// Set sizes and strides.
+// CHECK: [[STRIDE_2:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: [[DESC_4:%.*]] = llvm.insertvalue [[SIZE_2]], [[DESC_3]][3, 2]
+// CHECK: [[DESC_5:%.*]] = llvm.insertvalue [[STRIDE_2]], [[DESC_4]][4, 2]
+// CHECK: [[STRIDE_1:%.*]] = llvm.mul [[STRIDE_2]], [[SIZE_2]] : !llvm.i64
+// CHECK: [[DESC_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[DESC_5]][3, 1]
+// CHECK: [[DESC_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[DESC_6]][4, 1]
+// CHECK: [[STRIDE_0:%.*]] = llvm.mul [[STRIDE_1]], [[SIZE_1]] : !llvm.i64
+// CHECK: [[DESC_8:%.*]] = llvm.insertvalue [[SIZE_0]], [[DESC_7]][3, 0]
+// CHECK: [[DESC_9:%.*]] = llvm.insertvalue [[STRIDE_0]], [[DESC_8]][4, 0]
+// CHECK: llvm.return [[DESC_9]] : [[DESC_TY]]
+
+// -----
+
+// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr<i8>)
+
+// CHECK-LABEL: llvm.func @dealloc_raw(
+// CHECK-SAME:    [[TF_CTX:%.*]]: !llvm.ptr<i8>,
+func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
+                  %memref : memref<?x10xf32>) {
+  tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
+  return
+}
+// Extract allocated ptr from the memref descriptor.
+// CHECK: %{{.*}} = llvm.mlir.undef : [[DESC_TY:!.*]]
+// CHECK: [[FLOAT_PTR:%.*]] = llvm.extractvalue %{{.*}}[0] : [[DESC_TY]]
+// CHECK-NEXT: [[VOID_PTR:%.*]] = llvm.bitcast [[FLOAT_PTR]]
+// CHECK-SAME:                   !llvm.ptr<float> to !llvm.ptr<i8>
+
+// Deallocate.
+// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw(
+// CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()