blob: 8b5669c49616d6a5b83e4b07a40ddfd4a9f78b43 [file] [log] [blame]
#!/usr/bin/env python
"""
A script that runs clang-format on changes detected via git. It will
report if running clang-format generated any changes.
In CI, the script considers it a failure if running clang-format makes a change.
In the pre-commit hook, the user is prompted to apply any clang-format changes.
Running tools/clang_format.py manually with no arguments should replicate the pre-commit hook behavior.
Only files that are in CLANG_FORMAT_WHITELIST are checked.
"""
import subprocess
import glob
import itertools
import os
import argparse
import difflib
import sys
# for python2 compatability
PY2 = sys.version_info[0] == 2
if PY2:
def input(foo):
return raw_input(foo)
# Whitelist of files to check. Takes a glob syntax. Does not support
# recursive globs ("**") because I am lazy and don't want to make that
# work with Python 2.
CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/passes/alias_analysis*"]
def parse_args():
parser = argparse.ArgumentParser(
description="Execute clang-format on your working copy changes."
)
parser.add_argument(
"-d",
"--diff",
default="HEAD",
help="Git revision to diff against to get changes",
)
parser.add_argument(
"--non-interactive",
action="store_true",
default=False,
help="Don't prompt the user to apply clang-format changes. If there are any changes, just exit non-zero",
)
parser.add_argument("--verbose", "-v", action="store_true", default=False)
return parser.parse_args()
def get_whitelisted_files():
"""
Parse CLANG_FORMAT_WHITELIST and resolve all globs.
Returns the set of all whitelisted filenames.
"""
paths = [glob.glob(entry) for entry in CLANG_FORMAT_WHITELIST]
# flatten the files list
paths = itertools.chain(*paths)
# filter out directories
filenames = filter(lambda path: os.path.isfile(path), paths)
return set(filenames)
def get_changed_files(rev):
"""
Get all changed files between the working tree and `rev`
"""
changed_files = (
subprocess.check_output(
["git", "diff-index", "--diff-filter=AMU", "--name-only", rev]
)
.decode()
.split("\n")
)
return set(changed_files)
def get_diffs(files):
"""
Run clang-format on all `files` and report if it changed anything.
Returns a mapping of filename => diff generator
"""
name_to_diffs = {}
for f in files:
formatted_text = subprocess.check_output(["clang-format", f]).decode()
with open(f) as orig:
orig_text = orig.read()
if formatted_text != orig_text:
orig_lines = orig_text.split("\n")
formatted_lines = formatted_text.split("\n")
diff = difflib.unified_diff(
orig_lines, formatted_lines, "original", "formatted"
)
name_to_diffs[f] = diff
return name_to_diffs
def main():
args = parse_args()
changed_files = get_changed_files(args.diff)
whitelisted_files = get_whitelisted_files()
files_to_check = changed_files & whitelisted_files
if args.verbose:
print("Running clang-format on whitelisted files: ")
for f in files_to_check:
print(f)
name_to_diffs = get_diffs(files_to_check)
if len(name_to_diffs) != 0:
print("ERROR: Running clang-format created changes: ")
for name, diff in name_to_diffs.items():
print("In ", name)
for line in diff:
print(line)
print("\n")
if args.non_interactive:
exit(1)
else:
choice = None
# Loop until we choose y or n
while choice is None:
choice = input("Accept these changes? [Y/n] ").lower()
if choice != "" and choice[0] != "y" and choice[0] != "n":
choice = None
if choice == "" or choice[0] == "y":
# run clang-format on the necessary files
args = ["clang-format", "-i"]
args.extend(name_to_diffs.keys())
subprocess.check_output(args)
else:
exit(1)
if __name__ == "__main__":
main()