| import torch |
| from torch.fx import symbolic_trace |
| |
| from torch.testing._internal.common_utils import TestCase, run_tests |
| |
| class TestFX(TestCase): |
| def test_graph_module(self): |
| class MySub(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.nn.Parameter(torch.rand(4, 3)) |
| |
| def forward(self, x): |
| return self.w + x |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lin = torch.nn.Linear(4, 3) |
| self.sub_mod = MySub() |
| self.w = torch.nn.Parameter(torch.rand(3)) |
| |
| def forward(self, A, B, c): |
| t = torch.sigmoid(A) + self.lin(c) |
| return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) |
| |
| m = MyModule() |
| gm = symbolic_trace(m) |
| |
| ms = torch.jit.script(gm) |
| |
| class M2(torch.nn.Module): |
| def forward(self, A): |
| m, idx = torch.max(A, 0) |
| return m + 1, idx + 1 |
| |
| m2 = M2() |
| gm2 = symbolic_trace(m2) |
| |
| class T(torch.nn.Module): |
| |
| def forward(self, A, b=4, *args, c=5, **kwargs): |
| x = A + 1 + args[0] + kwargs['3'] |
| return x |
| |
| t = T() |
| symbolic_trace(t) |
| |
| def test_fx_shifts(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, x): |
| return x << 3, x >> 3 |
| |
| input = torch.LongTensor(10).random_(0, 1024) |
| |
| m = MyModule() |
| ref_outs = m(input) |
| gm = symbolic_trace(m) |
| test_outs = gm(input) |
| |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_dict(self): |
| class MyDictMod(torch.nn.Module): |
| def forward(self, d): |
| return d['3'].relu(), {'4' : d['3'].neg()} |
| |
| input_dict = {'3': torch.rand(3, 4)} |
| m = MyDictMod() |
| ref_out = m(input_dict) |
| gm = symbolic_trace(m) |
| out = gm(input_dict) |
| |
| self.assertEqual(out, ref_out) |
| |
| if __name__ == '__main__': |
| run_tests() |