| #!/usr/bin/env python3 |
| """Check whether a PR has required labels.""" |
| |
| import sys |
| from typing import Any |
| |
| from github_utils import gh_delete_comment, gh_post_pr_comment |
| from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo |
| from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG |
| from trymerge import GitHubPR |
| |
| |
| def delete_all_label_err_comments(pr: "GitHubPR") -> None: |
| for comment in pr.get_comments(): |
| if is_label_err_comment(comment): |
| gh_delete_comment(pr.org, pr.project, comment.database_id) |
| |
| |
| def add_label_err_comment(pr: "GitHubPR") -> None: |
| # Only make a comment if one doesn't exist already |
| if not any(is_label_err_comment(comment) for comment in pr.get_comments()): |
| gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG) |
| |
| |
| def parse_args() -> Any: |
| from argparse import ArgumentParser |
| |
| parser = ArgumentParser("Check PR labels") |
| parser.add_argument("pr_num", type=int) |
| # add a flag to return a non-zero exit code if the PR does not have the required labels |
| parser.add_argument( |
| "--exit-non-zero", |
| action="store_true", |
| help="Return a non-zero exit code if the PR does not have the required labels", |
| ) |
| |
| return parser.parse_args() |
| |
| |
| def main() -> None: |
| args = parse_args() |
| repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) |
| org, project = repo.gh_owner_and_name() |
| pr = GitHubPR(org, project, args.pr_num) |
| |
| try: |
| if not has_required_labels(pr): |
| print(LABEL_ERR_MSG, flush=True) |
| add_label_err_comment(pr) |
| if args.exit_non_zero: |
| raise RuntimeError("PR does not have required labels") |
| else: |
| delete_all_label_err_comments(pr) |
| except Exception as e: |
| if args.exit_non_zero: |
| raise RuntimeError(f"Error checking labels: {e}") from e |
| |
| sys.exit(0) |
| |
| |
| if __name__ == "__main__": |
| main() |