| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import logging |
| |
| import torch |
| |
| # pyre-ignore |
| from torchvision.models import ( # @manual |
| resnet18, |
| ResNet18_Weights, |
| resnet50, |
| ResNet50_Weights, |
| ) |
| |
| from ..model_base import EagerModelBase |
| |
| |
| class ResNet18Model(EagerModelBase): |
| def __init__(self): |
| pass |
| |
| def get_eager_model(self) -> torch.nn.Module: |
| logging.info("Loading torchvision resnet18 model") |
| # pyre-ignore |
| resnet18_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) |
| logging.info("Loaded torchvision resnet18 model") |
| return resnet18_model |
| |
| def get_example_inputs(self): |
| input_shape = (1, 3, 224, 224) |
| return (torch.randn(input_shape),) |
| |
| |
| class ResNet50Model(EagerModelBase): |
| def __init__(self): |
| pass |
| |
| def get_eager_model(self) -> torch.nn.Module: |
| logging.info("Loading torchvision resnet50 model") |
| # pyre-ignore |
| resnet50_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
| logging.info("Loaded torchvision resnet50 model") |
| return resnet50_model |
| |
| def get_example_inputs(self): |
| input_shape = (1, 3, 224, 224) |
| return (torch.randn(input_shape),) |