| """Setup TensorFlow as external dependency""" |
| |
| _TF_HEADER_DIR = "TF_HEADER_DIR" |
| _TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" |
| _TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME" |
| |
| def _tpl(repository_ctx, tpl, substitutions = {}, out = None): |
| if not out: |
| out = tpl |
| repository_ctx.template( |
| out, |
| Label("//third_party/tensorflow:%s.tpl" % tpl), |
| substitutions, |
| ) |
| |
| def _fail(msg): |
| """Output failure message when auto configuration fails.""" |
| red = "\033[0;31m" |
| no_color = "\033[0m" |
| fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) |
| |
| def _is_windows(repository_ctx): |
| """Returns true if the host operating system is windows.""" |
| os_name = repository_ctx.os.name.lower() |
| if os_name.find("windows") != -1: |
| return True |
| return False |
| |
| def _execute( |
| repository_ctx, |
| cmdline, |
| error_msg = None, |
| error_details = None, |
| empty_stdout_fine = False): |
| """Executes an arbitrary shell command. |
| |
| Helper for executes an arbitrary shell command. |
| |
| Args: |
| repository_ctx: the repository_ctx object. |
| cmdline: list of strings, the command to execute. |
| error_msg: string, a summary of the error if the command fails. |
| error_details: string, details about the error or steps to fix it. |
| empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise |
| it's an error. |
| |
| Returns: |
| The result of repository_ctx.execute(cmdline). |
| """ |
| result = repository_ctx.execute(cmdline) |
| if result.stderr or not (empty_stdout_fine or result.stdout): |
| _fail("\n".join([ |
| error_msg.strip() if error_msg else "Repository command failed", |
| result.stderr.strip(), |
| error_details if error_details else "", |
| ])) |
| return result |
| |
| def _read_dir(repository_ctx, src_dir): |
| """Returns a string with all files in a directory. |
| |
| Finds all files inside a directory, traversing subfolders and following |
| symlinks. The returned string contains the full path of all files |
| separated by line breaks. |
| |
| Args: |
| repository_ctx: the repository_ctx object. |
| src_dir: directory to find files from. |
| |
| Returns: |
| A string of all files inside the given dir. |
| """ |
| if _is_windows(repository_ctx): |
| src_dir = src_dir.replace("/", "\\") |
| find_result = _execute( |
| repository_ctx, |
| ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], |
| empty_stdout_fine = True, |
| ) |
| |
| # src_files will be used in genrule.outs where the paths must |
| # use forward slashes. |
| result = find_result.stdout.replace("\\", "/") |
| else: |
| find_result = _execute( |
| repository_ctx, |
| ["find", src_dir, "-follow", "-type", "f"], |
| empty_stdout_fine = True, |
| ) |
| result = find_result.stdout |
| return result |
| |
| def _genrule(genrule_name, command, outs): |
| """Returns a string with a genrule. |
| |
| Genrule executes the given command and produces the given outputs. |
| |
| Args: |
| genrule_name: A unique name for genrule target. |
| command: The command to run. |
| outs: A list of files generated by this rule. |
| |
| Returns: |
| A genrule target. |
| """ |
| return ( |
| "genrule(\n" + |
| ' name = "' + |
| genrule_name + '",\n' + |
| " outs = [\n" + |
| outs + |
| "\n ],\n" + |
| ' cmd = """\n' + |
| command + |
| '\n """,\n' + |
| ")\n" |
| ) |
| |
| def _norm_path(path): |
| """Returns a path with '/' and remove the trailing slash.""" |
| path = path.replace("\\", "/") |
| if path[-1] == "/": |
| path = path[:-1] |
| return path |
| |
| def _symlink_genrule_for_dir( |
| repository_ctx, |
| src_dir, |
| dest_dir, |
| genrule_name, |
| src_files = [], |
| dest_files = [], |
| tf_pip_dir_rename_pair = []): |
| """Returns a genrule to symlink(or copy if on Windows) a set of files. |
| |
| If src_dir is passed, files will be read from the given directory; otherwise |
| we assume files are in src_files and dest_files. |
| |
| Args: |
| repository_ctx: the repository_ctx object. |
| src_dir: source directory. |
| dest_dir: directory to create symlink in. |
| genrule_name: genrule name. |
| src_files: list of source files instead of src_dir. |
| dest_files: list of corresonding destination files. |
| tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to |
| replace. For example, in TF pip package, the source code is under |
| "tensorflow_core", and we might want to replace it with |
| "tensorflow" to match the header includes. |
| Returns: |
| genrule target that creates the symlinks. |
| """ |
| |
| # Check that tf_pip_dir_rename_pair has the right length |
| tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair) |
| if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2: |
| _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len) |
| |
| if src_dir != None: |
| src_dir = _norm_path(src_dir) |
| dest_dir = _norm_path(dest_dir) |
| files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) |
| |
| # Create a list with the src_dir stripped to use for outputs. |
| if tf_pip_dir_rename_pair_len: |
| dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines() |
| else: |
| dest_files = files.replace(src_dir, "").splitlines() |
| src_files = files.splitlines() |
| command = [] |
| outs = [] |
| for i in range(len(dest_files)): |
| if dest_files[i] != "": |
| # If we have only one file to link we do not want to use the dest_dir, as |
| # $(@D) will include the full path to the file. |
| dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] |
| |
| # Copy the headers to create a sandboxable setup. |
| cmd = "cp -f" |
| command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) |
| outs.append(' "' + dest_dir + dest_files[i] + '",') |
| dest_dir = "abc" |
| genrule = _genrule( |
| genrule_name, |
| " && ".join(command), |
| "\n".join(outs), |
| ) |
| return genrule |
| |
| def _tf_pip_impl(repository_ctx): |
| tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR] |
| tf_header_rule = _symlink_genrule_for_dir( |
| repository_ctx, |
| tf_header_dir, |
| "include", |
| "tf_header_include", |
| tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"], |
| ) |
| |
| tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] |
| tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME] |
| tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name) |
| tf_shared_library_rule = _symlink_genrule_for_dir( |
| repository_ctx, |
| None, |
| "", |
| "libtensorflow_framework.so", |
| [tf_shared_library_path], |
| ["_pywrap_tensorflow_internal.lib" if _is_windows(repository_ctx) else "libtensorflow_framework.so"], |
| ) |
| |
| _tpl(repository_ctx, "BUILD", { |
| "%{TF_HEADER_GENRULE}": tf_header_rule, |
| "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, |
| }) |
| |
| tf_configure = repository_rule( |
| implementation = _tf_pip_impl, |
| environ = [ |
| _TF_HEADER_DIR, |
| _TF_SHARED_LIBRARY_DIR, |
| _TF_SHARED_LIBRARY_NAME, |
| ], |
| ) |