| # Owner(s): ["oncall: package/deploy"] | |
| import torch | |
| try: | |
| from torchvision.models import resnet18 | |
| class TorchVisionTest(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.tvmod = resnet18() | |
| def forward(self, x): | |
| x = a_non_torch_leaf(x, x) | |
| return torch.relu(x + 3.0) | |
| except ImportError: | |
| pass | |
| def a_non_torch_leaf(a, b): | |
| return a + b |