blob: 39f86a4821f303e04ef31fb49e2fdbc15362154d [file] [log] [blame]
import torch
from torch.fx import symbolic_trace
from torch.testing._internal.common_utils import TestCase, run_tests
class TestFX(TestCase):
def test_graph_module(self):
class MySub(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.rand(4, 3))
def forward(self, x):
return self.w + x
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(4, 3)
self.sub_mod = MySub()
self.w = torch.nn.Parameter(torch.rand(3))
def forward(self, A, B, c):
t = torch.sigmoid(A) + self.lin(c)
return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
m = MyModule()
gm = symbolic_trace(m)
ms = torch.jit.script(gm)
class M2(torch.nn.Module):
def forward(self, A):
m, idx = torch.max(A, 0)
return m + 1, idx + 1
m2 = M2()
gm2 = symbolic_trace(m2)
class T(torch.nn.Module):
def forward(self, A, b=4, *args, c=5, **kwargs):
x = A + 1 + args[0] + kwargs['3']
return x
t = T()
symbolic_trace(t)
def test_fx_shifts(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return x << 3, x >> 3
input = torch.LongTensor(10).random_(0, 1024)
m = MyModule()
ref_outs = m(input)
gm = symbolic_trace(m)
test_outs = gm(input)
self.assertEqual(ref_outs, test_outs)
def test_dict(self):
class MyDictMod(torch.nn.Module):
def forward(self, d):
return d['3'].relu(), {'4' : d['3'].neg()}
input_dict = {'3': torch.rand(3, 4)}
m = MyDictMod()
ref_out = m(input_dict)
gm = symbolic_trace(m)
out = gm(input_dict)
self.assertEqual(out, ref_out)
if __name__ == '__main__':
run_tests()