blob: ebae2dbb89f7e2a2994ece848d866e52e03e1c37 [file] [log] [blame]
"""Repository rule for TensorRT configuration.
`tensorrt_configure` depends on the following environment variables:
* `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
* `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
"""
load(
"//third_party/gpus:cuda_configure.bzl",
"find_cuda_config",
"get_cpu_value",
"lib_name",
"make_copy_files_rule",
)
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
_TF_TENSORRT_HEADERS_V6 = [
"NvInfer.h",
"NvUtils.h",
"NvInferPlugin.h",
"NvInferVersion.h",
"NvInferRuntime.h",
"NvInferRuntimeCommon.h",
"NvInferPluginUtils.h",
]
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
def _at_least_version(actual_version, required_version):
actual = [int(v) for v in actual_version.split(".")]
required = [int(v) for v in required_version.split(".")]
return actual >= required
def _get_tensorrt_headers(tensorrt_version):
if _at_least_version(tensorrt_version, "6"):
return _TF_TENSORRT_HEADERS_V6
return _TF_TENSORRT_HEADERS
def _tpl(repository_ctx, tpl, substitutions):
repository_ctx.template(
tpl,
Label("//third_party/tensorrt:%s.tpl" % tpl),
substitutions,
)
def _create_dummy_repository(repository_ctx):
"""Create a dummy TensorRT repository."""
_tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"})
_tpl(repository_ctx, "BUILD", {
"%{copy_rules}": "",
"\":tensorrt_include\"": "",
"\":tensorrt_lib\"": "",
})
_tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
"%{tensorrt_version}": "",
})
def enable_tensorrt(repository_ctx):
"""Returns whether to build with TensorRT support."""
return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False))
def _tensorrt_configure_impl(repository_ctx):
"""Implementation of the tensorrt_configure repository rule."""
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
# Forward to the pre-configured remote repository.
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
repository_ctx.template(
"build_defs.bzl",
Label(remote_config_repo + ":build_defs.bzl"),
{},
)
repository_ctx.template(
"tensorrt/include/tensorrt_config.h",
Label(remote_config_repo + ":tensorrt/include/tensorrt_config.h"),
{},
)
repository_ctx.template(
"LICENSE",
Label(remote_config_repo + ":LICENSE"),
{},
)
return
# Copy license file in non-remote build.
repository_ctx.template(
"LICENSE",
Label("//third_party/tensorrt:LICENSE"),
{},
)
if not enable_tensorrt(repository_ctx):
_create_dummy_repository(repository_ctx)
return
config = find_cuda_config(repository_ctx, ["tensorrt"])
trt_version = config["tensorrt_version"]
cpu_value = get_cpu_value(repository_ctx)
# Copy the library and header files.
libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
library_dir = config["tensorrt_library_dir"] + "/"
headers = _get_tensorrt_headers(trt_version)
include_dir = config["tensorrt_include_dir"] + "/"
copy_rules = [
make_copy_files_rule(
repository_ctx,
name = "tensorrt_lib",
srcs = [library_dir + library for library in libraries],
outs = ["tensorrt/lib/" + library for library in libraries],
),
make_copy_files_rule(
repository_ctx,
name = "tensorrt_include",
srcs = [include_dir + header for header in headers],
outs = ["tensorrt/include/" + header for header in headers],
),
]
# Set up config file.
_tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"})
# Set up BUILD file.
_tpl(repository_ctx, "BUILD", {
"%{copy_rules}": "\n".join(copy_rules),
})
# Set up tensorrt_config.h, which is used by
# tensorflow/stream_executor/dso_loader.cc.
_tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
"%{tensorrt_version}": trt_version,
})
tensorrt_configure = repository_rule(
implementation = _tensorrt_configure_impl,
environ = [
_TENSORRT_INSTALL_PATH,
_TF_TENSORRT_VERSION,
_TF_TENSORRT_CONFIG_REPO,
_TF_NEED_TENSORRT,
"TF_CUDA_PATHS",
],
)
"""Detects and configures the local CUDA toolchain.
Add the following to your WORKSPACE FILE:
```python
tensorrt_configure(name = "local_config_tensorrt")
```
Args:
name: A unique name for this workspace rule.
"""