blob: 7fe0d086020a8ff73caaeec6f23274062a64b033 [file] [log] [blame]
#!/usr/bin/env python3
import json
import os
import re
from dataclasses import dataclass
from urllib.request import urlopen, Request
from urllib.error import HTTPError
from typing import cast, Any, Callable, Dict, List, Optional, Tuple
from gitutils import get_git_remote_name, get_git_repo_dir, patterns_to_regex, GitRepo
GH_GET_PR_INFO_QUERY = """
query ($owner: String!, $name: String!, $number: Int!) {
repository(owner: $owner, name: $name) {
pullRequest(number: $number) {
closed
isCrossRepository
author {
login
}
title
body
headRefName
headRepository {
nameWithOwner
}
baseRefName
baseRepository {
nameWithOwner
defaultBranchRef {
name
}
}
commits(first: 100) {
nodes {
commit {
author {
user {
login
}
email
name
}
oid
checkSuites(filterBy: {appId: 12274}, first: 1) {
nodes {
app {
databaseId
}
conclusion
}
}
}
}
totalCount
}
changedFiles,
files(last: 100) {
nodes {
path
}
}
latestReviews(last: 100) {
nodes {
author {
login
},
state
},
totalCount
}
}
}
}
"""
RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$")
RE_GHSTACK_SOURCE_ID = re.compile(r'^ghstack-source-id: (.+)\n?', re.MULTILINE)
RE_PULL_REQUEST_RESOLVED = re.compile(
r'Pull Request resolved: '
r'https://github.com/(?P<owner>[^/]+)/(?P<repo>[^/]+)/pull/(?P<number>[0-9]+)',
re.MULTILINE
)
def _fetch_url(url: str, *,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
method: Optional[str] = None,
reader: Callable[[Any], Any] = lambda x: x.read()) -> Any:
if headers is None:
headers = {}
token = os.environ.get("GITHUB_TOKEN")
if token is not None and url.startswith('https://api.github.com/'):
headers['Authorization'] = f'token {token}'
data_ = json.dumps(data).encode() if data is not None else None
try:
with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
return reader(conn)
except HTTPError as err:
if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']):
print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}")
raise
def fetch_json(url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
headers = {'Accept': 'application/vnd.github.v3+json'}
if params is not None and len(params) > 0:
url += '?' + '&'.join(f"{name}={val}" for name, val in params.items())
return cast(List[Dict[str, Any]], _fetch_url(url, headers=headers, data=data, reader=json.load))
def gh_post_comment(org: str, project: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
if dry_run:
print(comment)
return []
return fetch_json(f'https://api.github.com/repos/{org}/{project}/issues/{pr_num}/comments',
data={"body": comment})
def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
rc = _fetch_url("https://api.github.com/graphql", data={"query": query, "variables": kwargs}, reader=json.load)
if "errors" in rc:
raise RuntimeError(f"GraphQL query {query} failed: {rc['errors']}")
return cast(Dict[str, Any], rc)
def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
return rc["data"]["repository"]["pullRequest"]
def parse_args() -> Any:
from argparse import ArgumentParser
parser = ArgumentParser("Merge PR into default branch")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("pr_num", type=int)
return parser.parse_args()
class GitHubPR:
def __init__(self, org: str, project: str, pr_num: int) -> None:
assert isinstance(pr_num, int)
self.org = org
self.project = project
self.pr_num = pr_num
self.info = gh_get_pr_info(org, project, pr_num)
def is_closed(self) -> bool:
return bool(self.info["closed"])
def is_cross_repo(self) -> bool:
return bool(self.info["isCrossRepository"])
def base_ref(self) -> str:
return cast(str, self.info["baseRefName"])
def default_branch(self) -> str:
return cast(str, self.info["baseRepository"]["defaultBranchRef"]["name"])
def head_ref(self) -> str:
return cast(str, self.info["headRefName"])
def is_ghstack_pr(self) -> bool:
return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
def get_changed_files_count(self) -> int:
return int(self.info["changedFiles"])
def get_changed_files(self) -> List[str]:
rc = [x["path"] for x in self.info["files"]["nodes"]]
if len(rc) != self.get_changed_files_count():
raise RuntimeError("Changed file count mismatch")
return rc
def _get_reviewers(self) -> List[Tuple[str, str]]:
reviews_count = int(self.info["latestReviews"]["totalCount"])
if len(self.info["latestReviews"]["nodes"]) != reviews_count:
raise RuntimeError("Can't fetch all PR reviews")
return [(x["author"]["login"], x["state"]) for x in self.info["latestReviews"]["nodes"]]
def get_approved_by(self) -> List[str]:
return [login for (login, state) in self._get_reviewers() if state == "APPROVED"]
def get_commit_count(self) -> int:
return int(self.info["commits"]["totalCount"])
def get_pr_creator_login(self) -> str:
return cast(str, self.info["author"]["login"])
def get_committer_login(self, num: int = 0) -> str:
return cast(str, self.info["commits"]["nodes"][num]["commit"]["author"]["user"]["login"])
def get_committer_author(self, num: int = 0) -> str:
node = self.info["commits"]["nodes"][num]["commit"]["author"]
return f"{node['name']} <{node['email']}>"
def get_check_suite_conclusions(self) -> Dict[int, str]:
last_commit = self.info["commits"]["nodes"][-1]["commit"]
rc = {}
for node in last_commit["checkSuites"]["nodes"]:
rc[int(node["app"]["databaseId"])] = node["conclusion"]
return rc
def get_authors(self) -> Dict[str, str]:
rc = {}
for idx in range(self.get_commit_count()):
rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
return rc
def get_author(self) -> str:
authors = self.get_authors()
if len(authors) == 1:
return next(iter(authors.values()))
return self.get_authors()[self.get_pr_creator_login()]
def get_title(self) -> str:
return cast(str, self.info["title"])
def get_body(self) -> str:
return cast(str, self.info["body"])
def get_pr_url(self) -> str:
return f"https://github.com/{self.org}/{self.project}/pull/{self.pr_num}"
def merge_ghstack_into(self, repo: GitRepo) -> None:
assert self.is_ghstack_pr()
# For ghstack, cherry-pick commits based from origin
orig_ref = f"{repo.remote}/{re.sub(r'/head$', '/orig', self.head_ref())}"
rev_list = repo.revlist(f"{self.default_branch()}..{orig_ref}")
for idx, rev in enumerate(reversed(rev_list)):
msg = repo.commit_message(rev)
m = RE_PULL_REQUEST_RESOLVED.search(msg)
if m is None:
raise RuntimeError(f"Could not find PR-resolved string in {msg} of ghstacked PR {self.pr_num}")
if self.org != m.group('owner') or self.project != m.group('repo'):
raise RuntimeError(f"PR {m.group('number')} resolved to wrong owner/repo pair")
pr_num = int(m.group('number'))
if pr_num != self.pr_num:
pr = GitHubPR(self.org, self.project, pr_num)
if pr.is_closed():
print(f"Skipping {idx+1} of {len(rev_list)} PR (#{pr_num}) as its already been merged")
continue
check_if_should_be_merged(pr, repo)
repo.cherry_pick(rev)
repo.amend_commit_message(re.sub(RE_GHSTACK_SOURCE_ID, "", msg))
def merge_into(self, repo: GitRepo, dry_run: bool = False) -> None:
check_if_should_be_merged(self, repo)
if repo.current_branch() != self.default_branch():
repo.checkout(self.default_branch())
if not self.is_ghstack_pr():
msg = self.get_title() + "\n\n" + self.get_body()
msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
repo._run_git("merge", "--squash", f"{repo.remote}/{self.head_ref()}")
repo._run_git("commit", f"--author=\"{self.get_author()}\"", "-m", msg)
else:
self.merge_ghstack_into(repo)
if not dry_run:
repo.push(self.default_branch())
@dataclass
class MergeRule:
name: str
patterns: List[str]
approved_by: List[str]
mandatory_app_id: Optional[int]
def read_merge_rules(repo: GitRepo) -> List[MergeRule]:
from pathlib import Path
rules_path = Path(repo.repo_dir) / ".github" / "merge_rules.json"
if not rules_path.exists():
print(f"{rules_path} does not exist, returning empty rules")
return []
with open(rules_path) as fp:
rc = json.load(fp, object_hook=lambda x: MergeRule(**x))
return cast(List[MergeRule], rc)
def check_if_should_be_merged(pr: GitHubPR, repo: GitRepo) -> None:
changed_files = pr.get_changed_files()
approved_by = set(pr.get_approved_by())
rules = read_merge_rules(repo)
for rule in rules:
rule_name = rule.name
rule_approvers_set = set(rule.approved_by)
patterns_re = patterns_to_regex(rule.patterns)
approvers_intersection = approved_by.intersection(rule_approvers_set)
# If rule requires approvers but they aren't the ones that reviewed PR
if len(approvers_intersection) == 0 and len(rule_approvers_set) > 0:
print(f"Skipping rule {rule_name} due to no approvers overlap")
continue
if rule.mandatory_app_id is not None:
cs_conslusions = pr.get_check_suite_conclusions()
mandatory_app_id = rule.mandatory_app_id
if mandatory_app_id not in cs_conslusions or cs_conslusions[mandatory_app_id] != "SUCCESS":
print(f"Skipping rule {rule_name} as mandatory app {mandatory_app_id} is not in {cs_conslusions}")
continue
non_matching_files = []
for fname in changed_files:
if not patterns_re.match(fname):
non_matching_files.append(fname)
if len(non_matching_files) > 0:
print(f"Skipping rule {rule_name} due to non-matching files: {non_matching_files}")
continue
print(f"Matched rule {rule_name} for {pr.pr_num}")
return
raise RuntimeError(f"PR {pr.pr_num} does not match merge rules")
def main() -> None:
import sys
args = parse_args()
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
org, project = repo.gh_owner_and_name()
pr = GitHubPR(org, project, args.pr_num)
if pr.is_closed():
print(gh_post_comment(org, project, args.pr_num, f"Can't merge closed PR #{args.pr_num}", dry_run=args.dry_run))
sys.exit(-1)
if pr.is_cross_repo():
print(gh_post_comment(org, project, args.pr_num, "Cross-repo merges are not supported at the moment", dry_run=args.dry_run))
sys.exit(-1)
try:
pr.merge_into(repo, dry_run=args.dry_run)
except Exception as e:
gh_post_comment(org, project, args.pr_num, f"Merge failed due to {e}", dry_run=args.dry_run)
if __name__ == "__main__":
main()