| import json |
| import os |
| import re |
| import subprocess |
| from collections import defaultdict |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import cast, Dict, List, Optional, Set, Union |
| from urllib.request import Request, urlopen |
| from warnings import warn |
| |
| from tools.testing.test_run import TestRun |
| |
| |
| REPO_ROOT = Path(__file__).resolve().parents[4] |
| |
| |
| def python_test_file_to_test_name(tests: Set[str]) -> Set[str]: |
| prefix = f"test{os.path.sep}" |
| valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")} |
| valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests} |
| |
| return valid_tests |
| |
| |
| @lru_cache(maxsize=None) |
| def get_pr_number() -> Optional[int]: |
| pr_number = os.environ.get("PR_NUMBER", "") |
| if pr_number == "": |
| re_match = re.match(r"^refs/tags/.*/(\d+)$", os.environ.get("GITHUB_REF", "")) |
| if re_match is not None: |
| pr_number = re_match.group(1) |
| if pr_number != "": |
| return int(pr_number) |
| return None |
| |
| |
| @lru_cache(maxsize=None) |
| def get_merge_base() -> str: |
| pr_number = get_pr_number() |
| if pr_number is not None: |
| github_token = os.environ.get("GITHUB_TOKEN") |
| headers = { |
| "Accept": "application/vnd.github.v3+json", |
| "Authorization": f"token {github_token}", |
| } |
| url = f"https://api.github.com/repos/pytorch/pytorch/pulls/{pr_number}" |
| with urlopen(Request(url, headers=headers)) as conn: |
| pr_info = json.loads(conn.read().decode()) |
| base = f"origin/{pr_info['base']['ref']}" |
| merge_base = ( |
| subprocess.check_output(["git", "merge-base", base, "HEAD"]) |
| .decode() |
| .strip() |
| ) |
| return merge_base |
| default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}" |
| merge_base = ( |
| subprocess.check_output(["git", "merge-base", default_branch, "HEAD"]) |
| .decode() |
| .strip() |
| ) |
| |
| head = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() |
| |
| if merge_base == head: |
| # We are on the default branch, so check for changes since the last commit |
| merge_base = "HEAD^" |
| return merge_base |
| |
| |
| def query_changed_files() -> List[str]: |
| base_commit = get_merge_base() |
| |
| proc = subprocess.run( |
| ["git", "diff", "--name-only", base_commit, "HEAD"], |
| capture_output=True, |
| check=False, |
| ) |
| print(f"base_commit: {base_commit}") |
| |
| if proc.returncode != 0: |
| raise RuntimeError("Unable to get changed files") |
| |
| lines = proc.stdout.decode().strip().split("\n") |
| lines = [line.strip() for line in lines] |
| print(f"Changed files: {lines}") |
| return lines |
| |
| |
| @lru_cache(maxsize=None) |
| def get_git_commit_info() -> str: |
| """Gets the commit info since the last commit on the default branch.""" |
| base_commit = get_merge_base() |
| |
| return ( |
| subprocess.check_output( |
| ["git", "log", f"{base_commit}..HEAD"], |
| ) |
| .decode() |
| .strip() |
| ) |
| |
| |
| @lru_cache(maxsize=None) |
| def get_issue_or_pr_body(number: int) -> str: |
| """Gets the body of an issue or PR""" |
| github_token = os.environ.get("GITHUB_TOKEN") |
| headers = { |
| "Accept": "application/vnd.github.v3+json", |
| "Authorization": f"token {github_token}", |
| } |
| # Despite the 'issues' in the link, this also works for PRs |
| url = f"https://api.github.com/repos/pytorch/pytorch/issues/{number}" |
| with urlopen(Request(url, headers=headers)) as conn: |
| body: str = json.loads(conn.read().decode())["body"] |
| return body |
| |
| |
| def normalize_ratings( |
| ratings: Dict[TestRun, float], max_value: float, min_value: float = 0 |
| ) -> Dict[TestRun, float]: |
| # Takse the ratings, makes the max value into max_value, and proportionally |
| # distributes the rest of the ratings. |
| # Ex [1,2,3,4] and max_value 8 gets converted to [2,4,6,8] |
| # Assumes all rankings are >= 0 |
| # min_value is what 0 gets mapped to and shifts the values accordingly. Ex |
| # [1,2,3,4], min_value 1, max_value 5 gets converted to [2,3,4,5] |
| # Don't modify in place |
| if len(ratings) == 0: |
| return ratings |
| min_rating = min(ratings.values()) |
| assert min_rating > 0 |
| max_rating = max(ratings.values()) |
| assert max_rating > 0 |
| normalized_ratings = {} |
| for tf, rank in ratings.items(): |
| normalized_ratings[tf] = rank / max_rating * (max_value - min_value) + min_value |
| return normalized_ratings |
| |
| |
| def get_ratings_for_tests(file: Union[str, Path]) -> Dict[str, float]: |
| path = REPO_ROOT / file |
| if not os.path.exists(path): |
| print(f"could not find path {path}") |
| return {} |
| with open(path) as f: |
| test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f)) |
| try: |
| changed_files = query_changed_files() |
| except Exception as e: |
| warn(f"Can't query changed test files due to {e}") |
| return {} |
| ratings: Dict[str, float] = defaultdict(float) |
| for file in changed_files: |
| for test_file, score in test_file_ratings.get(file, {}).items(): |
| ratings[test_file] += score |
| return ratings |
| |
| |
| def get_correlated_tests(file: Union[str, Path]) -> List[str]: |
| ratings = get_ratings_for_tests(file) |
| prioritize = sorted(ratings, key=lambda x: -ratings[x]) |
| return prioritize |