blob: cd5129696e4143b328e46c253d97193d33d9bf2f [file] [log] [blame]
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List
from warnings import warn
from tools.testing.target_determination.heuristics.interface import (
HeuristicInterface,
TestPrioritizations,
)
from tools.testing.target_determination.heuristics.utils import (
normalize_ratings,
query_changed_files,
)
from tools.testing.test_run import TestRun
REPO_ROOT = Path(__file__).parents[3]
keyword_synonyms: Dict[str, List[str]] = {
"amp": ["mixed_precision"],
"quant": ["quantized", "quantization", "quantize"],
"decomp": ["decomposition", "decompositions"],
"numpy": ["torch_np", "numpy_tests"],
"ops": ["opinfo"],
}
not_keyword = [
"torch",
"test",
"tests",
"util",
"utils",
"func",
"src",
"c",
"ns",
"tools",
"internal",
]
custom_matchers: Dict[str, Callable[[str], bool]] = {
"nn": lambda x: "nn" in x.replace("onnx", "_"),
"c10": lambda x: "c10" in x.replace("c10d", "_"),
}
@lru_cache(maxsize=1)
def get_keywords(file: str) -> List[str]:
keywords = []
for folder in Path(file).parts[:-1]:
folder = sanitize_folder_name(folder)
keywords.append(folder)
return [kw for kw in keywords if kw not in not_keyword]
def sanitize_folder_name(folder_name: str) -> str:
if folder_name.startswith("_"):
folder_name = folder_name[1:]
for syn_rep, syns in keyword_synonyms.items():
if folder_name in syns or folder_name == syn_rep:
return syn_rep
return folder_name
def file_matches_keyword(file: str, keyword: str) -> bool:
keywords = get_keywords(file)
return (
keyword in keywords
or any(
syn in keywords or syn in file for syn in keyword_synonyms.get(keyword, [])
)
or custom_matchers.get(keyword, lambda x: keyword in x)(file) # type: ignore[no-untyped-call]
)
class Filepath(HeuristicInterface):
# Heuristic based on folders in the file path. Takes each folder of each
# changed file and attempts to find matches based on those folders
def __init__(self, **kwargs: Dict[str, Any]) -> None:
super().__init__(**kwargs)
def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations:
keyword_frequency: Dict[str, int] = defaultdict(int)
try:
changed_files = query_changed_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
changed_files = []
for cf in changed_files:
keywords = get_keywords(cf)
for keyword in keywords:
keyword_frequency[keyword] += 1
test_ratings: Dict[str, float] = defaultdict(float)
for test in tests:
for keyword, frequency in keyword_frequency.items():
if file_matches_keyword(test, keyword):
test_ratings[test] += frequency
test_ratings = {TestRun(k): v for (k, v) in test_ratings.items() if k in tests}
return TestPrioritizations(
tests, normalize_ratings(test_ratings, 0.25, min_value=0.125)
)