| #!/usr/bin/env python |
| """ |
| A script that runs clang-format on changes detected via git. It will |
| report if running clang-format generated any changes. |
| |
| In CI, the script considers it a failure if running clang-format makes a change. |
| In the pre-commit hook, the user is prompted to apply any clang-format changes. |
| Running tools/clang_format.py manually with no arguments should replicate the pre-commit hook behavior. |
| |
| Only files that are in CLANG_FORMAT_WHITELIST are checked. |
| """ |
| import subprocess |
| import glob |
| import itertools |
| import os |
| import argparse |
| import difflib |
| import sys |
| |
| |
| # for python2 compatability |
| PY2 = sys.version_info[0] == 2 |
| if PY2: |
| |
| def input(foo): |
| return raw_input(foo) |
| |
| |
| # Whitelist of files to check. Takes a glob syntax. Does not support |
| # recursive globs ("**") because I am lazy and don't want to make that |
| # work with Python 2. |
| CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/passes/alias_analysis*"] |
| |
| |
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Execute clang-format on your working copy changes." |
| ) |
| parser.add_argument( |
| "-d", |
| "--diff", |
| default="HEAD", |
| help="Git revision to diff against to get changes", |
| ) |
| parser.add_argument( |
| "--non-interactive", |
| action="store_true", |
| default=False, |
| help="Don't prompt the user to apply clang-format changes. If there are any changes, just exit non-zero", |
| ) |
| parser.add_argument("--verbose", "-v", action="store_true", default=False) |
| return parser.parse_args() |
| |
| |
| def get_whitelisted_files(): |
| """ |
| Parse CLANG_FORMAT_WHITELIST and resolve all globs. |
| Returns the set of all whitelisted filenames. |
| """ |
| paths = [glob.glob(entry) for entry in CLANG_FORMAT_WHITELIST] |
| # flatten the files list |
| paths = itertools.chain(*paths) |
| # filter out directories |
| filenames = filter(lambda path: os.path.isfile(path), paths) |
| return set(filenames) |
| |
| |
| def get_changed_files(rev): |
| """ |
| Get all changed files between the working tree and `rev` |
| """ |
| changed_files = ( |
| subprocess.check_output( |
| ["git", "diff-index", "--diff-filter=AMU", "--name-only", rev] |
| ) |
| .decode() |
| .split("\n") |
| ) |
| return set(changed_files) |
| |
| |
| def get_diffs(files): |
| """ |
| Run clang-format on all `files` and report if it changed anything. |
| Returns a mapping of filename => diff generator |
| """ |
| name_to_diffs = {} |
| for f in files: |
| formatted_text = subprocess.check_output(["clang-format", f]).decode() |
| with open(f) as orig: |
| orig_text = orig.read() |
| if formatted_text != orig_text: |
| orig_lines = orig_text.split("\n") |
| formatted_lines = formatted_text.split("\n") |
| diff = difflib.unified_diff( |
| orig_lines, formatted_lines, "original", "formatted" |
| ) |
| name_to_diffs[f] = diff |
| |
| return name_to_diffs |
| |
| |
| def main(): |
| args = parse_args() |
| |
| changed_files = get_changed_files(args.diff) |
| whitelisted_files = get_whitelisted_files() |
| |
| files_to_check = changed_files & whitelisted_files |
| |
| if args.verbose: |
| print("Running clang-format on whitelisted files: ") |
| for f in files_to_check: |
| print(f) |
| |
| name_to_diffs = get_diffs(files_to_check) |
| |
| if len(name_to_diffs) != 0: |
| print("ERROR: Running clang-format created changes: ") |
| for name, diff in name_to_diffs.items(): |
| print("In ", name) |
| for line in diff: |
| print(line) |
| print("\n") |
| |
| if args.non_interactive: |
| exit(1) |
| else: |
| choice = None |
| # Loop until we choose y or n |
| while choice is None: |
| choice = input("Accept these changes? [Y/n] ").lower() |
| if choice != "" and choice[0] != "y" and choice[0] != "n": |
| choice = None |
| |
| if choice == "" or choice[0] == "y": |
| # run clang-format on the necessary files |
| args = ["clang-format", "-i"] |
| args.extend(name_to_diffs.keys()) |
| subprocess.check_output(args) |
| else: |
| exit(1) |
| |
| |
| if __name__ == "__main__": |
| main() |