blob: 6189c33123f981330b2466258185a41f2e6c8148 [file] [log] [blame]
import argparse
import json
import os
import xml.etree.ElementTree as ET
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Generator, Tuple
from tools.stats.upload_stats_lib import (
download_gha_artifacts,
download_s3_artifacts,
is_rerun_disabled_tests,
unzip,
upload_workflow_stats_to_s3,
)
from tools.stats.upload_test_stats import process_xml_element
TESTCASE_TAG = "testcase"
SEPARATOR = ";"
def process_report(
report: Path,
) -> Dict[str, Dict[str, int]]:
"""
Return a list of disabled tests that should be re-enabled and those that are still
flaky (failed or skipped)
"""
root = ET.parse(report)
# All rerun tests from a report are grouped here:
#
# * Success test should be re-enable if it's green after rerunning in all platforms
# where it is currently disabled
# * Failures from pytest because pytest-flakefinder is used to run the same test
# multiple times, some could fails
# * Skipped tests from unittest
#
# We want to keep track of how many times the test fails (num_red) or passes (num_green)
all_tests: Dict[str, Dict[str, int]] = {}
for test_case in root.iter(TESTCASE_TAG):
parsed_test_case = process_xml_element(test_case)
# Under --rerun-disabled-tests mode, a test is skipped when:
# * it's skipped explicitly inside PyTorch code
# * it's skipped because it's a normal enabled test
# * or it's falky (num_red > 0 and num_green > 0)
# * or it's failing (num_red > 0 and num_green == 0)
#
# We care only about the latter two here
skipped = parsed_test_case.get("skipped", None)
# NB: Regular ONNX tests could return a list of subskips here where each item in the
# list is a skipped message. In the context of rerunning disabled tests, we could
# ignore this case as returning a list of subskips only happens when tests are run
# normally
if skipped and (
type(skipped) is list or "num_red" not in skipped.get("message", "")
):
continue
name = parsed_test_case.get("name", "")
classname = parsed_test_case.get("classname", "")
filename = parsed_test_case.get("file", "")
if not name or not classname or not filename:
continue
# Check if the test is a failure
failure = parsed_test_case.get("failure", None)
disabled_test_id = SEPARATOR.join([name, classname, filename])
if disabled_test_id not in all_tests:
all_tests[disabled_test_id] = {
"num_green": 0,
"num_red": 0,
}
# Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's
# counted as a success. Otherwise, it's still flaky or failing
if skipped:
try:
stats = json.loads(skipped.get("message", ""))
except json.JSONDecodeError:
stats = {}
all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0)
all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0)
elif failure:
# As a failure, increase the failure count
all_tests[disabled_test_id]["num_red"] += 1
else:
all_tests[disabled_test_id]["num_green"] += 1
return all_tests
def get_test_reports(
repo: str, workflow_run_id: int, workflow_run_attempt: int
) -> Generator[Path, None, None]:
"""
Gather all the test reports from S3 and GHA. It is currently not possible to guess which
test reports are from rerun_disabled_tests workflow because the name doesn't include the
test config. So, all reports will need to be downloaded and examined
"""
with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir)
os.chdir(temp_dir)
artifact_paths = download_s3_artifacts(
"test-reports", workflow_run_id, workflow_run_attempt
)
for path in artifact_paths:
unzip(path)
artifact_paths = download_gha_artifacts(
"test-report", workflow_run_id, workflow_run_attempt
)
for path in artifact_paths:
unzip(path)
yield from Path(".").glob("**/*.xml")
def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]:
"""
Follow flaky bot convention here, if that changes, this will also need to be updated
"""
name, classname, filename = test_id.split(SEPARATOR)
return f"{name} (__main__.{classname})", name, classname, filename
def prepare_record(
workflow_id: int,
workflow_run_attempt: int,
name: str,
classname: str,
filename: str,
flaky: bool,
num_red: int = 0,
num_green: int = 0,
) -> Tuple[Any, Dict[str, Any]]:
"""
Prepare the record to save onto S3
"""
key = (
workflow_id,
workflow_run_attempt,
name,
classname,
filename,
)
record = {
"workflow_id": workflow_id,
"workflow_run_attempt": workflow_run_attempt,
"name": name,
"classname": classname,
"filename": filename,
"flaky": flaky,
"num_green": num_green,
"num_red": num_red,
}
return key, record
def save_results(
workflow_id: int,
workflow_run_attempt: int,
all_tests: Dict[str, Dict[str, int]],
) -> None:
"""
Save the result to S3, so it can go to Rockset
"""
should_be_enabled_tests = {
name: stats
for name, stats in all_tests.items()
if "num_green" in stats
and stats["num_green"]
and "num_red" in stats
and stats["num_red"] == 0
}
still_flaky_tests = {
name: stats
for name, stats in all_tests.items()
if name not in should_be_enabled_tests
}
records = {}
for test_id, stats in all_tests.items():
num_green = stats.get("num_green", 0)
num_red = stats.get("num_red", 0)
disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
key, record = prepare_record(
workflow_id=workflow_id,
workflow_run_attempt=workflow_run_attempt,
name=name,
classname=classname,
filename=filename,
flaky=test_id in still_flaky_tests,
num_green=num_green,
num_red=num_red,
)
records[key] = record
# Log the results
print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:")
for test_id, stats in should_be_enabled_tests.items():
disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
print(f" {disabled_test_name} from {filename}")
print(f"The following {len(still_flaky_tests)} are still flaky:")
for test_id, stats in still_flaky_tests.items():
num_green = stats.get("num_green", 0)
num_red = stats.get("num_red", 0)
disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
print(
f" {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}"
)
upload_workflow_stats_to_s3(
workflow_id,
workflow_run_attempt,
"rerun_disabled_tests",
list(records.values()),
)
def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
"""
Find the list of all disabled tests that should be re-enabled
"""
# Aggregated across all jobs
all_tests: Dict[str, Dict[str, int]] = {}
for report in get_test_reports(
args.repo, args.workflow_run_id, args.workflow_run_attempt
):
tests = process_report(report)
# The scheduled workflow has both rerun disabled tests and memory leak check jobs.
# We are only interested in the former here
if not is_rerun_disabled_tests(tests):
continue
for name, stats in tests.items():
if name not in all_tests:
all_tests[name] = stats.copy()
else:
all_tests[name]["num_green"] += stats.get("num_green", 0)
all_tests[name]["num_red"] += stats.get("num_red", 0)
save_results(
workflow_run_id,
workflow_run_attempt,
all_tests,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3")
parser.add_argument(
"--workflow-run-id",
type=int,
required=True,
help="id of the workflow to get artifacts from",
)
parser.add_argument(
"--workflow-run-attempt",
type=int,
required=True,
help="which retry of the workflow this is",
)
parser.add_argument(
"--repo",
type=str,
required=True,
help="which GitHub repo this workflow run belongs to",
)
args = parser.parse_args()
main(args.repo, args.workflow_run_id, args.workflow_run_attempt)