Improve the genrules for cubin kernel headers.
Add an additional build macro to build a library of headers.
Also change the output of bin2c to use type char instead of type int.
PiperOrigin-RevId: 314321754
Change-Id: I81ee3c7c962e807d28bb0f580eea8032f2a390ee
diff --git a/tensorflow/core/kernels/cubin_headers/BUILD b/tensorflow/core/kernels/cubin_headers/BUILD
index 509ac00..1d9b985 100644
--- a/tensorflow/core/kernels/cubin_headers/BUILD
+++ b/tensorflow/core/kernels/cubin_headers/BUILD
@@ -1,5 +1,10 @@
# Generates headers containing cubin for CUDA kernels.
-load("//tensorflow/core/kernels/cubin_headers:build_defs.bzl", "gen_kernel_image_hdr")
+load("//tensorflow/core/kernels/cubin_headers:build_defs.bzl", "gen_kernel_library")
+
+package(
+ default_visibility = ["//tensorflow/core/kernels:__subpackages__"],
+ licenses = ["notice"], # Apache 2.0
+)
bias_add_kernel = """
func @bias_add(%arg0: tensor<?x?xf99>,
@@ -10,19 +15,17 @@
}
"""
-[
- gen_kernel_image_hdr(
- name = "bias_add_{type}_kernel".format(type = type),
- op = bias_add_kernel.replace("f99", type).replace("DT_TYPE", dtype),
- same_shape = "0,2",
- tile_size = "16x16",
- )
- for (type, dtype) in [
- ("f16", "DT_HALF"),
- ("f32", "DT_FLOAT"),
- ("f64", "DT_DOUBLE"),
- ]
-]
+gen_kernel_library(
+ name = "bias_add",
+ op = bias_add_kernel,
+ same_shape = "0,2",
+ tile_size = "16x16",
+ types = [
+ "f16",
+ "f32",
+ "f64",
+ ],
+)
relu_kernel = """
func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> {
@@ -32,19 +35,17 @@
}
"""
-[
- gen_kernel_image_hdr(
- name = "relu_{type}_kernel".format(type = type),
- op = relu_kernel.replace("f99", type).replace("DT_TYPE", dtype),
- same_shape = "0,1",
- tile_size = "256",
- )
- for (type, dtype) in [
- ("f16", "DT_HALF"),
- ("f32", "DT_FLOAT"),
- ("f64", "DT_DOUBLE"),
- ]
-]
+gen_kernel_library(
+ name = "relu",
+ op = relu_kernel,
+ same_shape = "0,1",
+ tile_size = "256",
+ types = [
+ "f16",
+ "f32",
+ "f64",
+ ],
+)
tanh_kernel = """
func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> {
@@ -54,14 +55,12 @@
}
"""
-[
- gen_kernel_image_hdr(
- name = "tanh_{type}_kernel".format(type = type),
- op = tanh_kernel.replace("f99", type).replace("DT_TYPE", dtype),
- tile_size = "256",
- )
- for (type, dtype) in [
- ("f32", "DT_FLOAT"),
- ("f64", "DT_DOUBLE"),
- ]
-]
+gen_kernel_library(
+ name = "tanh",
+ op = tanh_kernel,
+ tile_size = "256",
+ types = [
+ "f32",
+ "f64",
+ ],
+)
diff --git a/tensorflow/core/kernels/cubin_headers/build_defs.bzl b/tensorflow/core/kernels/cubin_headers/build_defs.bzl
index f9dac50..bd19d7e 100644
--- a/tensorflow/core/kernels/cubin_headers/build_defs.bzl
+++ b/tensorflow/core/kernels/cubin_headers/build_defs.bzl
@@ -1,6 +1,6 @@
"""Generates cubin headers for TF dialect ops."""
-load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures")
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures", "if_cuda")
def _lookup_file(filegroup, path):
"""Extracts file at (relative) path in filegroup."""
@@ -61,12 +61,12 @@
outputs = [ctx.outputs.out],
inputs = [fatbin],
tools = [bin2c],
- command = "%s --static --const --type=int --name=%s %s 1> %s" %
+ command = "%s --static --const --type=char --name=%s %s 1> %s" %
(bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
mnemonic = "bin2c",
)
-_gen_kernel_image_hdr = rule(
+_gen_kernel_image_hdr_rule = rule(
implementation = _gen_kernel_image_hdr_impl,
output_to_genfiles = True,
attrs = {
@@ -87,10 +87,10 @@
},
)
-def gen_kernel_image_hdr(name, op, tile_size, tags = [], same_shape = None):
+def _gen_kernel_image_hdr(name, op, tile_size, tags = [], same_shape = None):
"""Generates a C header with fatbin data from a Tensorflow op."""
if cuda_gpu_architectures():
- _gen_kernel_image_hdr(
+ _gen_kernel_image_hdr_rule(
name = name,
op = op,
tile_size = tile_size,
@@ -100,3 +100,25 @@
gpu_archs = cuda_gpu_architectures(),
tags = tags,
)
+
+def gen_kernel_library(name, op, types, tile_size, tags = [], same_shape = None):
+ if cuda_gpu_architectures():
+ type_to_dtype = {
+ "f16": "DT_HALF",
+ "f32": "DT_FLOAT",
+ "f64": "DT_DOUBLE",
+ }
+ for type in types:
+ _gen_kernel_image_hdr(
+ name = "{name}_{type}_kernel".format(name = name, type = type),
+ op = op.replace("f99", type).replace("DT_TYPE", type_to_dtype[type]),
+ tile_size = tile_size,
+ tags = tags,
+ same_shape = same_shape,
+ )
+
+ native.cc_library(
+ name = name + "_kernels",
+ hdrs = if_cuda(if_true = [":{name}_{type}_kernel".format(name = name, type = type) for type in types]),
+ tags = tags,
+ )