Reordering tests experiment (#106347)
Companion with https://github.com/pytorch/test-infra/pull/4424
Uses the file rating generated by the test infra PR to re order tests. For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum.
A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now.
Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests. Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards.
I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106347
Approved by: https://github.com/ZainRizvi
diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat
index 5e74062..57f59fe 100644
--- a/.ci/pytorch/win-test-helpers/build_pytorch.bat
+++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat
@@ -128,6 +128,7 @@
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
python tools/stats/export_test_times.py
copy /Y ".pytorch-test-times.json" "%PYTORCH_FINAL_PACKAGE_DIR%"
+ copy /Y ".pytorch-test-file-ratings.json" "%PYTORCH_FINAL_PACKAGE_DIR%"
:: Also save build/.ninja_log as an artifact
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"
diff --git a/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat b/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat
index c18151d..7277de3 100644
--- a/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat
+++ b/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat
@@ -2,6 +2,7 @@
echo Copying over test times file
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
+copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"
pushd test
diff --git a/.ci/pytorch/win-test-helpers/test_python_shard.bat b/.ci/pytorch/win-test-helpers/test_python_shard.bat
index 5313bc0..ec7e78b 100644
--- a/.ci/pytorch/win-test-helpers/test_python_shard.bat
+++ b/.ci/pytorch/win-test-helpers/test_python_shard.bat
@@ -23,6 +23,7 @@
echo Copying over test times file
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
+copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"
echo Run nn tests
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 5cb89ac..36149c4 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -652,7 +652,7 @@
- run:
name: Archive artifacts into zip
command: |
- zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
+ zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
cp artifacts.zip /Users/distiller/workspace
- persist_to_workspace:
diff --git a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml
index f03e173..0a2aee3 100644
--- a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml
+++ b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml
@@ -177,7 +177,7 @@
- run:
name: Archive artifacts into zip
command: |
- zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
+ zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
cp artifacts.zip /Users/distiller/workspace
- persist_to_workspace:
diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml
index 269260c3..7031e4e 100644
--- a/.github/workflows/_linux-build.yml
+++ b/.github/workflows/_linux-build.yml
@@ -170,7 +170,7 @@
- name: Archive artifacts into zip
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
run: |
- zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
+ zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json .pytorch-test-file-ratings.json
- name: Store PyTorch Build Artifacts on S3
uses: seemethere/upload-artifact-s3@v5
diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml
index 9ba093e..2585709 100644
--- a/.github/workflows/_mac-build.yml
+++ b/.github/workflows/_mac-build.yml
@@ -182,7 +182,7 @@
- name: Archive artifacts into zip
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
run: |
- zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
+ zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
- name: Store PyTorch Build Artifacts on GHA
uses: actions/upload-artifact@v3
diff --git a/.gitignore b/.gitignore
index 9ffab8f..424cc4b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,6 +19,7 @@
**/.pytorch-disabled-tests.json
**/.pytorch-slow-tests.json
**/.pytorch-test-times.json
+**/.pytorch-test-file-ratings.json
*/*.pyc
*/*.so*
*/**/__pycache__
diff --git a/test/run_test.py b/test/run_test.py
index cd62621..2c27197 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -11,9 +11,10 @@
import subprocess
import sys
import tempfile
+import time
from datetime import datetime
from distutils.version import LooseVersion
-from typing import Any, cast, Dict, List, Optional, Union
+from typing import Any, cast, Dict, List, NamedTuple, Optional, Union
import pkg_resources
@@ -40,11 +41,11 @@
# using tools/ to optimize test run.
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.export_test_times import TEST_TIMES_FILE
+ from tools.stats.upload_stats_lib import emit_metric
from tools.testing.test_selections import (
calculate_shards,
get_reordered_tests,
get_test_case_configs,
- log_time_savings,
NUM_PROCS,
ShardedTest,
THRESHOLD,
@@ -1278,7 +1279,9 @@
return selected_tests
-def must_serial(file: str) -> bool:
+def must_serial(file: Union[str, ShardedTest]) -> bool:
+ if isinstance(file, ShardedTest):
+ file = file.name
return (
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
@@ -1408,20 +1411,10 @@
)
selected_tests = [parse_test_module(x) for x in selected_tests]
+ return selected_tests
- # sharding
- which_shard, num_shards = 1, 1
- if options.shard:
- assert len(options.shard) == 2, "Unexpected shard format"
- assert min(options.shard) > 0, "Shards must be positive numbers"
- which_shard, num_shards = options.shard
- assert (
- which_shard <= num_shards
- ), "Selected shard must be less than or equal to total number of shards"
- assert num_shards <= len(
- selected_tests
- ), f"Number of shards must be less than {len(selected_tests)}"
+def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
# Download previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
if os.path.exists(path):
@@ -1434,14 +1427,35 @@
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
+ return {}
else:
print("Found test time stats from artifacts")
+ return test_file_times[test_config]
+
+
+def do_sharding(
+ options,
+ selected_tests: List[str],
+ test_file_times: Dict[str, float],
+ sort_by_time: bool = True,
+) -> List[ShardedTest]:
+ which_shard, num_shards = 1, 1
+ if options.shard:
+ assert len(options.shard) == 2, "Unexpected shard format"
+ assert min(options.shard) > 0, "Shards must be positive numbers"
+ which_shard, num_shards = options.shard
+ assert (
+ which_shard <= num_shards
+ ), "Selected shard must be less than or equal to total number of shards"
if HAVE_TEST_SELECTION_TOOLS:
# Do sharding
- test_file_times_config = test_file_times.get(test_config, {})
shards = calculate_shards(
- num_shards, selected_tests, test_file_times_config, must_serial=must_serial
+ num_shards,
+ selected_tests,
+ test_file_times,
+ must_serial=must_serial,
+ sort_by_time=sort_by_time,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
@@ -1449,9 +1463,14 @@
return selected_tests
+class TestFailure(NamedTuple):
+ test: str
+ message: str
+
+
def run_test_module(
test: Union[ShardedTest, str], test_directory: str, options
-) -> Optional[str]:
+) -> Optional[TestFailure]:
maybe_set_hip_visible_devies()
# Printing the date here can help diagnose which tests are slow
@@ -1472,39 +1491,24 @@
# return code -N, where N is the signal number.
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
message += f" Received signal: {signal_name}"
- return message
+ return TestFailure(test, message)
def run_tests(
- selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
+ selected_tests: List[ShardedTest],
+ test_directory: str,
+ options,
+ failures: List[TestFailure],
) -> None:
- failure_messages = []
-
if len(selected_tests) == 0:
- print_to_stderr(f"No tests in group `{group_name}`")
- return failure_messages
+ return
# parallel = in parallel with other files
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
- selected_tests_parallel = [
- x
- for x in selected_tests
- if not must_serial(x.name if isinstance(x, ShardedTest) else x)
- ]
+ selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
selected_tests_serial = [
x for x in selected_tests if x not in selected_tests_parallel
]
- print(f"TEST GROUP: {group_name}")
- print_to_stderr(
- "parallel (file granularity) tests :\n {}".format(
- "\n".join(str(x) for x in selected_tests_parallel)
- )
- )
- print_to_stderr(
- "serial (file granularity) tests:\n {}".format(
- "\n ".join(str(x) for x in selected_tests_serial)
- )
- )
# See Note [ROCm parallel CI testing]
pool = get_context("spawn").Pool(
@@ -1523,15 +1527,15 @@
# Take the conftest file from the test directory
shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file)
- def handle_error_messages(err_message):
- if err_message is None:
+ def handle_error_messages(failure: Optional[TestFailure]):
+ if failure is None:
return False
- failure_messages.append(err_message)
- print_to_stderr(err_message)
+ failures.append(failure)
+ print_to_stderr(failure.message)
return True
- def parallel_test_completion_callback(err_message):
- test_failed = handle_error_messages(err_message)
+ def parallel_test_completion_callback(failure):
+ test_failed = handle_error_messages(failure)
if (
test_failed
and not options.continue_through_error
@@ -1557,10 +1561,10 @@
if (
not options.continue_through_error
and not RERUN_DISABLED_TESTS
- and len(failure_messages) != 0
+ and len(failures) != 0
):
raise RuntimeError(
- "\n".join(failure_messages)
+ "\n".join(x.message for x in failures)
+ "\n\nTip: You can keep running tests even on failure by "
"passing --keep-going to run_test.py.\n"
"If running on CI, add the 'keep-going' label to "
@@ -1571,20 +1575,20 @@
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
- err_message = run_test_module(test, test_directory, options_clone)
- test_failed = handle_error_messages(err_message)
+ failure = run_test_module(test, test_directory, options_clone)
+ test_failed = handle_error_messages(failure)
if (
test_failed
and not options.continue_through_error
and not RERUN_DISABLED_TESTS
):
- raise RuntimeError(err_message)
+ raise RuntimeError(failure.message)
finally:
pool.terminate()
pool.join()
- return failure_messages
+ return
def check_pip_packages() -> None:
@@ -1611,30 +1615,47 @@
test_directory = str(REPO_ROOT / "test")
selected_tests = get_selected_tests(options)
- if options.verbose:
- print_to_stderr(
- "Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
- )
-
- if options.dry_run:
- return
-
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
shell(["coverage", "erase"])
prioritized_tests = []
- remaining_tests = selected_tests
+ general_tests = selected_tests
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
- (prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
- log_time_savings(
- selected_tests,
- prioritized_tests,
- is_serial_test_fn=must_serial,
- num_procs=NUM_PROCS,
- )
-
# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)
+ (prioritized_tests, general_tests) = get_reordered_tests(general_tests)
+
+ metrics_dict = {
+ "prioritized_tests": prioritized_tests,
+ "general_tests": general_tests,
+ "cpp": options.cpp,
+ }
+
+ test_times_dict = download_test_times(TEST_TIMES_FILE)
+ prioritized_tests = do_sharding(
+ options, prioritized_tests, test_times_dict, sort_by_time=False
+ )
+ general_tests = do_sharding(options, general_tests, test_times_dict)
+
+ if options.verbose:
+
+ def print_tests(category, tests):
+ tests_str = "\n ".join(str(x) for x in tests)
+ print_to_stderr(f"{category} tests:\n {tests_str}")
+
+ print_tests(
+ "Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)]
+ )
+ print_tests(
+ "Prioritized serial", [x for x in prioritized_tests if must_serial(x)]
+ )
+ print_tests(
+ "General parallel", [x for x in general_tests if not must_serial(x)]
+ )
+ print_tests("General serial", [x for x in general_tests if must_serial(x)])
+
+ if options.dry_run:
+ return
if options.dynamo:
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
@@ -1646,17 +1667,17 @@
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
- failure_messages = []
-
+ prioritized_failures: List[TestFailure] = []
+ general_failures: List[TestFailure] = []
+ start_time = time.time()
# First run the prioritized tests, then the remaining tests.
try:
- failure_messages = run_tests(
- prioritized_tests, test_directory, options, "Prioritized tests"
- )
-
- failure_messages += run_tests(
- remaining_tests, test_directory, options, "General tests"
- )
+ run_tests(prioritized_tests, test_directory, options, prioritized_failures)
+ metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures]
+ metrics_dict["general_start_time"] = time.time() - start_time
+ run_tests(general_tests, test_directory, options, general_failures)
+ metrics_dict["general_end_time"] = time.time() - start_time
+ metrics_dict["all_failures"] = [x.test for x in general_failures]
finally:
if options.coverage:
@@ -1671,8 +1692,12 @@
if not PYTORCH_COLLECT_COVERAGE:
cov.html_report()
- if len(failure_messages) != 0:
- for err in failure_messages:
+ if IS_CI and HAVE_TEST_SELECTION_TOOLS:
+ emit_metric("td_experiment_1", metrics_dict)
+
+ all_failures = prioritized_failures + general_failures
+ if len(all_failures) != 0:
+ for _, err in all_failures:
print_to_stderr(err)
# A disabled test is expected to fail, so there is no need to report a failure here
diff --git a/tools/stats/export_test_times.py b/tools/stats/export_test_times.py
index 4554f54..6e60158 100644
--- a/tools/stats/export_test_times.py
+++ b/tools/stats/export_test_times.py
@@ -3,14 +3,16 @@
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
sys.path.append(str(REPO_ROOT))
-from tools.stats.import_test_stats import get_test_times
+from tools.stats.import_test_stats import get_test_file_ratings, get_test_times
TEST_TIMES_FILE = ".pytorch-test-times.json"
+TEST_FILE_RATINGS_FILE = ".pytorch-test-file-ratings.json"
def main() -> None:
print(f"Exporting test times from test-infra to {TEST_TIMES_FILE}")
get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE)
+ get_test_file_ratings(str(REPO_ROOT), filename=TEST_FILE_RATINGS_FILE)
if __name__ == "__main__":
diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py
index 28d8ee0..a0c0190 100644
--- a/tools/stats/import_test_stats.py
+++ b/tools/stats/import_test_stats.py
@@ -20,6 +20,7 @@
SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
+
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
@@ -116,3 +117,12 @@
except Exception:
print("Couldn't download test skip set, leaving all tests enabled...")
return {}
+
+
+def get_test_file_ratings(dirpath: str, filename: str) -> Optional[Dict[str, Any]]:
+ url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/file_test_rating.json"
+ try:
+ return fetch_and_cache(dirpath, filename, url, lambda x: x)
+ except Exception:
+ print("Couldn't download test file ratings file, not reordering...")
+ return {}
diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py
index dd48e78..ab8d873 100644
--- a/tools/stats/upload_stats_lib.py
+++ b/tools/stats/upload_stats_lib.py
@@ -263,7 +263,7 @@
value = os.environ.get(self.env_var)
if value is None and self.required:
raise ValueError(
- f"Missing {self.name}. Please set the {self.env_var}"
+ f"Missing {self.name}. Please set the {self.env_var} "
"environment variable to pass in this value."
)
if self.type_conversion_fn:
diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py
index 04f5f88..c0f646c 100644
--- a/tools/test/test_test_selections.py
+++ b/tools/test/test_test_selections.py
@@ -394,28 +394,24 @@
"tools.testing.test_selections._get_modified_tests",
return_value={"test2", "test4"},
)
+ @mock.patch(
+ "tools.testing.test_selections._get_file_rating_tests", return_value=["test1"]
+ )
def test_get_reordered_tests(
- self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
+ self,
+ mock_get_prev_failing_tests: Any,
+ mock_get_modified_tests: Any,
+ mock_get_file_rating_tests: Any,
) -> None:
- tests = [
- ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
- ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
- ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
- ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
- ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
- ]
+ tests = ["test1", "test2", "test3", "test4", "test5"]
- expected_prioritized_tests = {"test4", "test2"}
- expected_remaining_tests = {"test1", "test3", "test5"}
+ expected_prioritized_tests = ["test4", "test2", "test1"]
+ expected_remaining_tests = {"test3", "test5"}
prioritized_tests, remaining_tests = get_reordered_tests(tests)
- # Just want to check the names of the tests
- prioritized_tests_name = {test.name for test in prioritized_tests}
- remaining_tests_name = {test.name for test in remaining_tests}
-
- self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
- self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
+ self.assertListEqual(expected_prioritized_tests, prioritized_tests)
+ self.assertSetEqual(expected_remaining_tests, set(remaining_tests))
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
tests = [
diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py
index 76f841b..57ba611 100644
--- a/tools/testing/test_selections.py
+++ b/tools/testing/test_selections.py
@@ -3,16 +3,20 @@
import math
import os
import subprocess
+from collections import defaultdict
from pathlib import Path
-from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
+from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple
from warnings import warn
from tools.shared.logging_utils import duration_to_str, pluralize
+from tools.stats.export_test_times import TEST_FILE_RATINGS_FILE
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
from tools.stats.upload_stats_lib import emit_metric
+REPO_ROOT = Path(__file__).resolve().parent.parent.parent
+
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
@@ -81,8 +85,8 @@
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
for test in tests:
- duration = test_file_times[test]
- if duration > THRESHOLD:
+ duration = test_file_times.get(test, None)
+ if duration and duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
sharded_tests.append(
@@ -98,20 +102,24 @@
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
+ sort_by_time: bool = True,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
- known_tests = [x for x in tests if x in test_file_times]
- unknown_tests: List[str] = [x for x in tests if x not in known_tests]
+ known_tests = tests
+ unknown_tests = []
- sorted_tests = sorted(
- get_with_pytest_shard(known_tests, test_file_times),
- key=lambda j: j.get_time(),
- reverse=True,
- )
+ if sort_by_time:
+ known_tests = [x for x in tests if x in test_file_times]
+ unknown_tests = [x for x in tests if x not in known_tests]
+
+ known_tests = get_with_pytest_shard(known_tests, test_file_times)
+
+ if sort_by_time:
+ known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
- for test in sorted_tests:
+ for test in known_tests:
if must_serial(test.name):
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
@@ -127,7 +135,7 @@
return [job.convert_to_tuple() for job in sharded_jobs]
-def _query_changed_test_files() -> List[str]:
+def _query_changed_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
merge_base = (
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
@@ -186,7 +194,7 @@
def _get_modified_tests() -> Set[str]:
try:
- changed_files = _query_changed_test_files()
+ changed_files = _query_changed_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
# If unable to get changed files from git, quit without doing any sorting
@@ -271,76 +279,81 @@
return max_time_savings_sec
+def _get_file_rating_tests() -> List[str]:
+ path = REPO_ROOT / TEST_FILE_RATINGS_FILE
+ if not os.path.exists(path):
+ print(f"could not find path {path}")
+ return []
+ with open(path) as f:
+ test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
+ try:
+ changed_files = _query_changed_files()
+ except Exception as e:
+ warn(f"Can't query changed test files due to {e}")
+ return []
+ ratings: Dict[str, float] = defaultdict(float)
+ for file in changed_files:
+ for test_file, score in test_file_ratings.get(file, {}).items():
+ ratings[test_file] += score
+ prioritize = sorted(ratings, key=lambda x: ratings[x])
+ return prioritize
+
+
def get_reordered_tests(
- tests: List[ShardedTest],
-) -> Tuple[List[ShardedTest], List[ShardedTest]]:
+ tests: List[str],
+) -> Tuple[List[str], List[str]]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
+ prioritized_tests: List[str] = []
- def print_tests(tests: Set[str], test_group_description: str) -> None:
- if not tests:
+ def add_tests(tests_to_add: List[str], test_group_description: str) -> None:
+ if not tests_to_add:
return
print(f"{test_group_description}:")
- for test in tests:
- print(f" {test}")
+ for test in tests_to_add:
+ if test in tests:
+ print(f" {test}")
+ if test not in prioritized_tests:
+ prioritized_tests.append(test)
- prioritized_tests: Set[str] = set()
-
- pri_test = _get_previously_failing_tests()
- print_tests(
- pri_test, "If run, these tests will prioritized because they previously failed"
+ add_tests(
+ sorted(_get_previously_failing_tests()),
+ "If run, these tests will prioritized because they previously failed",
)
- prioritized_tests |= pri_test
- pri_test |= _get_modified_tests()
- print_tests(
- pri_test, "If run, these tests will be prioritized because they were modified"
+ add_tests(
+ sorted(_get_modified_tests()),
+ "If run, these tests will be prioritized because they were modified",
)
- prioritized_tests |= pri_test
- bring_to_front = []
- the_rest = []
+ add_tests(
+ _get_file_rating_tests(),
+ "If run, these tests will be preioritized for an experiment in TD",
+ )
- for test in tests:
- if test.name in prioritized_tests:
- bring_to_front.append(test)
- else:
- the_rest.append(test)
+ prioritized_tests = [x for x in prioritized_tests if x in tests]
+ the_rest = [x for x in tests if x not in prioritized_tests]
- if len(tests) != len(bring_to_front) + len(the_rest):
- print(
- f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
- f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
- )
- return ([], tests)
-
- prioritized_test_names = []
- remaining_test_names = []
- if bring_to_front:
+ if prioritized_tests:
test_cnt_str = pluralize(len(tests), "test")
- print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
-
- prioritized_test_names = [t.name for t in bring_to_front]
- print(f"Prioritized: {prioritized_test_names}")
- remaining_test_names = [t.name for t in the_rest]
- print(f"The Rest: {remaining_test_names}")
- else:
- print("Didn't find any tests to prioritize")
+ print(
+ f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
+ )
emit_metric(
"test_reordering_prioritized_tests",
{
- "prioritized_test_cnt": len(bring_to_front),
+ "prioritized_test_cnt": len(prioritized_tests),
"total_test_cnt": len(tests),
- "prioritized_tests": prioritized_test_names,
- "remaining_tests": remaining_test_names,
+ "prioritized_tests": prioritized_tests,
+ "remaining_tests": the_rest,
},
)
- return (bring_to_front, the_rest)
+ return (prioritized_tests, the_rest)
def get_test_case_configs(dirpath: str) -> None: