blob: 9360eee57bff4bc12427905fa78f01f85588e8e6 [file] [log] [blame]
import torch
import torch.jit
import torch.nn as nn
import unittest
from torch.autograd import Variable, Function
from common import TestCase, run_tests
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
try:
import toffee
import google.protobuf.text_format
HAS_TOFFEE = True
except ImportError:
HAS_TOFFEE = False
toffee_only = unittest.skipIf(not HAS_TOFFEE, "no toffee support library")
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
class TestJit(TestCase):
maxDiff = None
def test_simple(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
trace = torch._C._tracer_enter((x, y))
z = torch.sigmoid(torch.tanh(x * (x + y)))
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_init(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
def assertToffeeExpected(self, binary_pb, subname=None):
graph_def = toffee.GraphProto.FromString(binary_pb)
self.assertExpected(google.protobuf.text_format.MessageToString(graph_def, float_format='.15g'), subname)
@toffee_only
def test_export(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
trace = torch._C._tracer_enter((x, y))
z = -torch.sigmoid(torch.tanh(x * (x + y)))
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
self.assertToffeeExpected(trace.export())
@toffee_only
def test_export_view(self):
x = Variable(torch.Tensor([0]), requires_grad=True)
trace, _ = torch.jit.record_trace(lambda x: x.view(1, 1), x)
self.assertToffeeExpected(trace.export())
@toffee_only
def test_export_data(self):
x = Variable(torch.Tensor([[1, 2], [3, 4]]), requires_grad=True)
y = Variable(torch.Tensor([[1, 2], [3, 4]]), requires_grad=True)
trace, _ = torch.jit.record_trace(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), x, y)
initializers = [x.data]
self.assertToffeeExpected(trace.export(initializers))
@toffee_only
def test_export_data_other_types(self):
x = Variable(torch.LongTensor([[1, 2], [3, 4]]), requires_grad=True)
y = Variable(torch.LongTensor([[1, 2], [3, 4]]), requires_grad=True)
trace, _ = torch.jit.record_trace(lambda x, y: x * y + x, x, y)
initializers = [x.data]
self.assertToffeeExpected(trace.export(initializers))
def test_lstm(self):
# Careful: don't use fused backend (enabled with CUDA)
# Pasted from test_LSTM_cell
input = Variable(torch.randn(3, 10))
hx = Variable(torch.randn(3, 20))
cx = Variable(torch.randn(3, 20))
trace, _ = torch.jit.record_trace(
nn.LSTMCell(10, 20), input, (hx, cx))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_init(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
def test_function_as_argument(self):
# Careful: don't use fused backend (enabled with CUDA)
# Pasted from test_LSTM_cell
input = Variable(torch.randn(3, 10))
hx = Variable(torch.randn(3, 20))
cx = Variable(torch.randn(3, 20))
lstm = nn.LSTMCell(10, 20)
def a_function(a, b):
return lstm(a, b)
trace, _ = torch.jit.record_trace(
a_function, input, (hx, cx), parameters=lstm.parameters())
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_init(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
def test_verify(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
traced = torch.jit.traced(
doit, enabled=True, verify=True, time=True, optimize=False)
z = traced(x, y)
z2 = traced(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
def test_traced_function(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
traced = torch.jit.traced(doit)
z = traced(x, y)
z2 = traced(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
def test_disabled_traced_function(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
traced = torch.jit.traced(doit, enabled=False)
z = traced(x, y)
z2 = traced(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
def test_traced_module(self):
input = Variable(torch.randn(3, 10))
hx = Variable(torch.randn(3, 20))
cx = Variable(torch.randn(3, 20))
lstm = nn.LSTMCell(10, 20)
lstm = torch.jit.traced(lstm, verify=True)
out = lstm(input, (hx, cx))
out2 = lstm(input, (hx, cx))
self.assertEqual(out, out2)
def test_autograd_closure(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
trace = torch._C._tracer_enter((x, y))
z = torch.sigmoid(x * (x + y))
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
torch._C._tracer_exit((z, w))
torch._C._jit_pass_lint(trace)
(z * w).backward()
torch._C._jit_pass_dce(trace)
torch._C._jit_pass_lint(trace)
x_grad = x.grad.data.clone()
x.grad.data.zero_()
function = torch._C._jit_createAutogradClosure(trace)
torch._C._jit_pass_lint(trace)
z2, w2 = function()(x, y)
(z2 * w2).backward()
self.assertEqual(z, z2)
self.assertEqual(w, w2)
self.assertEqual(x.grad.data, x_grad)
def test_constant(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
trace = torch._C._tracer_enter((x,))
y = Variable(torch.diag(torch.Tensor([2, 2])))
z = x.matmul(y)
torch._C._tracer_exit((z,))
function = torch._C._jit_createAutogradClosure(trace)
z2 = function()(x)
self.assertEqual(z, z2)
y.data.fill_(1000) # make sure the data has been cloned
x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
z3 = function()(x2)
self.assertEqual(z3.data, torch.ones(2, 2) * 4)
def test_c_function(self):
x = Variable(torch.randn(1, 3, 10, 10))
m = nn.Conv2d(3, 8, 3, 1)
trace = torch._C._tracer_enter((x,) + tuple(m.parameters()))
y = m(x)
torch._C._tracer_exit((y,))
self.assertExpected(str(trace))
def test_legacy_fail(self):
class Legacy(Function):
def forward(self, x):
return x
def backward(self, grad_output):
return grad_output
x = Variable(torch.Tensor([0]), requires_grad=True)
trace = torch._C._tracer_enter((x,))
self.assertRaises(RuntimeError, lambda: Legacy()(x))
torch._C._tracer_exit((x,))
def test_inplace_transplant(self):
x = Variable(torch.Tensor([0]), requires_grad=True)
trace = torch._C._tracer_enter((x,))
y = x.clone()
y.add_(2)
y.add_(3)
torch._C._tracer_exit((y,))
self.assertExpected(str(trace))
def test_backward(self):
a = Variable(torch.randn(2, 2), requires_grad=True)
b = Variable(torch.randn(2, 2), requires_grad=True)
x = a
y = a * b
trace = torch._C._tracer_enter((x, y))
z = y * 2 * x
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
# Run first backward
grad, = torch.autograd.grad(z, x, Variable(torch.ones(2, 2), requires_grad=True), create_graph=True)
torch._C._jit_pass_lint(trace)
# Run second backward
grad.sum().backward(create_graph=True)
torch._C._jit_pass_lint(trace)
# Run dead code elimination to remove unused trace nodes
torch._C._jit_pass_dce(trace)
self.assertExpected(str(trace))
def test_python_ir(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
traced, _ = torch.jit.record_trace(doit, x, y)
g = torch._C._jit_get_graph(traced)
g2 = torch._C.Graph()
g_to_g2 = {}
for node in g.inputs():
g_to_g2[node] = g2.addInput()
for node in g.nodes():
if node.kind() == "PythonOp":
n_ = g2.create(node.pyname(),
[g_to_g2[i] for i in node.inputs()]) \
.setType(node.typeOption()) \
.s_("note", "from_pyop") \
.i_("some_value", len(node.scalar_args()))
assert(n_.i("some_value") == len(node.scalar_args()))
else:
n_ = g2.createClone(node, lambda x: g_to_g2[x])
assert(n_.kindOf("Offset") == "i")
g_to_g2[node] = g2.appendNode(n_)
for node in g.outputs():
g2.registerOutput(g_to_g2[node])
t_node = g2.create("TensorTest").t_("a", torch.ones([2, 2]))
assert(t_node.attributeNames() == ["a"])
g2.appendNode(t_node)
assert(torch.equal(torch.ones([2, 2]), t_node.t("a")))
self.assertExpected(str(g2))
def test_cpp(self):
torch._C._jit_run_cpp_tests()
def test_batchnorm(self):
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.record_trace(nn.BatchNorm2d(2), x)
self.assertExpected(str(trace))
@toffee_only
def test_batchnorm_export(self):
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.record_trace(nn.BatchNorm2d(2), x)
self.assertToffeeExpected(trace.export())
def test_batchnorm_verify(self):
bn = torch.jit.traced(nn.BatchNorm2d(1), enabled=True, verify=True)
x = Variable(torch.randn(5, 1))
z = bn(x)
z2 = bn(x)
self.assertEqual(z, z2)
def test_conv(self):
x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.record_trace(nn.Conv2d(16, 13, 3, bias=False), x)
self.assertExpected(str(trace))
@toffee_only
def test_conv_export(self):
x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.record_trace(nn.Conv2d(16, 13, 3, bias=False), x)
self.assertToffeeExpected(trace.export())
@toffee_only
def test_maxpool_export(self):
x = Variable(torch.randn(20, 16, 50))
trace, _ = torch.jit.record_trace(nn.MaxPool1d(3, stride=2), x)
self.assertToffeeExpected(trace.export())
@skipIfNoTorchVision
def test_alexnet(self):
x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.record_trace(torchvision.models.AlexNet(), x)
self.assertExpected(str(trace))
# NB: Purposely NOT testing protobuf export here
if __name__ == '__main__':
run_tests()