blob: e28d33c642be1b60a1e7d7790ece822dc42b4f39 [file] [log] [blame]
# Delete old branches
import os
import re
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Set
from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo
SEC_IN_DAY = 24 * 60 * 60
CLOSED_PR_RETENTION = 30 * SEC_IN_DAY
NO_PR_RETENTION = 1.5 * 365 * SEC_IN_DAY
PR_WINDOW = 90 * SEC_IN_DAY # Set to None to look at all PRs (may take a lot of tokens)
REPO_OWNER = "pytorch"
REPO_NAME = "pytorch"
ESTIMATED_TOKENS = [0]
TOKEN = os.environ["GITHUB_TOKEN"]
if not TOKEN:
raise Exception("GITHUB_TOKEN is not set") # noqa: TRY002
REPO_ROOT = Path(__file__).parents[2]
# Query for all PRs instead of just closed/merged because it's faster
GRAPHQL_ALL_PRS_BY_UPDATED_AT = """
query ($owner: String!, $repo: String!, $cursor: String) {
repository(owner: $owner, name: $repo) {
pullRequests(
first: 100
after: $cursor
orderBy: {field: UPDATED_AT, direction: DESC}
) {
totalCount
pageInfo {
hasNextPage
endCursor
}
nodes {
headRefName
number
updatedAt
state
}
}
}
}
"""
GRAPHQL_OPEN_PRS = """
query ($owner: String!, $repo: String!, $cursor: String) {
repository(owner: $owner, name: $repo) {
pullRequests(
first: 100
after: $cursor
states: [OPEN]
) {
totalCount
pageInfo {
hasNextPage
endCursor
}
nodes {
headRefName
number
updatedAt
state
}
}
}
}
"""
GRAPHQL_NO_DELETE_BRANCH_LABEL = """
query ($owner: String!, $repo: String!, $cursor: String) {
repository(owner: $owner, name: $repo) {
label(name: "no-delete-branch") {
pullRequests(first: 100, after: $cursor) {
totalCount
pageInfo {
hasNextPage
endCursor
}
nodes {
headRefName
number
updatedAt
state
}
}
}
}
}
"""
def is_protected(branch: str) -> bool:
try:
ESTIMATED_TOKENS[0] += 1
res = gh_fetch_json_dict(
f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/branches/{branch}"
)
return bool(res["protected"])
except Exception as e:
print(f"[{branch}] Failed to fetch branch protections: {e}")
return True
def convert_gh_timestamp(date: str) -> float:
return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()
def get_branches(repo: GitRepo) -> Dict[str, Any]:
# Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
git_response = repo._run_git(
"for-each-ref",
"--sort=creatordate",
"--format=%(refname) %(committerdate:iso-strict)",
"refs/remotes/origin",
)
branches_by_base_name: Dict[str, Any] = {}
for line in git_response.splitlines():
branch, date = line.split(" ")
re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
assert re_branch
branch = branch_base_name = re_branch.group(1)
if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch):
branch_base_name = x.group(1)
date = datetime.fromisoformat(date).timestamp()
if branch_base_name not in branches_by_base_name:
branches_by_base_name[branch_base_name] = [date, [branch]]
else:
branches_by_base_name[branch_base_name][1].append(branch)
if date > branches_by_base_name[branch_base_name][0]:
branches_by_base_name[branch_base_name][0] = date
return branches_by_base_name
def paginate_graphql(
query: str,
kwargs: Dict[str, Any],
termination_func: Callable[[List[Dict[str, Any]]], bool],
get_data: Callable[[Dict[str, Any]], List[Dict[str, Any]]],
get_page_info: Callable[[Dict[str, Any]], Dict[str, Any]],
) -> List[Any]:
hasNextPage = True
endCursor = None
data: List[Dict[str, Any]] = []
while hasNextPage:
ESTIMATED_TOKENS[0] += 1
res = gh_graphql(query, cursor=endCursor, **kwargs)
data.extend(get_data(res))
hasNextPage = get_page_info(res)["hasNextPage"]
endCursor = get_page_info(res)["endCursor"]
if termination_func(data):
break
return data
def get_recent_prs() -> Dict[str, Any]:
now = datetime.now().timestamp()
# Grab all PRs updated in last CLOSED_PR_RETENTION days
pr_infos: List[Dict[str, Any]] = paginate_graphql(
GRAPHQL_ALL_PRS_BY_UPDATED_AT,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: (
PR_WINDOW is not None
and (now - convert_gh_timestamp(data[-1]["updatedAt"]) > PR_WINDOW)
),
lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
)
# Get the most recent PR for each branch base (group gh together)
prs_by_branch_base = {}
for pr in pr_infos:
pr["updatedAt"] = convert_gh_timestamp(pr["updatedAt"])
branch_base_name = pr["headRefName"]
if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch_base_name):
branch_base_name = x.group(1)
if branch_base_name not in prs_by_branch_base:
prs_by_branch_base[branch_base_name] = pr
else:
if pr["updatedAt"] > prs_by_branch_base[branch_base_name]["updatedAt"]:
prs_by_branch_base[branch_base_name] = pr
return prs_by_branch_base
@lru_cache(maxsize=1)
def get_open_prs() -> List[Dict[str, Any]]:
return paginate_graphql(
GRAPHQL_OPEN_PRS,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: False,
lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
)
def get_branches_with_magic_label_or_open_pr() -> Set[str]:
pr_infos: List[Dict[str, Any]] = paginate_graphql(
GRAPHQL_NO_DELETE_BRANCH_LABEL,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: False,
lambda res: res["data"]["repository"]["label"]["pullRequests"]["nodes"],
lambda res: res["data"]["repository"]["label"]["pullRequests"]["pageInfo"],
)
pr_infos.extend(get_open_prs())
# Get the most recent PR for each branch base (group gh together)
branch_bases = set()
for pr in pr_infos:
branch_base_name = pr["headRefName"]
if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch_base_name):
branch_base_name = x.group(1)
branch_bases.add(branch_base_name)
return branch_bases
def delete_branch(repo: GitRepo, branch: str) -> None:
repo._run_git("push", "origin", "-d", branch)
def delete_branches() -> None:
now = datetime.now().timestamp()
git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True)
branches = get_branches(git_repo)
prs_by_branch = get_recent_prs()
keep_branches = get_branches_with_magic_label_or_open_pr()
delete = []
# Do not delete if:
# * associated PR is open, closed but updated recently, or contains the magic string
# * no associated PR and branch was updated in last 1.5 years
# * is protected
# Setting different values of PR_WINDOW will change how branches with closed
# PRs are treated depending on how old the branch is. The default value of
# 90 will allow branches with closed PRs to be deleted if the PR hasn't been
# updated in 90 days and the branch hasn't been updated in 1.5 years
for base_branch, (date, sub_branches) in branches.items():
print(f"[{base_branch}] Updated {(now - date) / SEC_IN_DAY} days ago")
if base_branch in keep_branches:
print(f"[{base_branch}] Has magic label or open PR, skipping")
continue
pr = prs_by_branch.get(base_branch)
if pr:
print(
f"[{base_branch}] Has PR {pr['number']}: {pr['state']}, updated {(now - pr['updatedAt']) / SEC_IN_DAY} days ago"
)
if (
now - pr["updatedAt"] < CLOSED_PR_RETENTION
or (now - date) < CLOSED_PR_RETENTION
):
continue
elif now - date < NO_PR_RETENTION:
continue
print(f"[{base_branch}] Checking for branch protections")
if any(is_protected(sub_branch) for sub_branch in sub_branches):
print(f"[{base_branch}] Is protected")
continue
for sub_branch in sub_branches:
print(f"[{base_branch}] Deleting {sub_branch}")
delete.append(sub_branch)
if ESTIMATED_TOKENS[0] > 400:
print("Estimated tokens exceeded, exiting")
break
print(f"To delete ({len(delete)}):")
for branch in delete:
print(f"About to delete branch {branch}")
delete_branch(git_repo, branch)
def delete_old_ciflow_tags() -> None:
# Deletes ciflow tags if they are associated with a closed PR or a specific
# commit. Lightweight tags don't have information about the date they were
# created, so we can't check how old they are. The script just assumes that
# ciflow tags should be deleted regardless of creation date.
git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True)
def delete_tag(tag: str) -> None:
print(f"Deleting tag {tag}")
ESTIMATED_TOKENS[0] += 1
delete_branch(git_repo, f"refs/tags/{tag}")
tags = git_repo._run_git("tag").splitlines()
open_pr_numbers = [x["number"] for x in get_open_prs()]
for tag in tags:
try:
if ESTIMATED_TOKENS[0] > 400:
print("Estimated tokens exceeded, exiting")
break
if not tag.startswith("ciflow/"):
continue
re_match_pr = re.match(r"^ciflow\/.*\/(\d{5,6})$", tag)
re_match_sha = re.match(r"^ciflow\/.*\/([0-9a-f]{40})$", tag)
if re_match_pr:
pr_number = int(re_match_pr.group(1))
if pr_number in open_pr_numbers:
continue
delete_tag(tag)
elif re_match_sha:
delete_tag(tag)
except Exception as e:
print(f"Failed to check tag {tag}: {e}")
if __name__ == "__main__":
delete_branches()
delete_old_ciflow_tags()