| #!/usr/bin/env python3 | 
 |  | 
 | import argparse | 
 | import asyncio | 
 | import collections | 
 | import csv | 
 | import hashlib | 
 | import itertools | 
 | import os | 
 | import pathlib | 
 | import re | 
 | import shlex | 
 | import shutil | 
 | import subprocess | 
 | import sys | 
 | import time | 
 | from typing import Awaitable, cast, DefaultDict, Dict, List, Match, Optional, Set | 
 |  | 
 | from typing_extensions import TypedDict | 
 |  | 
 | help_msg = """fast_nvcc [OPTION]... -- [NVCC_ARG]... | 
 |  | 
 | Run the commands given by nvcc --dryrun, in parallel. | 
 |  | 
 | All flags for this script itself (see the "optional arguments" section | 
 | of --help) must be passed before the first "--". Everything after that | 
 | first "--" is passed directly to nvcc, with the --dryrun argument added. | 
 |  | 
 | This script only works with the "normal" execution path of nvcc, so for | 
 | instance passing --help (after "--") doesn't work since the --help | 
 | execution path doesn't compile anything, so adding --dryrun there gives | 
 | nothing in stderr. | 
 | """ | 
 | parser = argparse.ArgumentParser(help_msg) | 
 | parser.add_argument( | 
 |     "--faithful", | 
 |     action="store_true", | 
 |     help="don't modify the commands given by nvcc (slower)", | 
 | ) | 
 | parser.add_argument( | 
 |     "--graph", | 
 |     metavar="FILE.gv", | 
 |     help="write Graphviz DOT file with execution graph", | 
 | ) | 
 | parser.add_argument( | 
 |     "--nvcc", | 
 |     metavar="PATH", | 
 |     default="nvcc", | 
 |     help='path to nvcc (default is just "nvcc")', | 
 | ) | 
 | parser.add_argument( | 
 |     "--save", | 
 |     metavar="DIR", | 
 |     help="copy intermediate files from each command into DIR", | 
 | ) | 
 | parser.add_argument( | 
 |     "--sequential", | 
 |     action="store_true", | 
 |     help="sequence commands instead of using the graph (slower)", | 
 | ) | 
 | parser.add_argument( | 
 |     "--table", | 
 |     metavar="FILE.csv", | 
 |     help="write CSV with times and intermediate file sizes", | 
 | ) | 
 | parser.add_argument( | 
 |     "--verbose", | 
 |     metavar="FILE.txt", | 
 |     help="like nvcc --verbose, but expanded and into a file", | 
 | ) | 
 | default_config = parser.parse_args([]) | 
 |  | 
 |  | 
 | # docs about temporary directories used by NVCC | 
 | url_base = "https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html" | 
 | url_vars = f"{url_base}#keeping-intermediate-phase-files" | 
 |  | 
 |  | 
 | # regex for temporary file names | 
 | re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)" | 
 |  | 
 |  | 
 | def fast_nvcc_warn(warning: str) -> None: | 
 |     """ | 
 |     Warn the user about something regarding fast_nvcc. | 
 |     """ | 
 |     print(f"warning (fast_nvcc): {warning}", file=sys.stderr) | 
 |  | 
 |  | 
 | def warn_if_windows() -> None: | 
 |     """ | 
 |     Warn the user that using fast_nvcc on Windows might not work. | 
 |     """ | 
 |     # use os.name instead of platform.system() because there is a | 
 |     # platform.py file in this directory, making it very difficult to | 
 |     # import the platform module from the Python standard library | 
 |     if os.name == "nt": | 
 |         fast_nvcc_warn("untested on Windows, might not work; see this URL:") | 
 |         fast_nvcc_warn(url_vars) | 
 |  | 
 |  | 
 | def warn_if_tmpdir_flag(args: List[str]) -> None: | 
 |     """ | 
 |     Warn the user that using fast_nvcc with some flags might not work. | 
 |     """ | 
 |     file_path_specs = "file-and-path-specifications" | 
 |     guiding_driver = "options-for-guiding-compiler-driver" | 
 |     scary_flags = { | 
 |         "--objdir-as-tempdir": file_path_specs, | 
 |         "-objtemp": file_path_specs, | 
 |         "--keep": guiding_driver, | 
 |         "-keep": guiding_driver, | 
 |         "--keep-dir": guiding_driver, | 
 |         "-keep-dir": guiding_driver, | 
 |         "--save-temps": guiding_driver, | 
 |         "-save-temps": guiding_driver, | 
 |     } | 
 |     for arg in args: | 
 |         for flag, frag in scary_flags.items(): | 
 |             if re.match(rf"^{re.escape(flag)}(?:=.*)?$", arg): | 
 |                 fast_nvcc_warn(f"{flag} not supported since it interacts with") | 
 |                 fast_nvcc_warn("TMPDIR, so fast_nvcc may break; see this URL:") | 
 |                 fast_nvcc_warn(f"{url_base}#{frag}") | 
 |  | 
 |  | 
 | class DryunData(TypedDict): | 
 |     env: Dict[str, str] | 
 |     commands: List[str] | 
 |     exit_code: int | 
 |  | 
 |  | 
 | def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData: | 
 |     """ | 
 |     Return parsed environment variables and commands from nvcc --dryrun. | 
 |     """ | 
 |     result = subprocess.run(  # type: ignore[call-overload] | 
 |         [binary, "--dryrun"] + args, | 
 |         capture_output=True, | 
 |         encoding="ascii",  # this is just a guess | 
 |     ) | 
 |     print(result.stdout, end="") | 
 |     env = {} | 
 |     commands = [] | 
 |     for line in result.stderr.splitlines(): | 
 |         match = re.match(r"^#\$ (.*)$", line) | 
 |         if match: | 
 |             (stripped,) = match.groups() | 
 |             mapping = re.match(r"^(\w+)=(.*)$", stripped) | 
 |             if mapping: | 
 |                 name, val = mapping.groups() | 
 |                 env[name] = val | 
 |             else: | 
 |                 commands.append(stripped) | 
 |         else: | 
 |             print(line, file=sys.stderr) | 
 |     return {"env": env, "commands": commands, "exit_code": result.returncode} | 
 |  | 
 |  | 
 | def warn_if_tmpdir_set(env: Dict[str, str]) -> None: | 
 |     """ | 
 |     Warn the user that setting TMPDIR with fast_nvcc might not work. | 
 |     """ | 
 |     if os.getenv("TMPDIR") or "TMPDIR" in env: | 
 |         fast_nvcc_warn("TMPDIR is set, might not work; see this URL:") | 
 |         fast_nvcc_warn(url_vars) | 
 |  | 
 |  | 
 | def contains_non_executable(commands: List[str]) -> bool: | 
 |     for command in commands: | 
 |         # This is to deal with special command dry-run result from NVCC such as: | 
 |         # ``` | 
 |         # #$ "/lib64/ccache"/c++ -std=c++11 -E -x c++ -D__CUDACC__ -D__NVCC__  -fPIC -fvisibility=hidden -O3 \ | 
 |         #   -I ... -m64 "reduce_scatter.cu" > "/tmp/tmpxft_0037fae3_00000000-5_reduce_scatter.cpp4.ii | 
 |         # #$ -- Filter Dependencies -- > ... pytorch/build/nccl/obj/collectives/device/reduce_scatter.dep.tmp | 
 |         # ``` | 
 |         if command.startswith("--"): | 
 |             return True | 
 |     return False | 
 |  | 
 |  | 
 | def module_id_contents(command: List[str]) -> str: | 
 |     """ | 
 |     Guess the contents of the .module_id file contained within command. | 
 |     """ | 
 |     if command[0] == "cicc": | 
 |         path = command[-3] | 
 |     elif command[0] == "cudafe++": | 
 |         path = command[-1] | 
 |     middle = pathlib.PurePath(path).name.replace("-", "_").replace(".", "_") | 
 |     # this suffix is very wrong (the real one is far less likely to be | 
 |     # unique), but it seems difficult to find a rule that reproduces the | 
 |     # real suffixes, so here's one that, while inaccurate, is at least | 
 |     # hopefully as straightforward as possible | 
 |     suffix = hashlib.md5(str.encode(middle)).hexdigest()[:8] | 
 |     return f"_{len(middle)}_{middle}_{suffix}" | 
 |  | 
 |  | 
 | def unique_module_id_files(commands: List[str]) -> List[str]: | 
 |     """ | 
 |     Give each command its own .module_id filename instead of sharing. | 
 |     """ | 
 |     module_id = None | 
 |     uniqueified = [] | 
 |     for i, line in enumerate(commands): | 
 |         arr = [] | 
 |  | 
 |         def uniqueify(s: Match[str]) -> str: | 
 |             filename = re.sub(r"\-(\d+)", r"-\1-" + str(i), s.group(0)) | 
 |             arr.append(filename) | 
 |             return filename | 
 |  | 
 |         line = re.sub(re_tmp + r".module_id", uniqueify, line) | 
 |         line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line) | 
 |         if arr: | 
 |             (filename,) = arr | 
 |             if not module_id: | 
 |                 module_id = module_id_contents(shlex.split(line)) | 
 |             uniqueified.append(f"echo -n '{module_id}' > '{filename}'") | 
 |         uniqueified.append(line) | 
 |     return uniqueified | 
 |  | 
 |  | 
 | def make_rm_force(commands: List[str]) -> List[str]: | 
 |     """ | 
 |     Add --force to all rm commands. | 
 |     """ | 
 |     return [f"{c} --force" if c.startswith("rm ") else c for c in commands] | 
 |  | 
 |  | 
 | def print_verbose_output( | 
 |     *, | 
 |     env: Dict[str, str], | 
 |     commands: List[List[str]], | 
 |     filename: str, | 
 | ) -> None: | 
 |     """ | 
 |     Human-readably write nvcc --dryrun data to stderr. | 
 |     """ | 
 |     padding = len(str(len(commands) - 1)) | 
 |     with open(filename, "w") as f: | 
 |         for name, val in env.items(): | 
 |             print(f'#{" "*padding}$ {name}={val}', file=f) | 
 |         for i, command in enumerate(commands): | 
 |             prefix = f"{str(i).rjust(padding)}$ " | 
 |             print(f"#{prefix}{command[0]}", file=f) | 
 |             for part in command[1:]: | 
 |                 print(f'#{" "*len(prefix)}{part}', file=f) | 
 |  | 
 |  | 
 | Graph = List[Set[int]] | 
 |  | 
 |  | 
 | def straight_line_dependencies(commands: List[str]) -> Graph: | 
 |     """ | 
 |     Return a straight-line dependency graph. | 
 |     """ | 
 |     return [({i - 1} if i > 0 else set()) for i in range(len(commands))] | 
 |  | 
 |  | 
 | def files_mentioned(command: str) -> List[str]: | 
 |     """ | 
 |     Return fully-qualified names of all tmp files referenced by command. | 
 |     """ | 
 |     return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)] | 
 |  | 
 |  | 
 | def nvcc_data_dependencies(commands: List[str]) -> Graph: | 
 |     """ | 
 |     Return a list of the set of dependencies for each command. | 
 |     """ | 
 |     # fatbin needs to be treated specially because while the cicc steps | 
 |     # do refer to .fatbin.c files, they do so through the | 
 |     # --include_file_name option, since they're generating files that | 
 |     # refer to .fatbin.c file(s) that will later be created by the | 
 |     # fatbinary step; so for most files, we make a data dependency from | 
 |     # the later step to the earlier step, but for .fatbin.c files, the | 
 |     # data dependency is sort of flipped, because the steps that use the | 
 |     # files generated by cicc need to wait for the fatbinary step to | 
 |     # finish first | 
 |     tmp_files: Dict[str, int] = {} | 
 |     fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set) | 
 |     graph = [] | 
 |     for i, line in enumerate(commands): | 
 |         deps = set() | 
 |         for tmp in files_mentioned(line): | 
 |             if tmp in tmp_files: | 
 |                 dep = tmp_files[tmp] | 
 |                 deps.add(dep) | 
 |                 if dep in fatbins: | 
 |                     for filename in fatbins[dep]: | 
 |                         if filename in tmp_files: | 
 |                             deps.add(tmp_files[filename]) | 
 |             if tmp.endswith(".fatbin.c") and not line.startswith("fatbinary"): | 
 |                 fatbins[i].add(tmp) | 
 |             else: | 
 |                 tmp_files[tmp] = i | 
 |         if line.startswith("rm ") and not deps: | 
 |             deps.add(i - 1) | 
 |         graph.append(deps) | 
 |     return graph | 
 |  | 
 |  | 
 | def is_weakly_connected(graph: Graph) -> bool: | 
 |     """ | 
 |     Return true iff graph is weakly connected. | 
 |     """ | 
 |     if not graph: | 
 |         return True | 
 |     neighbors: List[Set[int]] = [set() for _ in graph] | 
 |     for node, predecessors in enumerate(graph): | 
 |         for pred in predecessors: | 
 |             neighbors[pred].add(node) | 
 |             neighbors[node].add(pred) | 
 |     # assume nonempty graph | 
 |     stack = [0] | 
 |     found = {0} | 
 |     while stack: | 
 |         node = stack.pop() | 
 |         for neighbor in neighbors[node]: | 
 |             if neighbor not in found: | 
 |                 found.add(neighbor) | 
 |                 stack.append(neighbor) | 
 |     return len(found) == len(graph) | 
 |  | 
 |  | 
 | def warn_if_not_weakly_connected(graph: Graph) -> None: | 
 |     """ | 
 |     Warn the user if the execution graph is not weakly connected. | 
 |     """ | 
 |     if not is_weakly_connected(graph): | 
 |         fast_nvcc_warn("execution graph is not (weakly) connected") | 
 |  | 
 |  | 
 | def print_dot_graph( | 
 |     *, | 
 |     commands: List[List[str]], | 
 |     graph: Graph, | 
 |     filename: str, | 
 | ) -> None: | 
 |     """ | 
 |     Print a DOT file displaying short versions of the commands in graph. | 
 |     """ | 
 |  | 
 |     def name(k: int) -> str: | 
 |         return f'"{k} {os.path.basename(commands[k][0])}"' | 
 |  | 
 |     with open(filename, "w") as f: | 
 |         print("digraph {", file=f) | 
 |         # print all nodes, in case it's disconnected | 
 |         for i in range(len(graph)): | 
 |             print(f"    {name(i)};", file=f) | 
 |         for i, deps in enumerate(graph): | 
 |             for j in deps: | 
 |                 print(f"    {name(j)} -> {name(i)};", file=f) | 
 |         print("}", file=f) | 
 |  | 
 |  | 
 | class Result(TypedDict, total=False): | 
 |     exit_code: int | 
 |     stdout: bytes | 
 |     stderr: bytes | 
 |     time: float | 
 |     files: Dict[str, int] | 
 |  | 
 |  | 
 | async def run_command( | 
 |     command: str, | 
 |     *, | 
 |     env: Dict[str, str], | 
 |     deps: Set[Awaitable[Result]], | 
 |     gather_data: bool, | 
 |     i: int, | 
 |     save: Optional[str], | 
 | ) -> Result: | 
 |     """ | 
 |     Run the command with the given env after waiting for deps. | 
 |     """ | 
 |     for task in deps: | 
 |         dep_result = await task | 
 |         # abort if a previous step failed | 
 |         if "exit_code" not in dep_result or dep_result["exit_code"] != 0: | 
 |             return {} | 
 |     if gather_data: | 
 |         t1 = time.monotonic() | 
 |     proc = await asyncio.create_subprocess_shell( | 
 |         command, | 
 |         env=env, | 
 |         stdout=asyncio.subprocess.PIPE, | 
 |         stderr=asyncio.subprocess.PIPE, | 
 |     ) | 
 |     stdout, stderr = await proc.communicate() | 
 |     code = cast(int, proc.returncode) | 
 |     results: Result = {"exit_code": code, "stdout": stdout, "stderr": stderr} | 
 |     if gather_data: | 
 |         t2 = time.monotonic() | 
 |         results["time"] = t2 - t1 | 
 |         sizes = {} | 
 |         for tmp_file in files_mentioned(command): | 
 |             if os.path.exists(tmp_file): | 
 |                 sizes[tmp_file] = os.path.getsize(tmp_file) | 
 |             else: | 
 |                 sizes[tmp_file] = 0 | 
 |         results["files"] = sizes | 
 |     if save: | 
 |         dest = pathlib.Path(save) / str(i) | 
 |         dest.mkdir() | 
 |         for src in map(pathlib.Path, files_mentioned(command)): | 
 |             if src.exists(): | 
 |                 shutil.copy2(src, dest / (src.name)) | 
 |     return results | 
 |  | 
 |  | 
 | async def run_graph( | 
 |     *, | 
 |     env: Dict[str, str], | 
 |     commands: List[str], | 
 |     graph: Graph, | 
 |     gather_data: bool = False, | 
 |     save: Optional[str] = None, | 
 | ) -> List[Result]: | 
 |     """ | 
 |     Return outputs/errors (and optionally time/file info) from commands. | 
 |     """ | 
 |     tasks: List[Awaitable[Result]] = [] | 
 |     for i, (command, indices) in enumerate(zip(commands, graph)): | 
 |         deps = {tasks[j] for j in indices} | 
 |         tasks.append( | 
 |             asyncio.create_task( | 
 |                 run_command(  # type: ignore[attr-defined] | 
 |                     command, | 
 |                     env=env, | 
 |                     deps=deps, | 
 |                     gather_data=gather_data, | 
 |                     i=i, | 
 |                     save=save, | 
 |                 ) | 
 |             ) | 
 |         ) | 
 |     return [await task for task in tasks] | 
 |  | 
 |  | 
 | def print_command_outputs(command_results: List[Result]) -> None: | 
 |     """ | 
 |     Print captured stdout and stderr from commands. | 
 |     """ | 
 |     for result in command_results: | 
 |         sys.stdout.write(result.get("stdout", b"").decode("ascii")) | 
 |         sys.stderr.write(result.get("stderr", b"").decode("ascii")) | 
 |  | 
 |  | 
 | def write_log_csv( | 
 |     command_parts: List[List[str]], | 
 |     command_results: List[Result], | 
 |     *, | 
 |     filename: str, | 
 | ) -> None: | 
 |     """ | 
 |     Write a CSV file of the times and /tmp file sizes from each command. | 
 |     """ | 
 |     tmp_files: List[str] = [] | 
 |     for result in command_results: | 
 |         tmp_files.extend(result.get("files", {}).keys()) | 
 |     with open(filename, "w", newline="") as csvfile: | 
 |         fieldnames = ["command", "seconds"] + list(dict.fromkeys(tmp_files)) | 
 |         writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | 
 |         writer.writeheader() | 
 |         for i, result in enumerate(command_results): | 
 |             command = f"{i} {os.path.basename(command_parts[i][0])}" | 
 |             row = {"command": command, "seconds": result.get("time", 0)} | 
 |             writer.writerow({**row, **result.get("files", {})}) | 
 |  | 
 |  | 
 | def exit_code(results: List[Result]) -> int: | 
 |     """ | 
 |     Aggregate individual exit codes into a single code. | 
 |     """ | 
 |     for result in results: | 
 |         code = result.get("exit_code", 0) | 
 |         if code != 0: | 
 |             return code | 
 |     return 0 | 
 |  | 
 |  | 
 | def wrap_nvcc( | 
 |     args: List[str], | 
 |     config: argparse.Namespace = default_config, | 
 | ) -> int: | 
 |     return subprocess.call([config.nvcc] + args) | 
 |  | 
 |  | 
 | def fast_nvcc( | 
 |     args: List[str], | 
 |     *, | 
 |     config: argparse.Namespace = default_config, | 
 | ) -> int: | 
 |     """ | 
 |     Emulate the result of calling the given nvcc binary with args. | 
 |  | 
 |     Should run faster than plain nvcc. | 
 |     """ | 
 |     warn_if_windows() | 
 |     warn_if_tmpdir_flag(args) | 
 |     dryrun_data = nvcc_dryrun_data(config.nvcc, args) | 
 |     env = dryrun_data["env"] | 
 |     warn_if_tmpdir_set(env) | 
 |     commands = dryrun_data["commands"] | 
 |     if not config.faithful: | 
 |         commands = make_rm_force(unique_module_id_files(commands)) | 
 |  | 
 |     if contains_non_executable(commands): | 
 |         return wrap_nvcc(args, config) | 
 |  | 
 |     command_parts = list(map(shlex.split, commands)) | 
 |     if config.verbose: | 
 |         print_verbose_output( | 
 |             env=env, | 
 |             commands=command_parts, | 
 |             filename=config.verbose, | 
 |         ) | 
 |     graph = nvcc_data_dependencies(commands) | 
 |     warn_if_not_weakly_connected(graph) | 
 |     if config.graph: | 
 |         print_dot_graph( | 
 |             commands=command_parts, | 
 |             graph=graph, | 
 |             filename=config.graph, | 
 |         ) | 
 |     if config.sequential: | 
 |         graph = straight_line_dependencies(commands) | 
 |     results = asyncio.run( | 
 |         run_graph(  # type: ignore[attr-defined] | 
 |             env=env, | 
 |             commands=commands, | 
 |             graph=graph, | 
 |             gather_data=bool(config.table), | 
 |             save=config.save, | 
 |         ) | 
 |     ) | 
 |     print_command_outputs(results) | 
 |     if config.table: | 
 |         write_log_csv(command_parts, results, filename=config.table) | 
 |     return exit_code([dryrun_data] + results)  # type: ignore[arg-type, operator] | 
 |  | 
 |  | 
 | def our_arg(arg: str) -> bool: | 
 |     return arg != "--" | 
 |  | 
 |  | 
 | if __name__ == "__main__": | 
 |     argv = sys.argv[1:] | 
 |     us = list(itertools.takewhile(our_arg, argv)) | 
 |     them = list(itertools.dropwhile(our_arg, argv)) | 
 |     sys.exit(fast_nvcc(them[1:], config=parser.parse_args(us))) |