|  | #!/usr/bin/env python3 | 
|  | import argparse | 
|  | import os | 
|  |  | 
|  | from typing import Set | 
|  |  | 
|  |  | 
|  | # Note - hf and timm have their own version of this, torchbench does not | 
|  | # TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this... | 
|  | def model_names(filename: str) -> Set[str]: | 
|  | names = set() | 
|  | with open(filename, "r") as fh: | 
|  | lines = fh.readlines() | 
|  | lines = [line.rstrip() for line in lines] | 
|  | for line in lines: | 
|  | line_parts = line.split(" ") | 
|  | if len(line_parts) == 1: | 
|  | line_parts = line.split(",") | 
|  | model_name = line_parts[0] | 
|  | names.add(model_name) | 
|  | return names | 
|  |  | 
|  |  | 
|  | TIMM_MODEL_NAMES = model_names( | 
|  | os.path.join(os.path.dirname(__file__), "timm_models_list.txt") | 
|  | ) | 
|  | HF_MODELS_FILE_NAME = model_names( | 
|  | os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt") | 
|  | ) | 
|  | TORCHBENCH_MODELS_FILE_NAME = model_names( | 
|  | os.path.join(os.path.dirname(__file__), "all_torchbench_models_list.txt") | 
|  | ) | 
|  |  | 
|  | # timm <> HF disjoint | 
|  | assert TIMM_MODEL_NAMES.isdisjoint(HF_MODELS_FILE_NAME) | 
|  | # timm <> torch disjoint | 
|  | assert TIMM_MODEL_NAMES.isdisjoint(TORCHBENCH_MODELS_FILE_NAME) | 
|  | # torch <> hf disjoint | 
|  | assert TORCHBENCH_MODELS_FILE_NAME.isdisjoint(HF_MODELS_FILE_NAME) | 
|  |  | 
|  |  | 
|  | def parse_args(args=None): | 
|  | parser = argparse.ArgumentParser() | 
|  | parser.add_argument( | 
|  | "--only", | 
|  | help="""Run just one model from whichever model suite it belongs to. Or | 
|  | specify the path and class name of the model in format like: | 
|  | --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME> | 
|  |  | 
|  | Due to the fact that dynamo changes current working directory, | 
|  | the path should be an absolute path. | 
|  |  | 
|  | The class should have a method get_example_inputs to return the inputs | 
|  | for the model. An example looks like | 
|  | ``` | 
|  | class LinearModel(nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.linear = nn.Linear(10, 10) | 
|  |  | 
|  | def forward(self, x): | 
|  | return self.linear(x) | 
|  |  | 
|  | def get_example_inputs(self): | 
|  | return (torch.randn(2, 10),) | 
|  | ``` | 
|  | """, | 
|  | ) | 
|  | return parser.parse_known_args(args) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | args, unknown = parse_args() | 
|  | if args.only: | 
|  | name = args.only | 
|  | if name in TIMM_MODEL_NAMES: | 
|  | import timm_models | 
|  |  | 
|  | timm_models.timm_main() | 
|  | elif name in HF_MODELS_FILE_NAME: | 
|  | import huggingface | 
|  |  | 
|  | huggingface.huggingface_main() | 
|  | elif name in TORCHBENCH_MODELS_FILE_NAME: | 
|  | import torchbench | 
|  |  | 
|  | torchbench.torchbench_main() | 
|  | else: | 
|  | print(f"Illegal model name? {name}") | 
|  | exit(-1) | 
|  | else: | 
|  | import torchbench | 
|  |  | 
|  | torchbench.torchbench_main() | 
|  |  | 
|  | import huggingface | 
|  |  | 
|  | huggingface.huggingface_main() | 
|  |  | 
|  | import timm_models | 
|  |  | 
|  | timm_models.timm_main() |