| import argparse |
| import random |
| import time |
| from abc import abstractmethod |
| from typing import Any, Tuple |
| |
| from tqdm import tqdm # type: ignore[import-untyped] |
| |
| import torch |
| |
| |
| class BenchmarkRunner: |
| """ |
| BenchmarkRunner is a base class for all benchmark runners. It provides an interface to run benchmarks in order to |
| collect data with AutoHeuristic. |
| """ |
| |
| def __init__(self, name: str) -> None: |
| self.name = name |
| self.parser = argparse.ArgumentParser() |
| self.add_base_arguments() |
| self.args = None |
| |
| def add_base_arguments(self) -> None: |
| self.parser.add_argument( |
| "--device", |
| type=int, |
| default=None, |
| help="torch.cuda.set_device(device) will be used", |
| ) |
| self.parser.add_argument( |
| "--use-heuristic", |
| action="store_true", |
| help="Use learned heuristic instead of collecting data.", |
| ) |
| self.parser.add_argument( |
| "-o", |
| type=str, |
| default="ah_data.txt", |
| help="Path to file where AutoHeuristic will log results.", |
| ) |
| self.parser.add_argument( |
| "--num-samples", |
| type=int, |
| default=1000, |
| help="Number of samples to collect.", |
| ) |
| self.parser.add_argument( |
| "--num-reps", |
| type=int, |
| default=3, |
| help="Number of measurements to collect for each input.", |
| ) |
| |
| def run(self) -> None: |
| torch.set_default_device("cuda") |
| args = self.parser.parse_args() |
| if args.use_heuristic: |
| torch._inductor.config.autoheuristic_use = self.name |
| torch._inductor.config.autoheuristic_collect = "" |
| else: |
| torch._inductor.config.autoheuristic_use = "" |
| torch._inductor.config.autoheuristic_collect = self.name |
| torch._inductor.config.autoheuristic_log_path = args.o |
| if args.device is not None: |
| torch.cuda.set_device(args.device) |
| random.seed(time.time()) |
| self.main(args.num_samples, args.num_reps) |
| |
| @abstractmethod |
| def run_benchmark(self, *args: Any) -> None: |
| ... |
| |
| @abstractmethod |
| def create_input(self) -> Tuple[Any, ...]: |
| ... |
| |
| def main(self, num_samples: int, num_reps: int) -> None: |
| for _ in tqdm(range(num_samples)): |
| input = self.create_input() |
| for _ in range(num_reps): |
| self.run_benchmark(*input) |