blob: ed0147e4a62a9fd0dc5ffa03f95c12fb69f7da9e [file] [log] [blame]
#!/usr/bin/env python3
import argparse
import json
import re
import subprocess
from bisect import bisect_right
from collections import defaultdict
from typing import (Callable, DefaultDict, Generic, List, Optional, Pattern,
Sequence, TypeVar, cast)
from typing_extensions import TypedDict
class Hunk(TypedDict):
old_start: int
old_count: int
new_start: int
new_count: int
class Diff(TypedDict):
old_filename: Optional[str]
hunks: List[Hunk]
# @@ -start,count +start,count @@
hunk_pattern = r'^@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@'
def parse_diff(diff: str) -> Diff:
name = None
name_found = False
hunks: List[Hunk] = []
for line in diff.splitlines():
hunk_match = re.match(hunk_pattern, line)
if name_found:
if hunk_match:
old_start, old_count, new_start, new_count = hunk_match.groups()
hunks.append({
'old_start': int(old_start),
'old_count': int(old_count or '1'),
'new_start': int(new_start),
'new_count': int(new_count or '1'),
})
else:
assert not hunk_match
name_match = re.match(r'^--- (?:(?:/dev/null)|(?:a/(.*)))$', line)
if name_match:
name_found = True
name, = name_match.groups()
return {
'old_filename': name,
'hunks': hunks,
}
T = TypeVar('T')
U = TypeVar('U')
# we want to use bisect.bisect_right to find the closest hunk to a given
# line number, but the bisect module won't have a key function until
# Python 3.10 https://github.com/python/cpython/pull/20556 so we make an
# O(1) wrapper around the list of hunks that makes it pretend to just be
# a list of line numbers
# https://gist.github.com/ericremoreynolds/2d80300dabc70eebc790
class KeyifyList(Generic[T, U]):
def __init__(self, inner: List[T], key: Callable[[T], U]) -> None:
self.inner = inner
self.key = key
def __len__(self) -> int:
return len(self.inner)
def __getitem__(self, k: int) -> U:
return self.key(self.inner[k])
def translate(diff: Diff, line_number: int) -> Optional[int]:
if line_number < 1:
return None
hunks = diff['hunks']
if not hunks:
return line_number
keyified = KeyifyList(
hunks,
lambda hunk: hunk['new_start'] + (0 if hunk['new_count'] > 0 else 1)
)
i = bisect_right(cast(Sequence[int], keyified), line_number)
if i < 1:
return line_number
hunk = hunks[i - 1]
d = line_number - (hunk['new_start'] + (hunk['new_count'] or 1))
return None if d < 0 else hunk['old_start'] + (hunk['old_count'] or 1) + d
# we use camelCase here because this will be output as JSON and so the
# field names need to match the group names from here:
# https://github.com/pytorch/add-annotations-github-action/blob/3ab7d7345209f5299d53303f7aaca7d3bc09e250/action.yml#L23
class Annotation(TypedDict):
filename: str
lineNumber: int
columnNumber: int
errorCode: str
errorDesc: str
def parse_annotation(regex: Pattern[str], line: str) -> Optional[Annotation]:
m = re.match(regex, line)
if m:
try:
line_number = int(m.group('lineNumber'))
column_number = int(m.group('columnNumber'))
except ValueError:
return None
return {
'filename': m.group('filename'),
'lineNumber': line_number,
'columnNumber': column_number,
'errorCode': m.group('errorCode'),
'errorDesc': m.group('errorDesc'),
}
else:
return None
def translate_all(
*,
lines: List[str],
regex: Pattern[str],
commit: str
) -> List[Annotation]:
ann_dict: DefaultDict[str, List[Annotation]] = defaultdict(list)
for line in lines:
annotation = parse_annotation(regex, line)
if annotation is not None:
ann_dict[annotation['filename']].append(annotation)
ann_list = []
for filename, annotations in ann_dict.items():
raw_diff = subprocess.check_output(
['git', 'diff-index', '--unified=0', commit, filename],
encoding='utf-8',
)
diff = parse_diff(raw_diff) if raw_diff.strip() else None
# if there is a diff but it doesn't list an old filename, that
# means the file is absent in the commit we're targeting, so we
# skip it
if not (diff and not diff['old_filename']):
for annotation in annotations:
line_number: Optional[int] = annotation['lineNumber']
if diff:
annotation['filename'] = cast(str, diff['old_filename'])
line_number = translate(diff, cast(int, line_number))
if line_number:
annotation['lineNumber'] = line_number
ann_list.append(annotation)
return ann_list
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('--file')
parser.add_argument('--regex')
parser.add_argument('--commit')
args = parser.parse_args()
with open(args.file, 'r') as f:
lines = f.readlines()
print(json.dumps(translate_all(
lines=lines,
regex=args.regex,
commit=args.commit
)))
if __name__ == '__main__':
main()