blob: f6253da5d8cdc62ebad699cf00eec97684d65664 [file] [log] [blame]
import math
import os
import subprocess
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
THRESHOLD = 60 * 10 # 10 minutes
# See Note [ROCm parallel CI testing]
# Special logic for ROCm GHA runners to query number of GPUs available.
# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
# Must check for ROCm runner in another way. We look for /opt/rocm directory.
if os.path.exists("/opt/rocm") and not IS_MEM_LEAK_CHECK:
try:
# This is the same logic used in GHA health check, see .github/templates/common.yml.j2
lines = (
subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n")
)
count = 0
for line in lines:
if " gfx" in line:
count += 1
assert count > 0 # there must be at least 1 GPU
NUM_PROCS = count
except subprocess.CalledProcessError as e:
# The safe default for ROCm GHA runners is to run tests serially.
NUM_PROCS = 1
class ShardedTest(NamedTuple):
name: str
shard: int
num_shards: int
time: Optional[float]
def __str__(self) -> str:
return f"{self.name} {self.shard}/{self.num_shards}"
def get_time(self) -> float:
return self.time or 0
class ShardJob:
def __init__(self) -> None:
self.serial: List[ShardedTest] = []
self.parallel: List[ShardedTest] = []
def get_total_time(self) -> float:
procs = [0.0 for _ in range(NUM_PROCS)]
for test in self.parallel:
min_index = procs.index(min(procs))
procs[min_index] += test.get_time()
time = max(procs) + sum(test.get_time() for test in self.serial)
return time
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
return (self.get_total_time(), self.serial + self.parallel)
def get_with_pytest_shard(
tests: List[str], test_file_times: Dict[str, float]
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
for test in tests:
duration = test_file_times[test]
if duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
sharded_tests.append(
ShardedTest(test, i + 1, num_shards, duration / num_shards)
)
else:
sharded_tests.append(ShardedTest(test, 1, 1, duration))
return sharded_tests
def calculate_shards(
num_shards: int,
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
known_tests = [x for x in tests if x in test_file_times]
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
sorted_tests = sorted(
get_with_pytest_shard(known_tests, test_file_times),
key=lambda j: j.get_time(),
reverse=True,
)
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
for test in sorted_tests:
if must_serial(test.name):
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
else:
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.parallel.append(test)
# Round robin the unknown jobs starting with the smallest shard
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
for unknown_test in unknown_tests:
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
index = (index + 1) % num_shards
return [job.convert_to_tuple() for job in sharded_jobs]
def _query_changed_test_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'master')}"
cmd = ["git", "diff", "--name-only", default_branch, "HEAD"]
proc = subprocess.run(cmd, capture_output=True)
if proc.returncode != 0:
raise RuntimeError("Unable to get changed files")
lines = proc.stdout.decode().strip().split("\n")
lines = [line.strip() for line in lines]
return lines
def get_reordered_tests(tests: List[str]) -> List[str]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
prioritized_tests: List[str] = []
if len(prioritized_tests) == 0:
try:
changed_files = _query_changed_test_files()
except Exception:
# If unable to get changed files from git, quit without doing any sorting
return tests
prefix = f"test{os.path.sep}"
prioritized_tests = [
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
]
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
print("Prioritized test from test file changes.")
bring_to_front = []
the_rest = []
for test in tests:
if test in prioritized_tests:
bring_to_front.append(test)
else:
the_rest.append(test)
if len(tests) == len(bring_to_front) + len(the_rest):
print(
f"reordering tests for PR:\n"
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
)
return bring_to_front + the_rest
else:
print(
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return tests
def get_test_case_configs(dirpath: str) -> None:
get_slow_tests(dirpath=dirpath)
get_disabled_tests(dirpath=dirpath)