| # Owner(s): ["oncall: package/deploy"] |
| |
| import torch |
| from torch.fx import wrap |
| |
| wrap("a_non_torch_leaf") |
| |
| |
| class ModWithSubmod(torch.nn.Module): |
| def __init__(self, script_mod): |
| super().__init__() |
| self.script_mod = script_mod |
| |
| def forward(self, x): |
| return self.script_mod(x) |
| |
| |
| class ModWithTensor(torch.nn.Module): |
| def __init__(self, tensor): |
| super().__init__() |
| self.tensor = tensor |
| |
| def forward(self, x): |
| return self.tensor * x |
| |
| |
| class ModWithSubmodAndTensor(torch.nn.Module): |
| def __init__(self, tensor, sub_mod): |
| super().__init__() |
| self.tensor = tensor |
| self.sub_mod = sub_mod |
| |
| def forward(self, x): |
| return self.sub_mod(x) + self.tensor |
| |
| |
| class ModWithTwoSubmodsAndTensor(torch.nn.Module): |
| def __init__(self, tensor, sub_mod_0, sub_mod_1): |
| super().__init__() |
| self.tensor = tensor |
| self.sub_mod_0 = sub_mod_0 |
| self.sub_mod_1 = sub_mod_1 |
| |
| def forward(self, x): |
| return self.sub_mod_0(x) + self.sub_mod_1(x) + self.tensor |
| |
| |
| class ModWithMultipleSubmods(torch.nn.Module): |
| def __init__(self, mod1, mod2): |
| super().__init__() |
| self.mod1 = mod1 |
| self.mod2 = mod2 |
| |
| def forward(self, x): |
| return self.mod1(x) + self.mod2(x) |
| |
| |
| class SimpleTest(torch.nn.Module): |
| def forward(self, x): |
| x = a_non_torch_leaf(x, x) |
| return torch.relu(x + 3.0) |
| |
| |
| def a_non_torch_leaf(a, b): |
| return a + b |