blob: e94c4dfd08923454018bc4f5c09a41c443fc2a86 [file] [log] [blame]
#!/usr/bin/env python3
# Tests implemented in this file are relying on GitHub GraphQL APIs
# In order to avoid test flakiness, results of the queries
# are cached in gql_mocks.json
# PyTorch Lint workflow does not have GITHUB_TOKEN defined to avoid
# flakiness, so if you are making changes to merge_rules or
# GraphQL queries in trymerge.py, please make sure to delete `gql_mocks.json`
# And re-run the test locally with ones PAT
import gzip
import json
import os
import warnings
from hashlib import sha256
from typing import Any, Dict, List, Optional
from unittest import main, mock, skip, TestCase
from urllib.error import HTTPError
from github_utils import gh_graphql
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from trymerge import (
categorize_checks,
DRCI_CHECKRUN_NAME,
find_matching_merge_rule,
get_classifications,
get_drci_classifications,
get_rockset_results,
gh_get_team_members,
GitHubPR,
JobCheckState,
main as trymerge_main,
MandatoryChecksMissingError,
MergeRule,
RE_GHSTACK_DESC,
read_merge_rules,
remove_job_name_suffix,
validate_revert,
)
if "GIT_REMOTE_URL" not in os.environ:
os.environ["GIT_REMOTE_URL"] = "https://github.com/pytorch/pytorch"
GQL_MOCKS = "gql_mocks.json.gz"
ROCKSET_MOCKS = "rockset_mocks.json.gz"
DRCI_MOCKS = "drci_mocks.json.gz"
def mock_query(
fallback_function: Any,
file_name: str,
key_function: Any,
*args: Any,
) -> Any:
gql_db_fname = os.path.join(os.path.dirname(__file__), file_name)
def get_mocked_queries() -> Any:
if not os.path.exists(gql_db_fname):
return {}
with gzip.open(gql_db_fname, encoding="utf-8", mode="rt") as f:
return json.load(f)
def save_mocked_queries(obj: Any) -> None:
with gzip.open(gql_db_fname, encoding="utf-8", mode="wt") as f:
json.dump(obj, f, indent=2)
f.write("\n")
key = key_function(*args)
mocked_queries = get_mocked_queries()
if key in mocked_queries:
return mocked_queries[key]
try:
rc = fallback_function(*args)
except HTTPError as err:
if err.code == 401 or err.code == 403:
err_msg = f"If you are seeing this message during workflow run, please make sure to update {file_name}"
err_msg += f" locally, by deleting it and running {os.path.basename(__file__)} with"
err_msg += " GitHub Personal Access Token passed via GITHUB_TOKEN,"
err_msg += " the rockset api key passed via ROCKSET_API_KEY,"
err_msg += " and drci api key passed via DRCI_BOT_KEY environment variables"
if (
os.getenv("GITHUB_TOKEN") is None
or os.getenv("ROCKSET_API_KEY") is None
or os.getenv("DRCI_BOT_KEY") is None
):
err_msg = (
"Failed to update cached queries as GITHUB_TOKEN or ROCKSET_API_KEY or DRCI_BOT_KEY "
+ "is not defined. "
+ err_msg
)
raise RuntimeError(err_msg) from err
mocked_queries[key] = rc
save_mocked_queries(mocked_queries)
return rc
def mocked_gh_graphql(query: str, **kwargs: Any) -> Any:
def key_function(query: str, kwargs: Any) -> str:
return f"query_sha={sha256(query.encode('utf-8')).hexdigest()} " + " ".join(
[f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]
)
def gh_graphql_wrapper(query: str, kwargs: Any) -> Any:
return gh_graphql(query, **kwargs)
return mock_query(gh_graphql_wrapper, GQL_MOCKS, key_function, query, kwargs)
def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3) -> Any:
return mock_query(
get_rockset_results,
ROCKSET_MOCKS,
lambda x, y: f"{x} {y}",
head_sha,
merge_base,
)
def mocked_drci_classifications(pr_num: int, project: str, num_retries: int = 3) -> Any:
return mock_query(
get_drci_classifications,
DRCI_MOCKS,
lambda x, y: f"{x} {y}",
pr_num,
project,
)
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
class Object:
def __init__(self) -> None:
self.revert = revert
self.force = force
self.pr_num = 76123
self.dry_run = True
self.comment_id = 0
self.reason = "this is for testing"
self.ignore_current = False
self.check_mergeability = False
return Object()
def mock_remove_label(
org: str, repo: str, pr_num: str, label: str, dry_run: bool
) -> None:
pass
def mock_revert(
repo: GitRepo,
pr: GitHubPR,
*,
dry_run: bool = False,
comment_id: Optional[int] = None,
reason: Optional[str] = None,
) -> None:
pass
def mock_merge(
pr: GitHubPR,
repo: GitRepo,
dry_run: bool = False,
skip_mandatory_checks: bool = False,
comment_id: Optional[int] = None,
timeout_minutes: int = 400,
stale_pr_days: int = 3,
ignore_current: bool = False,
) -> None:
pass
def mock_gh_get_info() -> Any:
return {
"closed": False,
"isCrossRepository": False,
"files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
"changedFiles": 0,
}
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
return [
MergeRule(
name="mock with nonexistent check",
patterns=["*"],
approved_by=[],
mandatory_checks_name=["Lint", "Facebook CLA Check", "nonexistent"],
ignore_flaky_failures=True,
),
]
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
return [
MergeRule(
name="super",
patterns=["*"],
approved_by=["pytorch/metamates", "ngimel"],
mandatory_checks_name=[
"Lint",
"Facebook CLA Check",
"pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
],
ignore_flaky_failures=True,
),
MergeRule(
name="xla",
patterns=[".github/ci_commit_pins/xla.txt"],
approved_by=["pytorchbot"],
mandatory_checks_name=[
"Lint",
"EasyCLA",
"pull / linux-focal-py3_8-clang9-xla / build",
"pull / linux-focal-py3_8-clang9-xla / test (xla, 1, 1, linux.12xlarge)",
],
ignore_flaky_failures=True,
),
]
def mocked_read_merge_rules_approvers(
repo: Any, org: str, project: str
) -> List[MergeRule]:
return [
MergeRule(
name="Core Reviewers",
patterns=["*"],
approved_by=["1", "2", "3", "4", "5", "6"],
mandatory_checks_name=[
"Lint",
"pull",
],
),
MergeRule(
name="Core Maintainers",
patterns=["*"],
approved_by=["1", "2", "malfet"],
mandatory_checks_name=[
"Lint",
"pull",
],
),
]
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
raise RuntimeError("testing")
def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
return [
MergeRule(
name=" OSS CI / pytorchbot / XLA",
patterns=[".github/ci_commit_pins/xla.txt"],
approved_by=["pytorchbot"],
mandatory_checks_name=[
"Lint",
"EasyCLA",
"pull / linux-bionic-py3_8-clang8-xla / build",
"pull / linux-bionic-py3_8-clang8-xla / test (xla, 1, 1, linux.4xlarge)",
"inductor / cuda11.8-py3.10-gcc7-sm86 / test (inductor_torchbench_dynamic, 1, 1, linux.g5.4xlarge.nvidia.gpu)",
],
ignore_flaky_failures=False,
),
]
def empty_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
return []
class DummyGitRepo(GitRepo):
def __init__(self) -> None:
super().__init__(get_git_repo_dir(), get_git_remote_name())
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
return ["FakeCommitSha"]
def commit_message(self, ref: str) -> str:
return "super awsome commit message"
@mock.patch("trymerge.get_rockset_results", side_effect=empty_rockset_results)
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch(
"trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
)
class TestTryMerge(TestCase):
def test_merge_rules_valid(self, *args: Any) -> None:
"Test that merge_rules.yaml can be parsed"
repo = DummyGitRepo()
merge_rules = read_merge_rules(repo, "pytorch", "pytorch")
self.assertGreater(len(merge_rules), 1)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
def test_match_rules(self, *args: Any) -> None:
"Tests that PR passes merge rules"
pr = GitHubPR("pytorch", "pytorch", 109999)
repo = DummyGitRepo()
self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_raise)
def test_read_merge_rules_fails(self, *args: Any) -> None:
"Tests that PR fails to read the merge rules"
pr = GitHubPR("pytorch", "pytorch", 77700)
repo = DummyGitRepo()
self.assertRaisesRegex(
RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo)
)
@mock.patch(
"trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_approvers
)
def test_match_rules_approvers(self, *args: Any) -> None:
"Tests that PR has the necessary approvers"
repo = DummyGitRepo()
pr = GitHubPR("pytorch", "pytorch", 115329)
# Test that all potential approvers across all rules are listed if the
# PR doesn't have one of them
for mock_rule in ["Core Reviewers", "Core Maintainers"]:
self.assertRaisesRegex(
RuntimeError,
mock_rule,
lambda: find_matching_merge_rule(pr, repo),
)
pr = GitHubPR("pytorch", "pytorch", 115495)
# Test that PR with the correct approvers doesn't raise any exception
self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
def test_lint_fails(self, *args: Any) -> None:
"Tests that PR fails mandatory lint check"
pr = GitHubPR("pytorch", "pytorch", 90791)
repo = DummyGitRepo()
self.assertRaises(RuntimeError, lambda: find_matching_merge_rule(pr, repo))
def test_get_last_comment(self, *args: Any) -> None:
"Tests that last comment can be fetched"
pr = GitHubPR("pytorch", "pytorch", 71759)
comment = pr.get_last_comment()
self.assertEqual(comment.author_login, "github-actions")
self.assertIsNone(comment.editor_login)
self.assertTrue("You've committed this PR" in comment.body_text)
def test_get_author_null(self, *args: Any) -> None:
"""Tests that PR author can be computed
If reply contains NULL
"""
pr = GitHubPR("pytorch", "pytorch", 71759)
author = pr.get_author()
self.assertTrue(author is not None)
self.assertTrue("@" in author)
self.assertTrue(pr.get_diff_revision() is None)
# PR with multiple contributors, but creator id is not among authors
pr = GitHubPR("pytorch", "pytorch", 75095)
self.assertEqual(pr.get_pr_creator_login(), "mruberry")
author = pr.get_author()
self.assertTrue(author is not None)
def test_large_diff(self, *args: Any) -> None:
"Tests that PR with 100+ files can be fetched"
pr = GitHubPR("pytorch", "pytorch", 73099)
self.assertTrue(pr.get_changed_files_count() > 100)
flist = pr.get_changed_files()
self.assertEqual(len(flist), pr.get_changed_files_count())
def test_internal_changes(self, *args: Any) -> None:
"Tests that PR with internal changes is detected"
pr = GitHubPR("pytorch", "pytorch", 110140)
self.assertTrue(pr.has_internal_changes())
def test_comments_pagination(self, *args: Any) -> None:
"Tests that PR with 50+ comments can be fetched"
pr = GitHubPR("pytorch", "pytorch", 31093)
self.assertGreater(len(pr.get_comments()), 50)
def test_gql_complexity(self, *args: Any) -> None:
"Fetch comments and conclusions for PR with 60 commits"
# Previous version of GrapQL query used to cause HTTP/502 error
# see https://gist.github.com/malfet/9b93bc7eeddeaf1d84546efc4f0c577f
pr = GitHubPR("pytorch", "pytorch", 68111)
self.assertGreater(len(pr.get_comments()), 20)
# NS(09/27/2023): GitHub seems to recycle older checkruns
# https://github.com/pytorch/pytorch/pull/68111/checks shows 0 runs
# self.assertGreater(len(pr.get_checkrun_conclusions()), 3)
self.assertGreater(pr.get_commit_count(), 60)
def test_gql_retrieve_checksuites(self, *args: Any) -> None:
"Fetch comments and conclusions for PR with 60 commits"
pr = GitHubPR("pytorch", "pytorch", 94787)
self.assertEqual(len(pr.get_checkrun_conclusions()), 183)
def test_team_members(self, *args: Any) -> None:
"Test fetching team members works"
dev_infra_team = gh_get_team_members("pytorch", "pytorch-dev-infra")
self.assertGreater(len(dev_infra_team), 2)
with self.assertWarns(Warning):
non_existing_team = gh_get_team_members("pytorch", "qwertyuiop")
self.assertEqual(len(non_existing_team), 0)
def test_get_author_many_commits(self, *args: Any) -> None:
"""Tests that authors for all commits can be fetched"""
pr = GitHubPR("pytorch", "pytorch", 76118)
authors = pr.get_authors()
self.assertGreater(pr.get_commit_count(), 100)
self.assertGreater(len(authors), 50)
self.assertTrue("@" in pr.get_author())
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_NE)
def test_pending_status_check(self, *args: Any) -> None:
"""Tests that PR with nonexistent/pending status checks fails with the right reason."""
pr = GitHubPR("pytorch", "pytorch", 76118)
repo = DummyGitRepo()
self.assertRaisesRegex(
MandatoryChecksMissingError,
".*are pending/not yet run.*",
lambda: find_matching_merge_rule(pr, repo),
)
def test_get_author_many_reviews(self, *args: Any) -> None:
"""Tests that all reviews can be fetched"""
pr = GitHubPR("pytorch", "pytorch", 76123)
approved_by = pr.get_approved_by()
self.assertGreater(len(approved_by), 0)
assert pr._reviews is not None # to pacify mypy
self.assertGreater(len(pr._reviews), 100)
def get_co_authors(self, *args: Any) -> None:
"""Tests that co-authors are recognized"""
pr = GitHubPR("pytorch", "pytorch", 118347)
authors = pr.get_authors()
self.assertIn("kit1980", authors)
self.assertIn("Co-authored-by:", pr.gen_commit_message())
def test_get_checkruns_many_runs(self, *args: Any) -> None:
"""Tests that all checkruns can be fetched"""
pr = GitHubPR("pytorch", "pytorch", 105260)
conclusions = pr.get_checkrun_conclusions()
self.assertEqual(len(conclusions), 221)
self.assertTrue(
"pull / linux-docs / build-docs-cpp-false" in conclusions.keys()
)
def test_cancelled_gets_ignored(self, *args: Any) -> None:
"""Tests that cancelled workflow does not override existing successfull status"""
pr = GitHubPR("pytorch", "pytorch", 110367)
conclusions = pr.get_checkrun_conclusions()
lint_checks = [name for name in conclusions.keys() if "Lint" in name]
self.assertTrue(len(lint_checks) > 0)
self.assertTrue(
all(conclusions[name].status == "SUCCESS" for name in lint_checks)
)
def test_get_review_comment_by_id(self, *args: Any) -> None:
"""Tests that even if the comment requested was actually a review instead of a simple comment, we can still find it"""
pr = GitHubPR("pytorch", "pytorch", 107070)
review_comment_id = 1582767635
comment = pr.get_comment_by_id(review_comment_id)
self.assertIsNotNone(comment)
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(True, False))
@mock.patch("trymerge.try_revert", side_effect=mock_revert)
def test_main_revert(self, mock_revert: Any, *args: Any) -> None:
trymerge_main()
mock_revert.assert_called_once()
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, True))
@mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
@mock.patch("trymerge.merge", side_effect=mock_merge)
def test_main_force(
self, mock_merge: Any, mock_parse_args: Any, *args: Any
) -> None:
trymerge_main()
mock_merge.assert_called_once_with(
mock.ANY,
mock.ANY,
dry_run=mock.ANY,
skip_mandatory_checks=True,
comment_id=mock.ANY,
ignore_current=False,
)
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, False))
@mock.patch("trymerge.gh_remove_label", side_effect=mock_remove_label)
@mock.patch("trymerge.merge", side_effect=mock_merge)
def test_main_merge(self, mock_merge: Any, *args: Any) -> None:
trymerge_main()
mock_merge.assert_called_once_with(
mock.ANY,
mock.ANY,
dry_run=mock.ANY,
skip_mandatory_checks=False,
comment_id=mock.ANY,
ignore_current=False,
)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
def test_revert_rules(self, *args: Any) -> None:
"""Tests that reverts from collaborators are allowed"""
pr = GitHubPR("pytorch", "pytorch", 79694)
repo = DummyGitRepo()
self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845))
def test_get_changed_files(self, *args: Any) -> None:
"""
Tests that the list changed files in a PR doesn't include duplicates
"""
pr = GitHubPR("pytorch", "pytorch", 95233)
try:
changed_files = pr.get_changed_files()
except RuntimeError as error:
self.fail(f"get_changed_files throws an exception: {error}")
self.assertEqual(len(changed_files), pr.get_changed_files_count())
def test_revert_codev_abandoned_diff_succeeds(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 100652)
class GitRepoCoDev(DummyGitRepo):
def commit_message(self, ref: str) -> str:
return pr.get_body()
repo = GitRepoCoDev()
validate_revert(repo, pr, comment_id=1588195237)
def test_pr_changed_submodule_detection(self, *args: Any) -> None:
# Updates submodule during dev-cycle but reverts it later
pr = GitHubPR("pytorch", "pytorch", 95045)
self.assertEqual(pr.get_changed_submodules(), [])
self.assertFalse(pr.has_invalid_submodule_updates())
# PR updates ideep
pr = GitHubPR("pytorch", "pytorch", 94939)
self.assertEqual(pr.get_changed_submodules(), ["third_party/ideep"])
self.assertTrue(pr.has_invalid_submodule_updates())
# Automated submodule update
pr = GitHubPR("pytorch", "pytorch", 91051)
self.assertEqual(pr.get_changed_submodules(), ["third_party/kineto"])
self.assertFalse(pr.has_invalid_submodule_updates())
def test_remove_job_name_suffix(self, *args: Any) -> None:
test_cases = [
{
"name": "linux-bionic-cuda12.1-py3.10-gcc9-sm86 / test (default, 1, 5, linux.g5.4xlarge.nvidia.gpu)",
"expected": "linux-bionic-cuda12.1-py3.10-gcc9-sm86 / test (default)",
},
{
"name": "android-emulator-build-test / build-and-test (default, 1, 1, ubuntu-20.04-16x)",
"expected": "android-emulator-build-test / build-and-test (default)",
},
{
"name": "linux-focal-rocm5.4.2-py3.8 / build",
"expected": "linux-focal-rocm5.4.2-py3.8 / build",
},
{
"name": "libtorch-cpu-shared-with-deps-release-build",
"expected": "libtorch-cpu-shared-with-deps-release-build",
},
{
"name": "manywheel-py3_8-cuda11_8-test / test",
"expected": "manywheel-py3_8-cuda11_8-test / test",
},
{
"name": "lintrunner / linux-job",
"expected": "lintrunner / linux-job",
},
{
"name": "Test `run_test.py` is usable without boto3/rockset",
"expected": "Test `run_test.py` is usable without boto3/rockset",
},
]
for case in test_cases:
self.assertEqual(case["expected"], remove_job_name_suffix(case["name"]))
def test_get_merge_base(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 104121)
mock_merge_base = "mocked-sha"
with mock.patch(
"trymerge.gh_fetch_merge_base", return_value=mock_merge_base
) as mocked_gh_fetch_merge_base:
self.assertEqual(mock_merge_base, pr.get_merge_base())
# Make sure that consecutive calls will use the same merge base instead of
# making another query
self.assertEqual(mock_merge_base, pr.get_merge_base())
mocked_gh_fetch_merge_base.assert_called_once()
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
@mock.patch(
"trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
)
class TestBypassFailures(TestCase):
def test_get_classifications(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 109584)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
self.assertTrue(
checks[
"pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
].classification
== "BROKEN_TRUNK"
)
self.assertTrue(
checks[
"trunk / win-vs2019-cpu-py3 / test (default, 2, 3, windows.4xlarge.nonephemeral)"
].classification
== "FLAKY"
)
self.assertTrue(
checks[
"pull / linux-jammy-py3.8-gcc11 / test (distributed, 1, 2, linux.2xlarge)"
].classification
== "FLAKY"
)
self.assertTrue(
checks[
"pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
].classification
== "FLAKY"
)
# Set the threshold larger or equal to the number of ok failures
pending, failed, ignorable = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=6
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
# Not set any threshold, defaults to -1 to ignore all flaky and broken trunk failures
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
# Set the threshold lower than the number of ok failures
pending, failed, ignorable = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=1
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 6)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
# Set the threshold to 0 like when ignore_flaky_failures is on
pending, failed, ignorable = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=1
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 6)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
def test_get_classifications_flaky_fullname(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 110362)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 1)
def test_get_classifications_invalid_cancel(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 110367)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 0)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 0)
self.assertTrue(len(ignorable["UNSTABLE"]) == 3)
def test_get_classifications_similar_failures(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 109750)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 1)
def test_get_classifications_unstable(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 104312)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
workflow_name = "linux-bionic-cuda12.1-py3.10-gcc9-bazel-test"
job_name = "build-and-test (default, 1, 1, linux.4xlarge.nvidia.gpu, unstable)"
self.assertTrue(
checks[f"pull / {workflow_name} / {job_name}"].classification == "UNSTABLE"
)
pending, failed, ignorable = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=1
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["UNSTABLE"]) == 1)
def test_get_classifications_broken_trunk(self, *args: Any) -> None:
# The mock merge base is the actual value returned by gh_fetch_merge_base
test_cases = [
{
# This PR had one broken trunk failure but it was run on a different shard
# than the one on the base commit. This should still count as broken trunk
"pr_num": 104214,
"related_failure_count": 0,
"unrelated_failure_count": 1,
},
{
# This PR had one broken trunk failure and it used ghstack
"pr_num": 105145,
"related_failure_count": 0,
"unrelated_failure_count": 1,
},
{
# The failure on the merge base was retried successfully and
# its conclusion changed from failure to success. We want to
# keep the failure record from the merge base so that it can
# be used to detect broken trunk
"pr_num": 107160,
"related_failure_count": 0,
"unrelated_failure_count": 4,
},
{
# This PR used Dr.CI broken trunk classification
"pr_num": 111253,
"related_failure_count": 1,
"unrelated_failure_count": 2,
},
]
for case in test_cases:
pr_num = case["pr_num"]
related_failure_count = case["related_failure_count"]
unrelated_failure_count = case["unrelated_failure_count"]
pr = GitHubPR("pytorch", "pytorch", pr_num)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, _ = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == related_failure_count)
# When the ok_failed_checks_threshold is set to 0, the broken trunk failure
# won't be ignored
pending, failed, _ = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=0
)
self.assertTrue(len(pending) == 0)
self.assertTrue(
len(failed) == unrelated_failure_count + related_failure_count
)
def test_ignore_current(self, *args: Any) -> None:
# Test various interactions of the failure classifier to ensure that ignore
# current checks takes place after other classifications: flaky, unstable,
# or broken trunk. Only actual new failures should be kept in the list of
# ignore current checks to use to record force merge with actual failures
flaky = "pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)"
broken_trunk = (
"pull / linux-focal-py3.11-clang10 / test (dynamo, 1, 2, linux.2xlarge)"
)
pr = GitHubPR("pytorch", "pytorch", 109584)
checks = pr.get_checkrun_conclusions()
# Known flaky failure takes precedence over ignore current (need to set the
# merge base here to get the results from Rockset, and that categorize the
# broken trunk failure too
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[broken_trunk, flaky],
)
self.assertTrue(checks[flaky].classification == "FLAKY")
self.assertTrue(checks[broken_trunk].classification == "BROKEN_TRUNK")
_, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["IGNORE_CURRENT_CHECK"]) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 4)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 2)
@mock.patch("trymerge.read_merge_rules", side_effect=xla_merge_rules)
def test_dont_ignore_flaky_failures(self, *args: Any) -> None:
"""
Regression test for https://github.com/pytorch/test-infra/issues/4126
"""
pr = GitHubPR("pytorch", "pytorch", 105312)
repo = DummyGitRepo()
# Check that failure is classified as flaky but still raises exception
with warnings.catch_warnings(record=True) as w, self.assertRaises(RuntimeError):
rule = find_matching_merge_rule(pr, repo)
self.assertEqual(len(w), 1)
self.assertIn(
"1 checks failed but were likely due flakiness or broken trunk",
str(w[0].message),
)
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
@mock.patch("trymerge.get_drci_classifications", return_value={})
class TestBypassFailuresOnSandCastle(TestCase):
def test_get_classifications(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 111467)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
self.assertTrue(len(ignorable["FLAKY"]) == 1)
self.assertTrue(len(ignorable["BROKEN_TRUNK"]) == 1)
def test_get_classifications_drci_checkrun_not_found(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 111467)
# No summary
checks = pr.get_checkrun_conclusions()
checks[DRCI_CHECKRUN_NAME] = JobCheckState(
DRCI_CHECKRUN_NAME,
"",
"NEUTRAL",
None,
1,
"",
None,
)
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 2)
# Empty summary
checks = pr.get_checkrun_conclusions()
checks[DRCI_CHECKRUN_NAME] = JobCheckState(
DRCI_CHECKRUN_NAME,
"",
"NEUTRAL",
None,
1,
"",
"",
)
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 2)
# No Dr.CI checkrun
checks = pr.get_checkrun_conclusions()
del checks[DRCI_CHECKRUN_NAME]
checks = get_classifications(
pr.pr_num,
pr.project,
checks,
[],
)
pending, failed, ignorable = categorize_checks(checks, list(checks.keys()))
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 2)
@mock.patch("trymerge.get_rockset_results", side_effect=mocked_rockset_results)
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.gh_fetch_merge_base", return_value="")
@mock.patch(
"trymerge.get_drci_classifications", side_effect=mocked_drci_classifications
)
class TestGitHubPRGhstackDependencies(TestCase):
def test_pr_dependencies(self, *args: Any) -> None:
pr = GitHubPR("pytorch", "pytorch", 106068)
msg = pr.gen_commit_message(filter_ghstack=True)
self.assertEqual(
msg,
f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
"Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
"Approved by: https://github.com/ezyang, https://github.com/fegin\n",
)
def test_pr_dependencies_ghstack(self, *args: Any) -> None:
pr0 = GitHubPR("pytorch", "pytorch", 106032)
pr1 = GitHubPR("pytorch", "pytorch", 106033)
pr2 = GitHubPR("pytorch", "pytorch", 106034)
pr = GitHubPR("pytorch", "pytorch", 106068)
msg = pr.gen_commit_message(filter_ghstack=True, ghstack_deps=[pr0, pr1, pr2])
self.assertEqual(
msg,
f"{pr.get_title()} (#106068)\n\n{RE_GHSTACK_DESC.sub('', pr.get_body())}\n"
"Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068\n"
"Approved by: https://github.com/ezyang, https://github.com/fegin\n"
"ghstack dependencies: #106032, #106033, #106034\n",
)
@skip(
reason="This test is run against a mutalbe PR that has changed, so it no longer works. The test should be changed"
)
@mock.patch("trymerge.read_merge_rules")
@mock.patch("trymerge.GitRepo")
@mock.patch("trymerge.get_ghstack_prs")
def test_merge_ghstack_into(
self,
mock_get_ghstack_prs: mock.MagicMock,
mock_repo: mock.MagicMock,
mock_merge_rules: mock.MagicMock,
*args: Any,
) -> None:
"""
Test that the merge_ghstack_into method works correctly
"""
pr0 = GitHubPR("pytorch", "pytorch", 106032)
pr1 = GitHubPR("pytorch", "pytorch", 106033)
pr2 = GitHubPR("pytorch", "pytorch", 106034)
pr = GitHubPR("pytorch", "pytorch", 106068)
# note: in reverse order (e.g. self.pr is the last commit, top of the stack)
mock_get_ghstack_prs.return_value = [
(pr0, "rev0"),
(pr1, "rev1"),
(pr2, "rev2"),
(pr, "rev123"),
]
mock_merge_rules.return_value = [
MergeRule(
"Mock title", patterns=["*"], approved_by=[], mandatory_checks_name=None
)
]
mock_repo.cherry_pick.return_value = None
mock_repo.amend_commit_message.return_value = None
# Call the method under test
res = pr.merge_ghstack_into(mock_repo, True)
self.assertEqual(res, [pr2, pr])
mock_repo.cherry_pick.assert_any_call("rev2")
mock_repo.cherry_pick.assert_any_call("rev123")
self.assertTrue(mock.call("rev1") not in mock_repo.cherry_pick.call_args_list)
# Verify the first call
message = mock_repo.amend_commit_message.call_args_list[0].args[0]
prefix = (
"[FSDP] Optimize away intermediate `div_` for HSDP (#106034)\n\n\r\n"
"### Background: Gradient Pre-Divide"
)
suffix = (
"\nPull Request resolved: https://github.com/pytorch/pytorch/pull/106034\nApproved by: \nghstack "
"dependencies: #106032, #106033\n"
)
self.assertTrue(message.startswith(prefix))
self.assertTrue(message.endswith(suffix))
# Verify the second call
mock_repo.amend_commit_message.assert_any_call(
"[FSDP] Break up `_post_backward_hook` into smaller funcs (#106068)\n\n\n"
"Differential Revision: ["
"D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)\n"
"Pull Request resolved: "
"https://github.com/pytorch/pytorch/pull/106068\n"
"Approved by: \n"
"ghstack dependencies: #106032, #106033, #106034\n"
)
if __name__ == "__main__":
main()