Update `syncbranches` workflow (#71420)
Summary:
Use `pytorchmergebot` credentials to do the merge
Infer sync branch name from the workflow rather than hardcode it
Move common functions from `syncbranches.py` to `gitutils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71420
Reviewed By: bigfootjon
Differential Revision: D33638846
Pulled By: malfet
fbshipit-source-id: a568fd9ca04f4f142a7f5f64363e9516f5f4ef1c
diff --git a/.github/scripts/__init__.py b/.github/scripts/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/.github/scripts/__init__.py
+++ /dev/null
diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py
new file mode 100644
index 0000000..38c80b5
--- /dev/null
+++ b/.github/scripts/gitutils.py
@@ -0,0 +1,232 @@
+#!/usr/bin/env python3
+
+from collections import defaultdict
+from datetime import datetime
+from typing import cast, Any, Dict, List, Optional, Tuple, Union
+import os
+import re
+
+
+RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
+
+
+def get_git_remote_name() -> str:
+ return os.getenv("GIT_REMOTE_NAME", "origin")
+
+
+def get_git_repo_dir() -> str:
+ from pathlib import Path
+ return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent))
+
+
+def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
+ """
+ Converts list to dict preserving elements with duplicate keys
+ """
+ rc: Dict[str, List[str]] = defaultdict(lambda: [])
+ for (key, val) in items:
+ rc[key].append(val)
+ return dict(rc)
+
+
+def _check_output(items: List[str], encoding: str = "utf-8") -> str:
+ from subprocess import check_output
+ return check_output(items).decode(encoding)
+
+
+class GitCommit:
+ commit_hash: str
+ title: str
+ body: str
+ author: str
+ author_date: datetime
+ commit_date: Optional[datetime]
+
+ def __init__(self,
+ commit_hash: str,
+ author: str,
+ author_date: datetime,
+ title: str,
+ body: str,
+ commit_date: Optional[datetime] = None) -> None:
+ self.commit_hash = commit_hash
+ self.author = author
+ self.author_date = author_date
+ self.commit_date = commit_date
+ self.title = title
+ self.body = body
+
+ def __repr__(self) -> str:
+ return f"{self.title} ({self.commit_hash})"
+
+ def __contains__(self, item: Any) -> bool:
+ return item in self.body or item in self.title
+
+
+def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
+ """
+ Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
+ commit <sha1>
+ Author: <author>
+ AuthorDate: <author date>
+ Commit: <committer>
+ CommitDate: <committer date>
+
+ <title line>
+
+ <full commit message>
+
+ """
+ if isinstance(lines, str):
+ lines = lines.split("\n")
+ # TODO: Handle merge commits correctly
+ if len(lines) > 1 and lines[1].startswith("Merge:"):
+ del lines[1]
+ assert len(lines) > 7
+ assert lines[0].startswith("commit")
+ assert lines[1].startswith("Author: ")
+ assert lines[2].startswith("AuthorDate: ")
+ assert lines[3].startswith("Commit: ")
+ assert lines[4].startswith("CommitDate: ")
+ assert len(lines[5]) == 0
+ return GitCommit(commit_hash=lines[0].split()[1].strip(),
+ author=lines[1].split(":", 1)[1].strip(),
+ author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
+ commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
+ title=lines[6].strip(),
+ body="\n".join(lines[7:]),
+ )
+
+
+class GitRepo:
+ def __init__(self, path: str, remote: str = "origin") -> None:
+ self.repo_dir = path
+ self.remote = remote
+
+ def _run_git(self, *args: Any) -> str:
+ return _check_output(["git", "-C", self.repo_dir] + list(args))
+
+ def revlist(self, revision_range: str) -> List[str]:
+ rc = self._run_git("rev-list", revision_range, "--", ".").strip()
+ return rc.split("\n") if len(rc) > 0 else []
+
+ def current_branch(self) -> str:
+ return self._run_git("symbolic-ref", "--short", "HEAD").strip()
+
+ def checkout(self, branch: str) -> None:
+ self._run_git('checkout', branch)
+
+ def show_ref(self, name: str) -> str:
+ refs = self._run_git('show-ref', '-s', name).strip().split('\n')
+ if not all(refs[i] == refs[0] for i in range(1, len(refs))):
+ raise RuntimeError(f"referce {name} is ambigous")
+ return refs[0]
+
+ def rev_parse(self, name: str) -> str:
+ return self._run_git('rev-parse', '--verify', name).strip()
+
+ def get_merge_base(self, from_ref: str, to_ref: str) -> str:
+ return self._run_git('merge-base', from_ref, to_ref).strip()
+
+ def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
+ is_list = isinstance(ref, list)
+ if is_list:
+ if len(ref) == 0:
+ return []
+ ref = " ".join(ref)
+ rc = _check_output(['sh', '-c', f'git -C {self.repo_dir} show {ref}|git patch-id --stable']).strip()
+ return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
+
+ def get_commit(self, ref: str) -> GitCommit:
+ return parse_fuller_format(self._run_git('show', '--format=fuller', '--date=unix', '--shortstat', ref))
+
+ def cherry_pick(self, ref: str) -> None:
+ self._run_git('cherry-pick', '-x', ref)
+
+ def compute_branch_diffs(self, from_branch: str, to_branch: str) -> Tuple[List[str], List[str]]:
+ """
+ Returns list of commmits that are missing in each other branch since their merge base
+ Might be slow if merge base is between two branches is pretty far off
+ """
+ from_ref = self.rev_parse(from_branch)
+ to_ref = self.rev_parse(to_branch)
+ merge_base = self.get_merge_base(from_ref, to_ref)
+ from_commits = self.revlist(f'{merge_base}..{from_ref}')
+ to_commits = self.revlist(f'{merge_base}..{to_ref}')
+ from_ids = fuzzy_list_to_dict(self.patch_id(from_commits))
+ to_ids = fuzzy_list_to_dict(self.patch_id(to_commits))
+ for patch_id in set(from_ids).intersection(set(to_ids)):
+ from_values = from_ids[patch_id]
+ to_values = to_ids[patch_id]
+ if len(from_values) != len(to_values):
+ # Eliminate duplicate commits+reverts from the list
+ while len(from_values) > 0 and len(to_values) > 0:
+ frc = self.get_commit(from_values.pop())
+ toc = self.get_commit(to_values.pop())
+ if frc.title != toc.title or frc.author_date != toc.author_date:
+ raise RuntimeError(f"Unexpected differences between {frc} and {toc}")
+ from_commits.remove(frc.commit_hash)
+ to_commits.remove(toc.commit_hash)
+ continue
+ for commit in from_values:
+ from_commits.remove(commit)
+ for commit in to_values:
+ to_commits.remove(commit)
+ return (from_commits, to_commits)
+
+ def cherry_pick_commits(self, from_branch: str, to_branch: str) -> None:
+ orig_branch = self.current_branch()
+ self.checkout(to_branch)
+ from_commits, to_commits = self.compute_branch_diffs(from_branch, to_branch)
+ if len(from_commits) == 0:
+ print("Nothing to do")
+ self.checkout(orig_branch)
+ return
+ for commit in reversed(from_commits):
+ self.cherry_pick(commit)
+ self.checkout(orig_branch)
+
+ def push(self, branch: str) -> None:
+ self._run_git("push", self.remote, branch)
+
+ def head_hash(self) -> str:
+ return self._run_git("show-ref", "--hash", "HEAD").strip()
+
+ def remote_url(self) -> str:
+ return self._run_git("remote", "get-url", self.remote)
+
+ def gh_owner_and_name(self) -> Tuple[str, str]:
+ url = os.getenv("GIT_REMOTE_URL", None)
+ if url is None:
+ url = self.remote_url()
+ rc = RE_GITHUB_URL_MATCH.match(url)
+ if rc is None:
+ raise RuntimeError(f"Unexpected url format {url}")
+ return cast(Tuple[str, str], rc.groups())
+
+ def commit_message(self, ref: str) -> str:
+ return self._run_git("log", "-1", "--format=%B", ref)
+
+ def amend_commit_message(self, msg: str) -> None:
+ self._run_git("commit", "--amend", "-m", msg)
+
+
+def parse_args() -> Any:
+ from argparse import ArgumentParser
+ parser = ArgumentParser("Merge PR/branch into default branch")
+ parser.add_argument("--sync-branch", default="sync")
+ parser.add_argument("--default-branch", type=str, default="main")
+ parser.add_argument("--dry-run", action="store_true")
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = parse_args()
+ repo = GitRepo(get_git_repo_dir())
+ repo.cherry_pick_commits(args.sync_branch, args.default_branch)
+ if not args.dry_run:
+ repo.push(args.default_branch)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/.github/scripts/syncbranches.py b/.github/scripts/syncbranches.py
index c838fa7..e4b321e 100755
--- a/.github/scripts/syncbranches.py
+++ b/.github/scripts/syncbranches.py
@@ -1,186 +1,25 @@
#!/usr/bin/env python3
-from collections import defaultdict
-from datetime import datetime
-from typing import cast, Any, Dict, List, Optional, Tuple, Union
-import os
+from gitutils import get_git_repo_dir, GitRepo
+from typing import Any
-def _check_output(items: List[str], encoding: str = "utf-8") -> str:
- from subprocess import check_output
- return check_output(items).decode(encoding)
+def parse_args() -> Any:
+ from argparse import ArgumentParser
+ parser = ArgumentParser("Merge PR/branch into default branch")
+ parser.add_argument("--sync-branch", default="sync")
+ parser.add_argument("--default-branch", type=str, default="main")
+ parser.add_argument("--dry-run", action="store_true")
+ return parser.parse_args()
-def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
- """
- Converts list to dict preserving elements with duplicate keys
- """
- rc: Dict[str, List[str]] = defaultdict(lambda: [])
- for (key, val) in items:
- rc[key].append(val)
- return dict(rc)
-
-
-class GitCommit:
- commit_hash: str
- title: str
- body: str
- author: str
- author_date: datetime
- commit_date: Optional[datetime]
-
- def __init__(self,
- commit_hash: str,
- author: str,
- author_date: datetime,
- title: str,
- body: str,
- commit_date: Optional[datetime] = None) -> None:
- self.commit_hash = commit_hash
- self.author = author
- self.author_date = author_date
- self.commit_date = commit_date
- self.title = title
- self.body = body
-
- def __repr__(self) -> str:
- return f"{self.title} ({self.commit_hash})"
-
- def __contains__(self, item: Any) -> bool:
- return item in self.body or item in self.title
-
-
-def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
- """
- Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
- commit <sha1>
- Author: <author>
- AuthorDate: <author date>
- Commit: <committer>
- CommitDate: <committer date>
-
- <title line>
-
- <full commit message>
-
- """
- if isinstance(lines, str):
- lines = lines.split("\n")
- # TODO: Handle merge commits correctly
- if len(lines) > 1 and lines[1].startswith("Merge:"):
- del lines[1]
- assert len(lines) > 7
- assert lines[0].startswith("commit")
- assert lines[1].startswith("Author: ")
- assert lines[2].startswith("AuthorDate: ")
- assert lines[3].startswith("Commit: ")
- assert lines[4].startswith("CommitDate: ")
- assert len(lines[5]) == 0
- return GitCommit(commit_hash=lines[0].split()[1].strip(),
- author=lines[1].split(":", 1)[1].strip(),
- author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
- commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
- title=lines[6].strip(),
- body="\n".join(lines[7:]),
- )
-
-
-class GitRepo:
- def __init__(self, path: str, remote: str = "origin") -> None:
- self.repo_dir = path
- self.remote = remote
-
- def _run_git(self, *args: Any) -> str:
- return _check_output(["git", "-C", self.repo_dir] + list(args))
-
- def revlist(self, revision_range: str) -> List[str]:
- rc = self._run_git("rev-list", revision_range, "--", ".").strip()
- return rc.split("\n") if len(rc) > 0 else []
-
- def current_branch(self) -> str:
- return self._run_git("symbolic-ref", "--short", "HEAD").strip()
-
- def checkout(self, branch: str) -> None:
- self._run_git('checkout', branch)
-
- def show_ref(self, name: str) -> str:
- refs = self._run_git('show-ref', '-s', name).strip().split('\n')
- if not all(refs[i] == refs[0] for i in range(1, len(refs))):
- raise RuntimeError(f"referce {name} is ambigous")
- return refs[0]
-
- def rev_parse(self, name: str) -> str:
- return self._run_git('rev-parse', '--verify', name).strip()
-
- def get_merge_base(self, from_ref: str, to_ref: str) -> str:
- return self._run_git('merge-base', from_ref, to_ref).strip()
-
- def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
- is_list = isinstance(ref, list)
- if is_list:
- if len(ref) == 0:
- return []
- ref = " ".join(ref)
- rc = _check_output(['sh', '-c', f'git -C {self.repo_dir} show {ref}|git patch-id --stable']).strip()
- return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
-
- def get_commit(self, ref: str) -> GitCommit:
- return parse_fuller_format(self._run_git('show', '--format=fuller', '--date=unix', '--shortstat', ref))
-
- def cherry_pick(self, ref: str) -> None:
- self._run_git('cherry-pick', '-x', ref)
-
- def compute_branch_diffs(self, from_branch: str, to_branch: str) -> Tuple[List[str], List[str]]:
- """
- Returns list of commmits that are missing in each other branch since their merge base
- Might be slow if merge base is between two branches is pretty far off
- """
- from_ref = self.rev_parse(from_branch)
- to_ref = self.rev_parse(to_branch)
- merge_base = self.get_merge_base(from_ref, to_ref)
- from_commits = self.revlist(f'{merge_base}..{from_ref}')
- to_commits = self.revlist(f'{merge_base}..{to_ref}')
- from_ids = fuzzy_list_to_dict(self.patch_id(from_commits))
- to_ids = fuzzy_list_to_dict(self.patch_id(to_commits))
- for patch_id in set(from_ids).intersection(set(to_ids)):
- from_values = from_ids[patch_id]
- to_values = to_ids[patch_id]
- if len(from_values) != len(to_values):
- # Eliminate duplicate commits+reverts from the list
- while len(from_values) > 0 and len(to_values) > 0:
- frc = self.get_commit(from_values.pop())
- toc = self.get_commit(to_values.pop())
- if frc.title != toc.title or frc.author_date != toc.author_date:
- raise RuntimeError(f"Unexpected differences between {frc} and {toc}")
- from_commits.remove(frc.commit_hash)
- to_commits.remove(toc.commit_hash)
- continue
- for commit in from_values:
- from_commits.remove(commit)
- for commit in to_values:
- to_commits.remove(commit)
- return (from_commits, to_commits)
-
- def cherry_pick_commits(self, from_branch: str, to_branch: str) -> None:
- orig_branch = self.current_branch()
- self.checkout(to_branch)
- from_commits, to_commits = self.compute_branch_diffs(from_branch, to_branch)
- if len(from_commits) == 0:
- print("Nothing to do")
- self.checkout(orig_branch)
- return
- for commit in reversed(from_commits):
- self.cherry_pick(commit)
- self.checkout(orig_branch)
-
- def push(self, branch: str) -> None:
- self._run_git("push", self.remote, branch)
+def main() -> None:
+ args = parse_args()
+ repo = GitRepo(get_git_repo_dir())
+ repo.cherry_pick_commits(args.sync_branch, args.default_branch)
+ if not args.dry_run:
+ repo.push(args.default_branch)
if __name__ == '__main__':
- repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
- default_branch = 'master'
- sync_branch = 'fbsync'
- repo = GitRepo(repo_dir)
- repo.cherry_pick_commits(sync_branch, default_branch)
- repo.push(default_branch)
+ main()
diff --git a/.github/workflows/syncbranches.yml b/.github/workflows/syncbranches.yml
index 1ff91b2..8ed4f01 100644
--- a/.github/workflows/syncbranches.yml
+++ b/.github/workflows/syncbranches.yml
@@ -18,11 +18,12 @@
uses: actions/checkout@v2
with:
fetch-depth: 0
+ token: ${{ secrets.MERGEBOT_TOKEN }}
- - name: Setup commiter id
+ - name: Setup committer id
run: |
- git config --global user.email "pytorchbot@users.noreply.github.com"
- git config --global user.name "PyTorch Bot"
+ git config --global user.email "pytorchmergebot@users.noreply.github.com"
+ git config --global user.name "PyTorch MergeBot"
- name: Sync branches
run: |
- python3 .github/scripts/syncbranches.py
+ python3 .github/scripts/syncbranches.py --sync-branch="${GITHUB_REF#refs/heads/}" --default-branch=master