import torch | |
import torch.jit | |
from torch.autograd import Variable | |
from common import TestCase, run_tests | |
class TestJit(TestCase): | |
def test_simple(self): | |
x = Variable(torch.Tensor([0.4]), requires_grad=True) | |
y = Variable(torch.Tensor([0.7]), requires_grad=True) | |
torch._C._tracer_enter((x, y)) | |
z = torch.sigmoid(torch.tanh(x * (x + y))) | |
trace = torch._C._tracer_exit((z,)) | |
# TODO: Do something more automated here | |
print(trace) | |
return | |
# Re-enable this when the interpreter is back | |
zs = z._execution_engine.run_forward(trace, (x, y)) | |
self.assertEqual(z, zs) | |
# TODO: test that backwards works correctly | |
if __name__ == '__main__': | |
run_tests() |