blob: 2f0a00556ab8f0ceac399eee8f2325cde9c33d7f [file] [log] [blame]
import os
import subprocess
from pathlib import Path
from typing import List, Set
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
def python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
prefix = f"test{os.path.sep}"
valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}
return valid_tests
def query_changed_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
merge_base = (
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
.decode()
.strip()
)
head = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
base_commit = merge_base
if base_commit == head:
# We are on the default branch, so check for changes since the last commit
base_commit = "HEAD^"
proc = subprocess.run(
["git", "diff", "--name-only", base_commit, "HEAD"], capture_output=True
)
if proc.returncode != 0:
raise RuntimeError("Unable to get changed files")
lines = proc.stdout.decode().strip().split("\n")
lines = [line.strip() for line in lines]
return lines