| import argparse |
| import sys |
| from pathlib import Path |
| |
| from pytest_caching_utils import ( |
| download_pytest_cache, |
| GithubRepo, |
| PRIdentifier, |
| upload_pytest_cache, |
| ) |
| |
| TEMP_DIR = "./tmp" # a backup location in case one isn't provided |
| |
| |
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Upload this job's the pytest cache to S3" |
| ) |
| |
| mode = parser.add_mutually_exclusive_group(required=True) |
| mode.add_argument( |
| "--upload", action="store_true", help="Upload the pytest cache to S3" |
| ) |
| mode.add_argument( |
| "--download", |
| action="store_true", |
| help="Download the pytest cache from S3, merging it with any local cache", |
| ) |
| |
| parser.add_argument( |
| "--cache_dir", |
| required=True, |
| help="Path to the folder pytest uses for its cache", |
| ) |
| parser.add_argument("--pr_identifier", required=True, help="A unique PR identifier") |
| parser.add_argument( |
| "--job_identifier", |
| required=True, |
| help="A unique job identifier that should be the same for all runs of job", |
| ) |
| parser.add_argument( |
| "--sha", required="--upload" in sys.argv, help="SHA of the commit" |
| ) # Only required for upload |
| parser.add_argument( |
| "--test_config", required="--upload" in sys.argv, help="The test config" |
| ) # Only required for upload |
| parser.add_argument( |
| "--shard", required="--upload" in sys.argv, help="The shard id" |
| ) # Only required for upload |
| |
| parser.add_argument( |
| "--repo", |
| required=False, |
| help="The github repository we're running in, in the format 'owner/repo-name'", |
| ) |
| parser.add_argument( |
| "--temp_dir", required=False, help="Directory to store temp files" |
| ) |
| parser.add_argument( |
| "--bucket", required=False, help="The S3 bucket to upload the cache to" |
| ) |
| |
| args = parser.parse_args() |
| |
| return args |
| |
| |
| def main() -> None: |
| args = parse_args() |
| |
| pr_identifier = PRIdentifier(args.pr_identifier) |
| print(f"PR identifier for `{args.pr_identifier}` is `{pr_identifier}`") |
| |
| repo = GithubRepo.from_string(args.repo) |
| cache_dir = Path(args.cache_dir) |
| if args.temp_dir: |
| temp_dir = Path(args.temp_dir) |
| else: |
| temp_dir = Path(TEMP_DIR) |
| |
| if args.upload: |
| print(f"Uploading cache with args {args}") |
| |
| # verify the cache dir exists |
| if not cache_dir.exists(): |
| print(f"The pytest cache dir `{cache_dir}` does not exist. Skipping upload") |
| return |
| |
| upload_pytest_cache( |
| pr_identifier=pr_identifier, |
| repo=repo, |
| job_identifier=args.job_identifier, |
| sha=args.sha, |
| test_config=args.test_config, |
| shard=args.shard, |
| cache_dir=cache_dir, |
| bucket=args.bucket, |
| temp_dir=temp_dir, |
| ) |
| |
| if args.download: |
| print(f"Downloading cache with args {args}") |
| download_pytest_cache( |
| pr_identifier=pr_identifier, |
| repo=repo, |
| job_identifier=args.job_identifier, |
| dest_cache_dir=cache_dir, |
| bucket=args.bucket, |
| temp_dir=temp_dir, |
| ) |
| |
| |
| if __name__ == "__main__": |
| main() |