blob: ad94ee0925444aa7e8b1919fc5641990db56f1f5 [file] [log] [blame]
import torch
import torch.jit
import torch.nn as nn
import unittest
from torch.autograd import Variable
from common import TestCase, run_tests
class TestJit(TestCase):
def test_simple(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
trace = torch._C._tracer_enter((x, y))
z = torch.sigmoid(torch.tanh(x * (x + y)))
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_init(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
def test_lstm(self):
# Careful: don't use fused backend (enabled with CUDA)
# Pasted from test_LSTM_cell
input = Variable(torch.randn(3, 10))
hx = Variable(torch.randn(3, 20))
cx = Variable(torch.randn(3, 20))
lstm = torch.jit.trace_model(nn.LSTMCell(10, 20))
trace, _ = lstm(input, (hx, cx))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_init(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
def test_autograd_closure(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
trace = torch._C._tracer_enter((x, y))
z, _ = torch.max(x * (x + y), 0)
w = torch.abs(x * x * x + y)
torch._C._tracer_exit((z, w))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_init(trace)
torch._C._jit_pass_lint(trace)
closure = torch._C._jit_createAutogradClosure(trace)
z2, w2 = Variable._execution_engine.run_forward(closure, (x, y))
self.assertEqual(z, z2)
self.assertEqual(w, w2)
def test_constant(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
trace = torch._C._tracer_enter((x,))
y = Variable(torch.diag(torch.Tensor([2, 2])))
z = x.matmul(y)
torch._C._tracer_exit((z,))
closure = torch._C._jit_createAutogradClosure(trace)
z2, = Variable._execution_engine.run_forward(closure, (x,))
self.assertEqual(z, z2)
y.data.fill_(1000) # make sure the data has been cloned
x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
z3, = Variable._execution_engine.run_forward(closure, (x2,))
self.assertEqual(z3.data, torch.ones(2, 2) * 4)
def test_c_function(self):
x = Variable(torch.randn(1, 3, 10, 10))
m = nn.Conv2d(3, 8, 3, 1)
trace = torch._C._tracer_enter((x,) + tuple(m.parameters()))
y = m(x)
torch._C._tracer_exit((y,))
self.assertExpected(str(trace))
def test_cpp(self):
torch._C._jit_run_cpp_tests()
if __name__ == '__main__':
run_tests()