blob: 328262559eeed1404c852d7b3b915b3eb555f717 [file] [log] [blame]
"""Setup TensorFlow as external dependency"""
_TF_HEADER_DIR = "TF_HEADER_DIR"
_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out:
out = tpl
repository_ctx.template(
out,
Label("//third_party/tensorflow:%s.tpl" % tpl),
substitutions,
)
def _fail(msg):
"""Output failure message when auto configuration fails."""
red = "\033[0;31m"
no_color = "\033[0m"
fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows."""
os_name = repository_ctx.os.name.lower()
if os_name.find("windows") != -1:
return True
return False
def _execute(
repository_ctx,
cmdline,
error_msg = None,
error_details = None,
empty_stdout_fine = False):
"""Executes an arbitrary shell command.
Helper for executes an arbitrary shell command.
Args:
repository_ctx: the repository_ctx object.
cmdline: list of strings, the command to execute.
error_msg: string, a summary of the error if the command fails.
error_details: string, details about the error or steps to fix it.
empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
it's an error.
Returns:
The result of repository_ctx.execute(cmdline).
"""
result = repository_ctx.execute(cmdline)
if result.stderr or not (empty_stdout_fine or result.stdout):
_fail("\n".join([
error_msg.strip() if error_msg else "Repository command failed",
result.stderr.strip(),
error_details if error_details else "",
]))
return result
def _read_dir(repository_ctx, src_dir):
"""Returns a string with all files in a directory.
Finds all files inside a directory, traversing subfolders and following
symlinks. The returned string contains the full path of all files
separated by line breaks.
Args:
repository_ctx: the repository_ctx object.
src_dir: directory to find files from.
Returns:
A string of all files inside the given dir.
"""
if _is_windows(repository_ctx):
src_dir = src_dir.replace("/", "\\")
find_result = _execute(
repository_ctx,
["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
empty_stdout_fine = True,
)
# src_files will be used in genrule.outs where the paths must
# use forward slashes.
result = find_result.stdout.replace("\\", "/")
else:
find_result = _execute(
repository_ctx,
["find", src_dir, "-follow", "-type", "f"],
empty_stdout_fine = True,
)
result = find_result.stdout
return result
def _genrule(genrule_name, command, outs):
"""Returns a string with a genrule.
Genrule executes the given command and produces the given outputs.
Args:
genrule_name: A unique name for genrule target.
command: The command to run.
outs: A list of files generated by this rule.
Returns:
A genrule target.
"""
return (
"genrule(\n" +
' name = "' +
genrule_name + '",\n' +
" outs = [\n" +
outs +
"\n ],\n" +
' cmd = """\n' +
command +
'\n """,\n' +
")\n"
)
def _norm_path(path):
"""Returns a path with '/' and remove the trailing slash."""
path = path.replace("\\", "/")
if path[-1] == "/":
path = path[:-1]
return path
def _symlink_genrule_for_dir(
repository_ctx,
src_dir,
dest_dir,
genrule_name,
src_files = [],
dest_files = [],
tf_pip_dir_rename_pair = []):
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
we assume files are in src_files and dest_files.
Args:
repository_ctx: the repository_ctx object.
src_dir: source directory.
dest_dir: directory to create symlink in.
genrule_name: genrule name.
src_files: list of source files instead of src_dir.
dest_files: list of corresonding destination files.
tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
replace. For example, in TF pip package, the source code is under
"tensorflow_core", and we might want to replace it with
"tensorflow" to match the header includes.
Returns:
genrule target that creates the symlinks.
"""
# Check that tf_pip_dir_rename_pair has the right length
tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
_fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)
if src_dir != None:
src_dir = _norm_path(src_dir)
dest_dir = _norm_path(dest_dir)
files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
# Create a list with the src_dir stripped to use for outputs.
if tf_pip_dir_rename_pair_len:
dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
else:
dest_files = files.replace(src_dir, "").splitlines()
src_files = files.splitlines()
command = []
outs = []
for i in range(len(dest_files)):
if dest_files[i] != "":
# If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file.
dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
# Copy the headers to create a sandboxable setup.
cmd = "cp -f"
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
dest_dir = "abc"
genrule = _genrule(
genrule_name,
" && ".join(command),
"\n".join(outs),
)
return genrule
def _tf_pip_impl(repository_ctx):
tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
tf_header_rule = _symlink_genrule_for_dir(
repository_ctx,
tf_header_dir,
"include",
"tf_header_include",
tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
)
tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
tf_shared_library_rule = _symlink_genrule_for_dir(
repository_ctx,
None,
"",
"libtensorflow_framework.so",
[tf_shared_library_path],
["_pywrap_tensorflow_internal.lib" if _is_windows(repository_ctx) else "libtensorflow_framework.so"],
)
_tpl(repository_ctx, "BUILD", {
"%{TF_HEADER_GENRULE}": tf_header_rule,
"%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
})
tf_configure = repository_rule(
implementation = _tf_pip_impl,
environ = [
_TF_HEADER_DIR,
_TF_SHARED_LIBRARY_DIR,
_TF_SHARED_LIBRARY_NAME,
],
)