Fix delete branches (#119399)

Due to PR_WINDOW, if the magic string exists in the body but the pr was not updated recently, the query wouldn't find it and would delete the branch.  Instead, query separately for branches with the no-delete-branch label, which I created recently.

Might as well query for branches with open PRs while we're at it so PRs with the stale label won't get their branches deleted either
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119399
Approved by: https://github.com/huydhn
diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py
index d871f35..601631d 100644
--- a/.github/scripts/delete_old_branches.py
+++ b/.github/scripts/delete_old_branches.py
@@ -3,7 +3,7 @@
 import re
 from datetime import datetime
 from pathlib import Path
-from typing import Any, Dict, List
+from typing import Any, Callable, Dict, List, Set
 
 from github_utils import gh_fetch_json_dict, gh_graphql
 from gitutils import GitRepo
@@ -14,7 +14,6 @@
 PR_WINDOW = 90 * SEC_IN_DAY  # Set to None to look at all PRs (may take a lot of tokens)
 REPO_OWNER = "pytorch"
 REPO_NAME = "pytorch"
-PR_BODY_MAGIC_STRING = "do-not-delete-branch"
 ESTIMATED_TOKENS = [0]
 
 TOKEN = os.environ["GITHUB_TOKEN"]
@@ -23,7 +22,8 @@
 
 REPO_ROOT = Path(__file__).parent.parent.parent
 
-GRAPHQL_PRS_QUERY = """
+# Query for all PRs instead of just closed/merged because it's faster
+GRAPHQL_ALL_PRS_BY_UPDATED_AT = """
 query ($owner: String!, $repo: String!, $cursor: String) {
   repository(owner: $owner, name: $repo) {
     pullRequests(
@@ -41,7 +41,52 @@
         number
         updatedAt
         state
-        body
+      }
+    }
+  }
+}
+"""
+
+GRAPHQL_OPEN_PRS = """
+query ($owner: String!, $repo: String!, $cursor: String) {
+  repository(owner: $owner, name: $repo) {
+    pullRequests(
+      first: 100
+      after: $cursor
+      states: [OPEN]
+    ) {
+      totalCount
+      pageInfo {
+        hasNextPage
+        endCursor
+      }
+      nodes {
+        headRefName
+        number
+        updatedAt
+        state
+      }
+    }
+  }
+}
+"""
+
+GRAPHQL_NO_DELETE_BRANCH_LABEL = """
+query ($owner: String!, $repo: String!, $cursor: String) {
+  repository(owner: $owner, name: $repo) {
+    label(name: "no-delete-branch") {
+      pullRequests(first: 100, after: $cursor) {
+        totalCount
+        pageInfo {
+          hasNextPage
+          endCursor
+        }
+        nodes {
+          headRefName
+          number
+          updatedAt
+          state
+        }
       }
     }
   }
@@ -87,27 +132,41 @@
     return branches_by_base_name
 
 
-def get_prs() -> Dict[str, Any]:
-    now = datetime.now().timestamp()
-
-    pr_infos: List[Dict[str, Any]] = []
-
+def paginate_graphql(
+    query: str,
+    kwargs: Dict[str, Any],
+    termination_func: Callable[[List[Dict[str, Any]]], bool],
+    get_data: Callable[[Dict[str, Any]], List[Dict[str, Any]]],
+    get_page_info: Callable[[Dict[str, Any]], Dict[str, Any]],
+) -> List[Any]:
     hasNextPage = True
     endCursor = None
+    data: List[Dict[str, Any]] = []
     while hasNextPage:
         ESTIMATED_TOKENS[0] += 1
-        res = gh_graphql(
-            GRAPHQL_PRS_QUERY, owner="pytorch", repo="pytorch", cursor=endCursor
-        )
-        info = res["data"]["repository"]["pullRequests"]
-        pr_infos.extend(info["nodes"])
-        hasNextPage = info["pageInfo"]["hasNextPage"]
-        endCursor = info["pageInfo"]["endCursor"]
-        if (
-            PR_WINDOW
-            and now - convert_gh_timestamp(pr_infos[-1]["updatedAt"]) > PR_WINDOW
-        ):
+        res = gh_graphql(query, cursor=endCursor, **kwargs)
+        data.extend(get_data(res))
+        hasNextPage = get_page_info(res)["hasNextPage"]
+        endCursor = get_page_info(res)["endCursor"]
+        if termination_func(data):
             break
+    return data
+
+
+def get_recent_prs() -> Dict[str, Any]:
+    now = datetime.now().timestamp()
+
+    # Grab all PRs updated in last CLOSED_PR_RETENTION days
+    pr_infos: List[Dict[str, Any]] = paginate_graphql(
+        GRAPHQL_ALL_PRS_BY_UPDATED_AT,
+        {"owner": "pytorch", "repo": "pytorch"},
+        lambda data: (
+            PR_WINDOW is not None
+            and (now - convert_gh_timestamp(data[-1]["updatedAt"]) > PR_WINDOW)
+        ),
+        lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
+        lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
+    )
 
     # Get the most recent PR for each branch base (group gh together)
     prs_by_branch_base = {}
@@ -124,6 +183,35 @@
     return prs_by_branch_base
 
 
+def get_branches_with_magic_label_or_open_pr() -> Set[str]:
+    pr_infos: List[Dict[str, Any]] = paginate_graphql(
+        GRAPHQL_NO_DELETE_BRANCH_LABEL,
+        {"owner": "pytorch", "repo": "pytorch"},
+        lambda data: False,
+        lambda res: res["data"]["repository"]["label"]["pullRequests"]["nodes"],
+        lambda res: res["data"]["repository"]["label"]["pullRequests"]["pageInfo"],
+    )
+
+    pr_infos.extend(
+        paginate_graphql(
+            GRAPHQL_OPEN_PRS,
+            {"owner": "pytorch", "repo": "pytorch"},
+            lambda data: False,
+            lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
+            lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
+        )
+    )
+
+    # Get the most recent PR for each branch base (group gh together)
+    branch_bases = set()
+    for pr in pr_infos:
+        branch_base_name = pr["headRefName"]
+        if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch_base_name):
+            branch_base_name = x.group(1)
+        branch_bases.add(branch_base_name)
+    return branch_bases
+
+
 def delete_branch(repo: GitRepo, branch: str) -> None:
     repo._run_git("push", "origin", "-d", branch)
 
@@ -132,16 +220,17 @@
     now = datetime.now().timestamp()
     git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True)
     branches = get_branches(git_repo)
-    prs_by_branch = get_prs()
+    prs_by_branch = get_recent_prs()
+    keep_branches = get_branches_with_magic_label_or_open_pr()
 
     delete = []
     # Do not delete if:
     # * associated PR is open, closed but updated recently, or contains the magic string
     # * no associated PR and branch was updated in last 1.5 years
     # * is protected
-    # Setting different values of PR_WINDOW will change how branches with open
+    # Setting different values of PR_WINDOW will change how branches with closed
     # PRs are treated depending on how old the branch is.  The default value of
-    # 90 will allow branches with open PRs to be deleted if the PR hasn't been
+    # 90 will allow branches with closed PRs to be deleted if the PR hasn't been
     # updated in 90 days and the branch hasn't been updated in 1.5 years
     for base_branch, (date, sub_branches) in branches.items():
         print(f"[{base_branch}] Updated {(now - date) / SEC_IN_DAY} days ago")
@@ -150,9 +239,8 @@
             print(
                 f"[{base_branch}] Has PR {pr['number']}: {pr['state']}, updated {(now - pr['updatedAt']) / SEC_IN_DAY} days ago"
             )
-            if PR_BODY_MAGIC_STRING in pr["body"]:
-                continue
-            if pr["state"] == "OPEN":
+            if base_branch in keep_branches:
+                print(f"[{base_branch}] Has magic label or open PR, skipping")
                 continue
             if (
                 now - pr["updatedAt"] < CLOSED_PR_RETENTION