blob: a529dcb2249c140f1e66a0553812ea8156361b11 [file] [log] [blame]
import torch
import torch.jit
import torch.nn as nn
import torch.nn.functional as F
from contextlib import contextmanager
from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
from torch.autograd.function import traceable
from common import TestCase, run_tests, IS_WINDOWS
import io
import sys
import unittest
import inspect
import textwrap
import numpy as np
import tempfile
import shutil
from torch.jit.frontend import NotSupportedError
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
RUN_CUDA = torch.cuda.is_available()
if torch.cuda.is_available():
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
for d in range(torch.cuda.device_count()):
major = torch.cuda.get_device_capability(d)[0]
if (CUDA_VERSION < 8000 and major >= 6) or (CUDA_VERSION < 9000 and major >= 7):
RUN_CUDA = False
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
PY2 = sys.version_info[0] == 2
WINDOWS = sys.platform == 'win32'
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
def LSTMCellC(*args, **kwargs):
hy, cy = LSTMCell(*args, **kwargs)
return torch.cat((hy, cy))
class TestJit(TestCase):
maxDiff = None
@contextmanager
def assertCompiled(self, compiled_fn):
self.assertIsInstance(compiled_fn, torch._C.CompiledFunction)
hits, misses = compiled_fn.hits, compiled_fn.misses
yield
self.assertLess(hits, compiled_fn.hits)
self.assertEqual(misses, compiled_fn.misses)
def assertExpectedTrace(self, trace, *args, **kwargs):
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_dce(trace.graph())
torch._C._jit_pass_lint(trace.graph())
trace.set_graph(torch._C._jit_pass_canonicalize(trace.graph()))
torch._C._jit_pass_lint(trace.graph())
self.assertExpected(str(trace), *args, **kwargs)
def test_simple(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def f(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
self.assertExpectedTrace(trace)
# matmul is currently implemented as a native function, which
# exercises different codepaths in the JIT. The following two
# tests ensure that (1) matmul indeed traces into an atomic,
# native operation, and (2) the JIT knows how to run it
def test_matmul_native(self):
x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
trace, z = torch.jit.get_trace_graph(lambda x, y: x.matmul(y), (x, y), nderivs=0)
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_dce(trace.graph())
self.assertExpectedTrace(trace)
def test_matmul_native_run(self):
x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
@torch.jit.compile(nderivs=0)
def fn(x, y):
return x.matmul(y)
z = fn(x, y)
with self.assertCompiled(fn):
z2 = fn(x, y)
self.assertEqual(z, z2)
# index-2 is not implemented in interpreter
@unittest.expectedFailure
def test_index(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.LongTensor([0]), requires_grad=True)
@torch.jit.compile(nderivs=0)
def fn(x, y):
return x[y]
z = fn(x, y)
with self.assertCompiled(fn):
z2 = fn(x, y)
self.assertEqual(z, z2)
# Backwards tracing was broken for indexing by a constant,
# because it's internally implemented using as_strided,
# and we attempted to trace its derivative (which is not
# currently supported.) It currently works because
# slice() is now not marked as traceable.
def test_index_constant(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
@torch.jit.compile(nderivs=1)
def fn(x):
return x[0]
z = fn(x)
z.backward()
grad = x.grad.clone()
x.grad.zero_()
with self.assertCompiled(fn):
z2 = fn(x)
z2.backward()
grad2 = x.grad.clone()
self.assertEqual(z, z2)
self.assertEqual(grad, grad2)
def test_scopes(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def f(x, y):
out = x + y
with torch.jit.scope('Foo', out):
out = x * out
with torch.jit.scope('Bar', out):
out = torch.tanh(out)
out = torch.sigmoid(out)
return out
trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
self.assertExpectedTrace(trace)
def test_scopes_intermediate_node(self):
class Net(nn.Module):
def forward(self, x):
return F.log_softmax(x, dim=0)
net = Net()
t = Variable(torch.ones(2), requires_grad=True)
trace, _ = torch.jit.get_trace_graph(net, (t, ))
torch.onnx._optimize_trace(trace, False)
self.assertExpectedTrace(trace)
def test_scopes_identity_node(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
def forward(self, x):
x = self.features(x)
return x
model = Net()
t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
with torch.onnx.set_training(model, False):
trace, _ = torch.jit.get_trace_graph(model, (t, ))
torch.onnx._optimize_trace(trace, False)
self.assertExpectedTrace(trace)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_lstm_fusion(self):
input = Variable(torch.randn(3, 10).float().cuda())
hx = Variable(torch.randn(3, 20).float().cuda())
cx = Variable(torch.randn(3, 20).float().cuda())
module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
trace, _ = torch.jit.get_trace_graph(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_dce(trace.graph())
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_fuse(trace.graph())
self.assertExpectedTrace(trace)
def run_lstm_fusion(self, use_cuda):
def to_type(x):
x = x.float()
if use_cuda:
x = x.cuda()
return x
def rand_v(a, b):
return Variable(to_type(torch.randn(a, b)))
input = rand_v(3, 10)
hx = rand_v(3, 20)
cx = rand_v(3, 20)
module = to_type(nn.LSTMCell(10, 20)) # Just to allocate weights with correct sizes
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCell)
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
with self.assertCompiled(CompiledLSTMCell):
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
self.assertEqual(z, z2)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_run_lstm_fusion_cuda(self):
self.run_lstm_fusion(True)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
def test_run_lstm_fusion_cpu(self):
self.run_lstm_fusion(False)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_run_lstm_fusion_concat(self):
input = Variable(torch.randn(3, 10).float().cuda())
hx = Variable(torch.randn(3, 20).float().cuda())
cx = Variable(torch.randn(3, 20).float().cuda())
module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCellC)
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
with self.assertCompiled(CompiledLSTMCell):
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
self.assertEqual(z, z2)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_concat_fusion(self):
hx = Variable(torch.randn(3, 20).float().cuda())
cx = Variable(torch.randn(3, 20).float().cuda())
def Foo(hx, cx):
return torch.cat((hx + cx, hx * cx))
trace, _ = torch.jit.get_trace_graph(Foo, (hx, cx))
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_fuse(trace.graph())
self.assertExpectedTrace(trace)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_fusion_distribute(self):
def f(x, y):
z1, z2 = (x + y).chunk(2, dim=1)
return z1 * z2
x = Variable(torch.randn(4, 4).float().cuda())
y = Variable(torch.randn(4, 4).float().cuda())
trace, _ = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_dce(trace.graph())
self.assertExpectedTrace(trace, 'raw')
torch._C._jit_pass_fuse(trace.graph())
self.assertExpectedTrace(trace)
def test_arg_configurations(self):
"""Different arg configurations should trigger different traces"""
x = Variable(torch.FloatTensor(4, 4).uniform_())
x_double = Variable(x.data.double())
x_grad = Variable(x.data.clone(), requires_grad=True)
y = Variable(torch.randn(4))
configurations = [
(x,),
(x_double,),
(x_grad,),
(y,),
([x, x],),
([x, y],),
]
if torch.cuda.is_available():
x_cuda = Variable(x.data.cuda())
configurations += [
(x_cuda,),
([x, x_cuda],),
([x_cuda, x],),
([[x_cuda, x]],),
]
if torch.cuda.device_count() > 1:
x_cuda_1 = Variable(x.data.cuda(1))
configurations += [
(x_cuda_1,),
([x_cuda, x_cuda_1],),
]
@torch.jit.compile(nderivs=0)
def fn(*args):
in_vars, _ = torch._C._jit_flatten(args)
return in_vars[0] + 1
for i, config in enumerate(configurations):
self.assertFalse(fn.has_trace_for(*config))
fn(*config)
self.assertTrue(fn.has_trace_for(*config))
for unk_config in configurations[i + 1:]:
self.assertFalse(fn.has_trace_for(*unk_config))
self.assertEqual(fn.hits, 0)
def test_cse(self):
x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True)
y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True)
trace, inputs = torch._C._tracer_enter((x, y), 0)
def fn(x, y):
w = (x + y) * (x + y) * (x + y)
t = torch.tanh(w) + torch.tanh(w)
z = (x + y) * (x + y) * (x + y) + t
return z
z = fn(*inputs)
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_cse(trace.graph())
self.assertExpectedTrace(trace)
def test_compile_run_twice(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@torch.jit.compile(nderivs=0, optimize=False)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
z = doit(x, y)
with self.assertCompiled(doit):
z2 = doit(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_compile_addc(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda()
y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda()
@torch.jit.compile(nderivs=0)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y) + 1))
z = doit(x, y)
with self.assertCompiled(doit):
z2 = doit(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
self.assertEqual(z, z2)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
def test_compile_fuse_last_device(self):
max_device = torch.cuda.device_count() - 1
x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda(max_device)
y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda(max_device)
@torch.jit.compile(nderivs=0)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y) + 1))
z = doit(x, y)
with self.assertCompiled(doit):
z2 = doit(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
self.assertEqual(z, z2)
def test_traced_function(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@torch.jit.compile(nderivs=0)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
z = doit(x, y)
with self.assertCompiled(doit):
z2 = doit(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
def test_disabled_traced_function(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@torch.jit.compile(enabled=False)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
z = doit(x, y)
z2 = doit(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
def test_assign_traces(self):
"""Check that output Variables are assigned traces before they are saved."""
@traceable
class MyFn(Function):
@staticmethod
def forward(ctx, a):
out = a * 2
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad_a):
a, = ctx.saved_tensors
return a * grad_a
x = Variable(torch.randn(10, 10), requires_grad=True)
trace, out = torch.jit.get_trace_graph(MyFn.apply, x, nderivs=1)
out.sum().backward()
torch._C._jit_pass_dce(trace.graph())
self.assertExpectedTrace(trace)
def test_legacy_traced_module(self):
input = Variable(torch.randn(3, 10))
hx = Variable(torch.randn(3, 20))
cx = Variable(torch.randn(3, 20))
@torch.jit.compile(nderivs=0)
class MyLSTMCell(nn.LSTMCell):
pass
lstm = MyLSTMCell(10, 20)
out = lstm(input, (hx, cx))
with self.assertCompiled(lstm):
out2 = lstm(input, (hx, cx))
self.assertEqual(out, out2)
def test_autograd_closure(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
trace, inputs = torch._C._tracer_enter((x, y), 1)
def fn(x, y):
z = torch.sigmoid(x * (x + y))
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
return z, w
z, w = fn(*inputs)
torch._C._tracer_exit((z, w))
torch._C._jit_pass_lint(trace.graph())
(z * w).backward()
torch._C._jit_pass_dce(trace.graph())
torch._C._jit_pass_lint(trace.graph())
x_grad = x.grad.data.clone()
x.grad.data.zero_()
function = torch._C._jit_createInterpreterFactory(trace)
torch._C._jit_pass_lint(trace.graph())
z2, w2 = function()(x, y)
(z2 * w2).backward()
self.assertEqual(z, z2)
self.assertEqual(w, w2)
self.assertEqual(x.grad.data, x_grad)
def test_verify(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@torch.jit.compile
def f(x, y):
z = torch.sigmoid(x * (x + y))
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
return z, w
torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
def test_constant(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
trace, (tx,) = torch._C._tracer_enter((x,), 0)
y = Variable(torch.diag(torch.Tensor([2, 2])))
z = tx.matmul(y)
torch._C._tracer_exit((z,))
function = torch._C._jit_createInterpreterFactory(trace)
z2 = function()(x)
self.assertEqual(z, z2)
y.data.fill_(1000) # make sure the data has been cloned
x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
z3 = function()(x2)
self.assertEqual(z3.data, torch.ones(2, 2) * 4)
def test_c_function(self):
x = Variable(torch.randn(1, 3, 10, 10))
m = nn.Conv2d(3, 8, 3, 1)
trace, inputs = torch._C._tracer_enter((x,) + tuple(m.parameters()), 0)
y = m(inputs[0])
torch._C._tracer_exit((y,))
self.assertExpectedTrace(trace)
def test_legacy_fail(self):
class MyLegacyFn(Function):
def forward(self, x):
return x
def backward(self, grad_output):
return grad_output
x = Variable(torch.Tensor([0]), requires_grad=True)
trace, inputs = torch._C._tracer_enter((x,), 0)
self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(*inputs))
torch._C._tracer_exit(inputs)
def test_inplace_transplant(self):
x = Variable(torch.Tensor([0]), requires_grad=True)
trace, inputs = torch._C._tracer_enter((x,), 0)
def fn(x):
y = x.clone()
y.add_(2)
y.add_(3)
return y
y = fn(*inputs)
torch._C._tracer_exit((y,))
self.assertExpectedTrace(trace)
def test_inplace_flags(self):
class InplaceFn(Function):
@staticmethod
def forward(ctx, x):
ctx.mark_dirty(x)
return x.add_(1)
@staticmethod
def backward(ctx, go):
return go
class RegularFn(Function):
@staticmethod
def forward(ctx, x):
return x.add(1)
@staticmethod
def backward(ctx, go):
return go
x = Variable(torch.Tensor([0]), requires_grad=True)
trace, inputs = torch._C._tracer_enter((x,), 0)
def fn(x):
y = RegularFn.apply(x)
y = InplaceFn.apply(y)
y = InplaceFn.apply(y)
y = RegularFn.apply(y)
return y
y = fn(*inputs)
torch._C._tracer_exit((y,))
torch._C._jit_pass_dce(trace.graph())
ops = [n for n in trace.graph().nodes()]
for op in ops:
self.assertTrue(op.hasAttribute('inplace'))
inplace_flags = [False, True, True, False]
for op, is_inplace in zip(ops, inplace_flags):
self.assertEqual(op.i('inplace'), is_inplace)
def test_inplace_check(self):
class MyInplaceFn(Function):
@staticmethod
def forward(self, x):
x.add_(1)
self.mark_dirty(x)
return x
@staticmethod
def backward(self, grad):
return grad
@torch.jit.compile(nderivs=0)
def fn(x):
return MyInplaceFn.apply(x)
x = Variable(torch.randn(5, 5))
fn(x) # trace
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
fn(x)
def test_backward(self):
a = Variable(torch.randn(2, 2), requires_grad=True)
b = Variable(torch.randn(2, 2), requires_grad=True)
x = a
y = a * b
trace, inputs = torch._C._tracer_enter((x, y), 2)
def fn(x, y):
return y * 2 * x
z = fn(*inputs)
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace.graph())
# Run first backward
grad, = torch.autograd.grad(z, x, Variable(torch.ones(2, 2), requires_grad=True), create_graph=True)
torch._C._jit_pass_lint(trace.graph())
# Run second backward
grad.sum().backward(create_graph=True)
torch._C._jit_pass_lint(trace.graph())
# Run dead code elimination to remove unused trace nodes
torch._C._jit_pass_dce(trace.graph())
# This is nondeterministic, see:
# https://github.com/ezyang/pytorch/issues/227
# self.assertExpectedTrace(trace)
self.skipTest("output is nondeterministic on Travis/Python 3.5")
def test_backward_opaque(self):
x = Variable(torch.randn(3, 3), requires_grad=True)
y = Variable(torch.randn(3, 3), requires_grad=True)
trace, inputs = torch._C._tracer_enter((x, y), 2)
def fn(x, y):
return x.cross(y)
z = fn(*inputs)
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace.graph())
# Run first backward
grad, = torch.autograd.grad(z, x, Variable(torch.ones(3, 3), requires_grad=True), create_graph=True)
torch._C._jit_pass_lint(trace.graph())
# Run dead code elimination to remove unused trace nodes
torch._C._jit_pass_dce(trace.graph())
# This is nondeterministic, see:
# https://github.com/ezyang/pytorch/issues/227
# self.assertExpectedTrace(trace)
self.skipTest("output is nondeterministic on Travis/Python 3.5")
def test_backward_closure(self):
"""Check that autograd closures handle multiple stages correctly."""
x = Variable(torch.randn(1), requires_grad=True)
@torch.jit.compile(nderivs=2)
def fn(x):
return x * x
# Generate trace
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
self.assertFalse(fn.has_trace_for(x))
grad_x.backward()
self.assertTrue(fn.has_trace_for(x))
x_grad = x.grad.data.clone()
x.grad.data.zero_()
# Run the trace
with self.assertCompiled(fn):
output = fn(x)
grad_x, = torch.autograd.grad(output, (x,), create_graph=True)
grad_x.backward()
self.assertEqual(x.grad.data, x_grad)
def test_trace_expire(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
y = Variable(torch.randn(2, 2), requires_grad=True)
def record_trace(num_backwards):
trace, inputs = torch._C._tracer_enter((x, y), num_backwards)
def fn(x, y):
return y * 2 * x
z = fn(*inputs)
torch._C._tracer_exit((z,))
return z, trace
def check(expired, complete):
self.assertEqual(trace.is_expired, expired)
self.assertEqual(trace.is_complete, complete)
z, trace = record_trace(0)
check(False, True)
del z
check(False, True)
z, trace = record_trace(1)
check(False, False)
del z
check(True, False)
z, trace = record_trace(1)
check(False, False)
z.sum().backward()
check(False, True)
del z
check(False, True)
def test_multiuse_fn(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
w = Variable(torch.randn(2, 2), requires_grad=True)
@torch.jit.compile
def cell(x, w):
return x * w + 2
out = cell(cell(cell(x, w), w), w)
self.assertFalse(cell.has_trace_for(x, w))
out.sum().backward()
self.assertTrue(cell.has_trace_for(x, w))
torch.jit.verify(cell, (x, w), devices=[])
def test_output_unflatten(self):
"""Check that outputs of traced functions retain the original structure and nesting"""
x = Variable(torch.randn(2, 2), requires_grad=True)
def fn(x):
return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
expected_out = fn(x)
fn = torch.jit.compile(fn)
def recursive_sum(obj):
if isinstance(obj, Variable):
return obj.sum()
else:
return sum(recursive_sum(o) for o in obj)
recursive_sum(fn(x)).backward()
self.assertTrue(fn.has_trace_for(x))
with self.assertCompiled(fn):
self.assertEqual(fn(x), expected_out)
def test_input_flatten(self):
"""Check that inputs to traced functions are flattened"""
def make_var():
return Variable(torch.randn(1), requires_grad=True)
x = (make_var(), (make_var(), make_var()))
def fn(x, t):
y, z = t
return x * y * z
expected_out = fn(*x)
fn = torch.jit.compile(fn)
fn(*x).backward()
self.assertTrue(fn.has_trace_for(*x))
with self.assertCompiled(fn):
self.assertEqual(fn(*x), expected_out)
def test_flags(self):
x = Variable(torch.randn(2, 2))
y = Variable(torch.randn(2, 2))
@torch.jit.compile
def fn(x, y):
return (x * x + y * y + x * y).sum()
grads = {}
for rx, ry in product((True, False), repeat=2):
x.requires_grad = rx
y.requires_grad = ry
self.assertFalse(fn.has_trace_for(x, y))
out = fn(x, y)
self.assertFalse(fn.has_trace_for(x, y))
for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
if not compute:
continue
grad_v, = torch.autograd.grad(out, v, retain_graph=True)
expected_grad = grads.setdefault(name, grad_v)
self.assertEqual(grad_v, expected_grad)
self.assertEqual(fn.has_trace_for(x, y), rx or ry)
def test_no_grad_fallback(self):
"""Check that Traceable falls back to num_backwards=0 if in no-backprop mode"""
x = Variable(torch.randn(2, 2))
y = Variable(torch.randn(2, 2), requires_grad=True)
@torch.jit.compile
def fn(x, y):
return x * x + x * y
out = fn(x, y)
self.assertFalse(fn.has_trace_for(x, y))
with torch.no_grad():
out = fn(x, y)
self.assertTrue(fn.has_trace_for(x, y))
with self.assertCompiled(fn):
out2 = fn(x, y)
self.assertEqual(out, out2)
def test_backward_flag_checks(self):
x = Variable(torch.randn(1), requires_grad=True)
@torch.jit.compile(nderivs=2)
def fn(x):
return x * x
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
self.assertFalse(fn.has_trace_for(x))
grad_x.backward()
self.assertTrue(fn.has_trace_for(x))
with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
fn(x).backward(Variable(torch.ones(1), requires_grad=True))
with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
grad_x.backward(Variable(torch.ones(1), requires_grad=True))
# TODO: Test executing this
def test_python_ir(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
traced, _ = torch.jit.get_trace_graph(doit, (x, y))
g = torch._C._jit_get_graph(traced)
g2 = torch._C.Graph()
g_to_g2 = {}
for node in g.inputs():
g_to_g2[node] = g2.addInput()
for node in g.nodes():
n_ = g2.createClone(node, lambda x: g_to_g2[x])
g2.appendNode(n_)
for o, no in zip(node.outputs(), n_.outputs()):
g_to_g2[o] = no
for node in g.outputs():
g2.registerOutput(g_to_g2[node])
t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
assert(t_node.attributeNames() == ["a"])
g2.appendNode(t_node)
assert(torch.equal(torch.ones([2, 2]), t_node.t("a")))
self.assertExpected(str(g2))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
def test_cpp(self):
# rather than rebuild assertExpected in cpp,
# just glob all the cpp outputs into one file for now
self.assertExpected(torch._C._jit_run_cpp_tests())
def test_batchnorm(self):
x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x)
self.assertExpectedTrace(trace)
def test_dropout(self):
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
self.assertExpectedTrace(trace)
def test_batchnorm_run_twice(self):
@torch.jit.compile(nderivs=0)
class MyBatchNorm2d(nn.BatchNorm2d):
pass
bn = MyBatchNorm2d(1)
x = Variable(torch.randn(5, 1, 2, 1))
z = bn(x)
with self.assertCompiled(bn):
z2 = bn(x)
self.assertEqual(z, z2)
def test_non_decorator_use_fails(self):
MyLSTM = torch.jit.compile(nn.LSTM)
self.assertRaisesRegex(TypeError, "class decorator", lambda: MyLSTM(2, 2))
def test_conv(self):
x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
self.assertExpectedTrace(trace)
def test_reuse_function(self):
@torch.jit.compile(nderivs=0)
def clinear(*args):
return F.linear(*args)
def cast(x):
return x
input = Variable(cast(torch.randn(1, 1)))
weights = Variable(cast(torch.randn(1, 1)))
bias = Variable(cast(torch.randn(1, 1)))
# linear AKA addmm without bias is of particular interest
# because we allocate a zero-filled new variable when we execute,
# and then *fill* it with the result
r1_ = clinear(input, weights)
with self.assertCompiled(clinear):
r1 = clinear(r1_, weights)
r2 = F.linear(F.linear(input, weights), weights)
self.assertEqual(r1, r2)
def test_unused_input(self):
@torch.jit.compile(nderivs=1)
def fn(a, b, c):
return a + b
a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
fn(a, b, c).sum().backward()
with self.assertCompiled(fn):
fn(a, b, c).sum().backward()
def test_repeated_input(self):
@torch.jit.compile(nderivs=1)
def fn(a, b):
return a + b
a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
fn(a, a).sum().backward()
with self.assertCompiled(fn):
fn(a, a).sum().backward()
with self.assertCompiled(fn):
fn(a, b).sum().backward()
self.assertExpected(str(fn.graph_for(a, a)))
def test_repeated_output(self):
@torch.jit.compile(nderivs=1)
def fn(a, b):
z = a + b
return z, z
a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
sum(fn(a, b)).sum().backward()
with self.assertCompiled(fn):
sum(fn(a, b)).sum().backward()
self.assertExpected(str(fn.graph_for(a, b)))
def test_re_enter(self):
@torch.jit.compile(nderivs=1)
def fn(a, b):
return a + b
@torch.jit.compile(nderivs=1)
def fn2(a, b, c):
return fn(a, b) + c
a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
fn(a, b).sum().backward()
with self.assertCompiled(fn):
fn(a, b).sum().backward()
fn2(a, b, c).sum().backward()
with self.assertCompiled(fn2):
fn2(a, b, c).sum().backward()
def test_mini_wlm(self):
"""Exercise null-edge pruning in the tracer."""
@torch.jit.compile
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.encoder = nn.Embedding(2, 2)
def forward(self, input, hidden):
emb = self.encoder(input)
hidden = hidden.clone() # simulate some RNN operation
return emb, hidden
model = MyModel()
x = Variable(torch.LongTensor([[0, 1], [1, 0]]))
y = Variable(torch.FloatTensor([0]))
z, _ = model(x, y)
z.sum().backward()
self.assertTrue(model.has_trace_for(x, y))
with self.assertCompiled(model):
z, _ = model(x, y)
z.sum().backward()
def test_module_cast(self):
"""Compiled modules can be casted to other data types"""
@torch.jit.compile(nderivs=0)
class Adder(nn.Module):
def __init__(self):
super(Adder, self).__init__()
self.y = nn.Parameter(torch.randn(2, 2))
def forward(self, x):
return x + self.y
x = Variable(torch.randn(2, 2).float())
# Wrap it in a sequential to make sure it works for submodules
a = nn.Sequential(Adder()).float()
def check_type(caster):
caster(a)
a(caster(x))
with self.assertCompiled(a[0]):
a(caster(x))
check_type(lambda x: x)
check_type(lambda x: x.double())
if torch.cuda.is_available():
check_type(lambda x: x.float().cuda())
check_type(lambda x: x.double().cuda())
self.assertEqual(a[0].hits, 4 if torch.cuda.is_available() else 2)
# Tracer fails when it receives the same grad variable as multiple input to
# traced region. The problem is that it's not immediately obvious how to
# assign multiple inputs to this Variable. It might be possible to solve
# this using the view mechanism, but this requires some thought.
# In general, it should be supported, because the user has no control
# over this (and it's quite common, e.g. the sum call below will pass the same
# grad variable as both inputs to grad of fn).
@unittest.skip("Broken - repeated grads trigger an assertion failure.")
def test_repeated_grad(self):
@torch.jit.compile
def fn(x):
return x * x, x + x
x = Variable(torch.randn(5, 5), requires_grad=True)
# This shouldn't raise!
sum(fn(x)).sum().backward()
def test_input_pruning(self):
"""Check that stage 1 will return only one value"""
# One of the inputs doesn't require grad, so it should be pruned
@torch.jit.compile
def fn(x, y):
return x * y, x + y
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5))
out = fn(x, y)
(out[0] * out[1]).sum().backward()
with self.assertCompiled(fn):
fn(x, y)
self.assertExpected(str(fn.graph_for(x, y)))
def test_output_pruning(self):
"""Check that stage 1 will take one value as an argument"""
# One of the outputs doesn't require grad, so it should be pruned
@torch.jit.compile
def fn(x, y):
return x * y, y + y
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5))
out = fn(x, y)
(out[0] * out[1]).sum().backward()
with self.assertCompiled(fn):
fn(x, y)
self.assertExpected(str(fn.graph_for(x, y)))
@skipIfNoTorchVision
def test_alexnet(self):
return
x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
self.assertExpectedTrace(trace)
# NB: Purposely NOT testing protobuf export here
def test_debug_info(self):
"""Check that debug info doesn't crash and has some reasonable info"""
@torch.jit.compile(nderivs=1)
def fn(x, y):
return x * y + x + y
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
out = fn(x, y)
out.sum().backward()
for _ in range(0, 100):
out = fn(x, y)
info_str = fn.jit_debug_info()
self.assertTrue("hits: 100" in info_str)
self.assertTrue("stage 1" in info_str)
# Inplace copies don't work with tracer yet.
# This is actually somewhat important to support correctly
# as all backwards functions of views are implemented
# as a zero filled tensor with a gradient fill on the
# viewed portion.
@unittest.expectedFailure
def test_inplace_copy(self):
x = Variable(torch.randn(4, 4), requires_grad=True)
def f(x):
out = Variable(torch.zeros(x.size()))
out.copy_(x)
return out
trace, z = torch.jit.get_trace_graph(f, (x, ), nderivs=0)
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_dce(trace.graph())
self.assertExpectedTrace(trace)
def test_index_trace(self):
x = Variable(torch.randn(4, 4), requires_grad=True)
trace, z = torch.jit.get_trace_graph(lambda x: x[0], (x, ), nderivs=1)
z.sum().backward()
torch._C._jit_pass_lint(trace.graph())
torch._C._jit_pass_dce(trace.graph())
self.assertExpectedTrace(trace)
def test_saved_output(self):
x = Variable(torch.randn(4, 4), requires_grad=True)
@torch.jit.compile(nderivs=1)
def fn(x):
return x.sigmoid()
fn(x).sum().backward()
self.assertExpected(str(fn.graph_for(x)))
def test_shared_param(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.b = self.a = nn.Parameter(torch.randn(2, 2))
def forward(self, x):
return x * self.a + self.b
m = MyModule()
trace, _ = torch.jit.get_trace_graph(m, (Variable(torch.randn(2, 2)),), nderivs=0)
self.assertEqual(len(list(trace.graph().inputs())), 2)
self.assertExpected(str(trace))
def test_nested_inplace(self):
x = Variable(torch.randn(2, 2))
trace, _ = torch.jit.get_trace_graph(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0)
self.assertExpectedTrace(trace)
def checkGraphExecutor(self, func, reference_tensors, input_tensors=None,
optimize=True, drop=None, allow_unused=False):
def allSum(vs):
# drop allows us to remove some values from ever being used
# to test unused outputs
if drop is not None:
vs = vs[:-drop]
# we don't want all the grad for all the outputs to be the same
# so we multiply each by a constant
return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
if input_tensors is None:
input_tensors = reference_tensors
nograd_inputs = [Variable(t) for t in reference_tensors]
recording_inputs = [Variable(t, requires_grad=True)
for t in reference_tensors]
ge = torch._C.GraphExecutor(func, [Variable(t) for t in input_tensors], optimize)
# test no gradients case
outputs = func(*nograd_inputs)
outputs_ge = ge(*nograd_inputs)
self.assertEqual(outputs, outputs_ge)
# test single grad case
outputs = func(*recording_inputs)
grads = torch.autograd.grad(allSum(outputs), recording_inputs,
allow_unused=allow_unused)
outputs_ge = ge(*recording_inputs)
grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
self.assertEqual(grads, grads_ge)
# test the grad grad case
outputs = func(*recording_inputs)
l1 = allSum(outputs)
grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
allow_unused=allow_unused)
l2 = (allSum(grads) * l1)
grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
recording_inputs = [Variable(t, requires_grad=True)
for t in reference_tensors]
outputs_ge = ge(*recording_inputs)
l1_ge = allSum(outputs_ge)
grads_ge = torch.autograd.grad(
l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
l2_ge = (allSum(grads_ge) * l1_ge)
grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
self.assertEqual(grads, grads_ge)
self.assertEqual(grads2, grads2_ge)
def run_ge_tests(self, optimize, use_cuda):
def rand(*args):
t = torch.rand(*args).float()
if use_cuda:
t = t.cuda()
return t
self.checkGraphExecutor(lambda a, b: a * b + b,
[rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
optimize=optimize)
# trivial identity
self.checkGraphExecutor(lambda a, b: (
b, a), [rand(1), rand(1)], optimize=optimize)
def foo(a):
t = a * a
return t * t, 4 * t
self.checkGraphExecutor(foo, [rand(1)], optimize=optimize)
# unused input
self.checkGraphExecutor(
lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
allow_unused=True)
# test outputs that do not get used in grad
self.checkGraphExecutor(foo, [rand(1)], drop=1, optimize=optimize)
# test autograd fallback
self.checkGraphExecutor(lambda a, b: a * b /
(a - 2 * b) + b, [rand(1), rand(1)],
optimize=optimize)
def test_ge_unoptimized(self):
self.run_ge_tests(False, False)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
def test_ge_optimized(self):
self.run_ge_tests(True, False)
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_ge_cuda(self):
self.run_ge_tests(True, True)
# more manual test of graph executor that can be used as a scratchpad
def test_ge(self):
def foo(a, b):
return a * b / (a - b) + b
V = Variable
a, b = V(torch.rand(1)), V(torch.rand(1))
ge = torch._C.GraphExecutor(foo, (a, b))
a, b = V(torch.rand(1), requires_grad=True), V(
torch.rand(1), requires_grad=True)
r, = ge(a, b)
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
l2 = (da * db + db * db)
g2result = torch.autograd.grad(l2, [da, db])
r = foo(a, b)
da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
self.assertEqual(da, da2)
self.assertEqual(db, db2)
l3 = (da2 * db2 + db2 * db2)
g2result2 = torch.autograd.grad(l3, [da2, db2])
self.assertEqual(g2result, g2result2)
def test_trace_annotation(self):
@torch.jit.trace(Variable(torch.rand(1)))
def foo(a):
return a + a + a
s = Variable(torch.rand(2))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
def test_traced_module(self):
class Model(nn.Module):
def __init__(self, num_features, num_layers):
super(Model, self).__init__()
self.num_layers = num_layers
layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
for _ in range(num_layers)]
self.submodule = nn.Sequential(*chain(*layers))
def forward(self, x):
for i in range(self.num_layers):
x = self.submodule[i](x) + x
return x
model = Model(5, 3)
x = torch.randn(2, 5)
traced_model = torch.jit.trace(x)(model)
# We're missing some attributes these modules had initially. Make sure we can
# still get the __repr__()
model.__repr__()
# XXX: indexing sequentials is broken
linear_submodule = next(iter(traced_model.submodule._modules.values()))
# All attributes that aren't parameters should raise
with self.assertRaises(AttributeError):
linear_submodule.in_features
linear_submodule.weight
with self.assertRaises(RuntimeError):
traced_model.asdf = 4
linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
with self.assertRaises(RuntimeError):
del linear_submodule.weight
# Submodules can't be called
with self.assertRaises(RuntimeError):
linear_submodule(x)
# Type casts
linear_submodule.cuda()
traced_model.float().cuda()
cuda_out = traced_model(x.float().cuda())
traced_model.cpu()
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
traced_model.double()
# state_dict + load_state_dict
state = {k: v.clone() for k, v in traced_model.state_dict().items()}
new_state = {k: v.clone().fill_(1) for k, v in state.items()}
out = traced_model(x)
traced_model.load_state_dict(new_state)
out_ones = traced_model(x)
traced_model.load_state_dict(state)
out_state = traced_model(x)
self.assertEqual(out, out_state)
self.assertNotEqual(out, out_ones)
def test_shape_prop_mismatch_output(self):
with self.assertRaises(RuntimeError):
cu = torch.jit.CompilationUnit('''
def test_shape_prop_mismatch_output(a):
b = slice(a, dim=0, end=-2, start=2, step=1)
b = topk(a, dim=0, k=2, largest=True, sorted=True)
return b
''')
inputs = [torch.zeros(10)]
outputs = [torch.zeros(2), torch.from_numpy(np.array([1, 5])).long()]
real_outs = cu.test_shape_prop_mismatch_output(*inputs)
self.assertEqual(real_outs, outputs)
def test_view_shape_prop(self):
cu = torch.jit.CompilationUnit('''
def test_view_shape_prop(a):
return view(a, size=[-1])
''')
inputs = [torch.zeros(10, 10)]
outputs = torch.zeros(100)
real_outs = cu.test_view_shape_prop(*inputs)
self.assertEqual(real_outs, outputs)
def test_integral_shape_inference(self):
cu = torch.jit.CompilationUnit('''
def test_integral_shape_inference(a):
return a / a
''')
inputs = [torch.ones(10, 10).type(torch.LongTensor)]
outputs = torch.ones(10, 10)
self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
def test_shape_analysis_broadcast(self):
def broadcast(a, b):
return a + b
x = torch.randn(3, 1, 5, requires_grad=True)
y = torch.randn(4, 1, 8, 5, requires_grad=True)
graph = torch.jit._script_graph(broadcast)
torch._C._jit_pass_shape_analysis(graph, (x, y), False)
self.assertExpected(str(graph))
def test_fuser_multiple_blocks(self):
cu = torch.jit.CompilationUnit('''
def test_fuser_multiple_blocks(this, that, theother, meme):
i = 0
while i < 20:
this = cat(this, meme, dim=0)
that = cat(that, meme, dim=0)
theother = cat(theother, meme, dim=0)
i = i + 1
return this, that, theother
''')
inputs = [torch.ones(0, 10, 10)] * 3
inputs += [torch.ones(1, 10, 10)]
outputs = [torch.ones(20, 10, 10)] * 3
self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
class TestScript(TestCase):
@contextmanager
def capture_stdout(self):
# No idea how to capture stdout from C++ on Windows
if WINDOWS:
yield ['']
return
import os
import fcntl
import errno
sys.stdout.flush()
stdout_fd = os.dup(1)
r, w = os.pipe()
try:
# Override stdout with r - dup is guaranteed to return the lowest free fd
os.close(1)
os.dup(w)
captured_stdout = ['']
yield captured_stdout
sys.stdout.flush() # Make sure that Python hasn't buffered anything
# Do the ugly dance to read all the data that was written into the pipe
fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
total_stdout = ''
while True:
try:
total_stdout += os.read(r, 1000).decode('ascii')
except OSError as e:
if e.errno != errno.EAGAIN:
raise
break
captured_stdout[0] = total_stdout
finally:
# Revert the change, and clean up all fds
os.close(1)
os.dup(stdout_fd)
os.close(stdout_fd)
os.close(r)
os.close(w)
def checkScript(self, script, inputs, optimize, outputs=None, name='func', capture_output=False):
if isinstance(script, str):
cu = torch.jit.CompilationUnit(script, optimize)
ge = getattr(cu, name)
else:
if capture_output:
with self.capture_stdout() as captured:
outputs = script(*inputs)
else:
outputs = script(*inputs)
# Check the string frontend first
source = textwrap.dedent(inspect.getsource(script))
self.checkScript(source, inputs, optimize, outputs, script.__name__, capture_output)
# Continue checking the Python frontend
ge = torch.jit.script(script)
if capture_output:
with self.capture_stdout() as captured:
outputs_ge = ge(*inputs)
if not WINDOWS:
self.assertExpected(captured[0], subname='stdout')
else:
outputs_ge = ge(*inputs)
self.assertEqual(outputs, outputs_ge)
def test_script_cu(self):
cu = torch.jit.CompilationUnit('''
def foo(a):
b = a
return b
''')
a = Variable(torch.rand(1))
self.assertEqual(a, cu.foo(a))
def test_script_annotation(self):
@torch.jit.script
def foo(a):
return a + a + a
s = Variable(torch.rand(2))
self.assertEqual(s + s + s, foo(s))
def test_add(self):
def func(a, b):
c = a + b
c += a
return c
a = torch.rand(1, requires_grad=True)
b = torch.rand(1, requires_grad=True)
self.checkScript(func, (a, b), optimize=True)
def test_mul(self):
def func(a, b):
return a * b
a = torch.rand(1, requires_grad=True)
b = torch.rand(1, requires_grad=True)
self.checkScript(func, (a, b), optimize=True)
def test_triple(self):
def func(x):
return 3. * x
x = torch.rand(1, dtype=torch.float, requires_grad=True)
self.checkScript(func, [x], optimize=True)
def test_slice(self):
def func(x):
return x[:5]
x = torch.rand(10, dtype=torch.float, requires_grad=True)
self.checkScript(func, [x], optimize=True)
def test_gather(self):
def func(x):
return x[0]
x = torch.rand(10, dtype=torch.float, requires_grad=True)
self.checkScript(func, [x], optimize=True)
def test_keyword(self):
@torch.jit.script
def func(x):
return torch.sum(x, dim=0, keepdim=True)
x = torch.rand(10, dtype=torch.float, requires_grad=True)
y = func(x)
y2 = torch.sum(x, dim=0, keepdim=True)
self.assertEqual(y, y2)
def test_func_call(self):
script = '''
def add(a, b):
return a + b
def mul(a, x):
return a * x
def func(alpha, beta, x, y):
return add(mul(alpha, x), mul(beta, y))
'''
alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
beta = torch.rand(1, dtype=torch.float, requires_grad=True)
x = torch.rand(3, dtype=torch.float, requires_grad=True)
y = torch.rand(3, dtype=torch.float, requires_grad=True)
outputs = alpha * x + beta * y
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
@unittest.skip("RuntimeError: VariableType::ID() not implemented")
def test_cast(self):
script = '''
def to_int(x):
return int(x)
'''
x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
def test_python_frontend(self):
def fn(x, y, z):
q = x + y - z.sigmoid()
print(q)
w = -z
if not x and not y and z:
m = x if not z else y
while x < y > z:
q = x
return x
ast = torch.jit.frontend.get_jit_ast(fn)
self.assertExpected(str(ast))
def _make_scalar_vars(self, arr, dtype):
return [torch.tensor(val, dtype=dtype) for val in arr]
def test_while(self):
def func(a, b, max):
while a < max:
a = a + 1
b = b + 1
c = a + b
return c
inputs = self._make_scalar_vars([1, 1, 10], torch.int)
self.checkScript(func, inputs, optimize=True)
def test_fibb(self):
def func(lim):
first = 1
second = 1
i = 1
somenum = 5
dontmutateme = 3
third = 0
while i < lim:
third = first + second
first = second
second = third
j = 0
while j < 10:
somenum = somenum * 2
j = j + 1
i = i + j
i = i + dontmutateme
st = second + third
fs = first + second
return third, st, fs
inputs = self._make_scalar_vars([10], torch.int)
self.checkScript(func, inputs, optimize=True)
def test_if(self):
def func(a, b):
d = 3
if a > 10:
a = 3 + d
else:
b = 3 + d
d = 4
c = a + b
return c
inputs = self._make_scalar_vars([1, -1], torch.int)
self.checkScript(func, inputs, optimize=True)
def test_if_noelse(self):
def func(a, b):
if a > 10:
a = 3 + b
c = a + b
return c
inputs = self._make_scalar_vars([-1, 1], torch.int)
self.checkScript(func, inputs, optimize=True)
def test_while_nonexistent_value(self):
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
torch.jit.CompilationUnit('''
def test_while(a, b):
while a < 10:
a = a + x
b = b + 1
return a + b
''')
def test_while_nonexistent_cond_value(self):
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
torch.jit.CompilationUnit('''
def test_while(a, b):
while a < x:
a = a + 1
b = b + 1
return a + b
''')
def test_while_write_outer_then_read(self):
def func(a, b):
while a < 10:
a = a + 1
b = a + 1
return a + b
inputs = self._make_scalar_vars([42, 1337], torch.int)
self.checkScript(func, inputs, optimize=True)
def test_while_nest_if(self):
def func(a, b):
c = 0
while a < 10:
a = a + 1
b = b + 1
if a > b:
c = -a
else:
c = -b
return c + 1
inputs = self._make_scalar_vars([-1234, 4321], torch.int)
self.checkScript(func, inputs, optimize=True)
def test_if_nest_while(self):
def func(a, b):
c = 0
if a > b:
while a > b:
b = b + 1
c = -b
return c
inputs = self._make_scalar_vars([4321, 1234], torch.int)
self.checkScript(func, inputs, optimize=True)
def test_script_for_in_range(self):
script = '''
def test_for_in_range():
c = 0
for i in range(100):
c += i
return c
'''
self.checkScript(script, [], outputs=[4950], optimize=True, name='test_for_in_range')
def test_script_for_in_range_dynamic(self):
script = '''
def test_script_for_in_range_dynamic():
c = 0
for i in range(100):
acc = 0
for j in range(i):
acc += j
c += acc
return c
'''
self.checkScript(script, [], outputs=[161700], optimize=True, name='test_script_for_in_range_dynamic')
def test_script_for_in_range_ast(self):
@torch.jit.script
def test_script_for_in_range_ast(zero):
c = zero
for i in range(100):
acc = zero
for j in range(i):
acc += j
c += acc
return c
inputs = self._make_scalar_vars([0], torch.int64)
self.assertEqual(test_script_for_in_range_ast(*inputs), 161700)
def test_script_bool_constant(self):
script = '''
def test_script_bool_constant():
a = True
return a
'''
outputs = [1]
self.checkScript(script, [], outputs[0], True, 'test_script_bool_constant')
def test_ternary(self):
def func(a, b):
c = 3
c = a + b if a > 3 else b
return c
inputs_true = self._make_scalar_vars([5, 2], torch.int)
inputs_false = self._make_scalar_vars([1, 0], torch.int)
self.checkScript(func, inputs_true, optimize=True)
self.checkScript(func, inputs_false, optimize=True)
def test_print(self):
def func(x, y):
q = (x + y).sigmoid()
print(q)
w = -q
return w * w
x = torch.arange(4, requires_grad=True)
y = torch.arange(0, 8, 2, requires_grad=True)
self.checkScript(func, [x, y], optimize=True, capture_output=True)
def test_multiple_assignment(self):
def outer_func(x):
return x * 2, x + 2
@torch.jit.script
def func(x):
y, z = outer_func(x)
return y + z
x = torch.arange(4)
self.assertEqual(func(x), x * 2 + x + 2)
def test_literals(self):
def func(a):
return a.view(size=[1, 2, 3])
a = torch.randn(6)
self.checkScript(func, [a], optimize=True)
def test_return(self):
def no_return(a):
a + 1
def void_return(a):
return
def one_return(a):
return a + 1.
def multiple_returns(a):
return a * 1., a * 2., a * 3.
a = torch.randn(1, dtype=torch.float)
self.checkScript(no_return, [a], optimize=True)
self.checkScript(void_return, [a], optimize=True)
self.checkScript(one_return, [a], optimize=True)
self.checkScript(multiple_returns, [a], optimize=True)
def test_error(self):
@torch.jit.script
def foo(a):
return a.t()
s = Variable(torch.rand(10))
# XXX: this should stay quiet in stay propagation and only fail in the interpreter
with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
foo(s)
@torch.jit.script
def bar(c, b):
return c / b
with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
def test_binop_unsupported_error(self):
with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"):
@torch.jit.script
def binop(x, y):
# Replace this with another unsupported op when/if it gets supported
return x ** y
def test_python_call(self):
def pyfunc(a):
return a * 3.0
cu = torch.jit.CompilationUnit('''
def other_func(a):
return a + a
def test_call_python(a):
b = pyfunc(a)
b = other_func(b)
i = 0
step = 1
while i < 10:
b = pyfunc(b)
if b > 3.0:
b = pyfunc(b)
i = 11
return b
''')
inputs = self._make_scalar_vars([1], torch.float)
outputs = self._make_scalar_vars([54], torch.float)
self.assertEqual(cu.test_call_python(*inputs), outputs[0])
def test_python_call_failure(self):
with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
def pyfunc(a):
return a * 3.0
cu = torch.jit.CompilationUnit('''
def other_func(a):
return a + a
def test_call_python(a):
b = pyfunc(a)
b = other_func(b)
i = 0
step = 1
while i < 10:
b = pyfunc2(b)
if b > 3.0:
b = pyfunc(b)
i = 11
return b
''')
inputs = self._make_scalar_vars([1], torch.float)
outputs = self._make_scalar_vars([54], torch.float)
self.assertEqual(cu.test_call_python(*inputs), outputs)
def test_python_call_annotation(self):
def pyfunc(a):
return a * 3.0
@torch.jit.script
def foo(a):
return pyfunc(a) + pyfunc(a)
inputs = self._make_scalar_vars([1], torch.float)
outputs = self._make_scalar_vars([6], torch.float)
self.assertEqual(foo(*inputs), outputs[0])
def test_python_call_annoytation_failure(self):
with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
def pyfunc(a):
return a * 3.0
@torch.jit.script
def foo(a):
return pyfunc2(a) + pyfunc(a)
inputs = self._make_scalar_vars([1], torch.float)
outputs = self._make_scalar_vars([6], torch.float)
self.assertEqual(foo(*inputs), outputs[0])
def test_desugar_module(self):
import torch.nn.functional as F
def fn(x, slope):
a = torch.abs(x)
b = torch.nn.functional.prelu(x, slope)
c = F.prelu(x, slope)
return a, b, c
x = torch.arange(-3, 4)
slope = torch.tensor([0.5])
self.checkScript(fn, [x, slope], optimize=True)
def test_script_module(self):
class M1(torch.jit.ScriptModule):
def __init__(self):
super(M1, self).__init__(False)
self.weight = nn.Parameter(torch.randn(2))
@torch.jit.script_method
def forward(self, thing):
return self.weight + thing
class PModule(nn.Module):
def __init__(self):
super(PModule, self).__init__()
self.a = nn.Parameter(torch.randn(2, 3))
def forward(self, a):
return self.a.mm(a)
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(False)
# test submodule
self.sub = M1()
self.sub2 = PModule()
# test parameters
self.weight = nn.Parameter(torch.randn(2, 3))
self.bias = nn.Parameter(torch.randn(2))
# test defining a method from a string
self.define("""
def hi(self, a):
return self.weight.mm(a)
""")
# test script methods
@torch.jit.script_method
def doit(self, input):
# test use of parameter
return self.weight.mm(input)
@torch.jit.script_method
def doit2(self, input):
return self.weight.mm(input)
@torch.jit.script_method
def forward(self, input):
a = self.doit(input)
b = self.doit2(input)
c = self.hi(input)
d = self.sub2(input)
return a + b + self.bias + self.sub(a) + c + d
m2 = M2()
input = torch.randn(3, 2)
a = m2.weight.mm(input)
b = m2.weight.mm(input)
c = m2.weight.mm(input)
d = m2.sub2.a.mm(input)
ref = a + b + m2.bias + m2.sub.weight + a + c + d
self.assertEqual(ref, m2.forward(input))
m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
m2.sub2.a.data.zero_()
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
def test_script_module_call_noscript(self):
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__(False)
self.value = 1
def foo(self):
return torch.ones(2, 2) + self.value
@torch.jit.script_method
def forward(self, input):
return input + self.foo()
m = M()
input = torch.randn(2, 2)
o = m(input)
self.assertEqual(o, input + torch.ones(2, 2) + 1)
# check that we can change python attributes
# and that those changes are picked up in script methods
m.value = 2
o = m(input)
self.assertEqual(o, input + torch.ones(2, 2) + 2)
def test_script_module_nochange_submodule(self):
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__(False)
self.sub = nn.Linear(5, 5)
@torch.jit.script_method
def forward(self, input):
return self.sub(input)
m = M()
input = torch.randn(1, 5, 5)
o = m(input)
self.assertEqual(o, m.sub(input))
with self.assertRaisesRegex(RuntimeError, "cannot re-assign"):
m.sub = nn.Linear(5, 5)
def test_script_inline_trace_multiple_args(self):
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__(False)
def forward(self, input, input2):
return input + input2
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(False)
self.m = torch.jit.trace(torch.zeros(4, 3), torch.zeros(4, 3))(M())
@torch.jit.script_method
def forward(self, inp):
return self.m(inp, inp)
m2 = M2()
m2(torch.zeros(4, 3))
def test_script_module_const(self):
class M(torch.jit.ScriptModule):
__constants__ = ['b', 'i', 'c']
def __init__(self):
super(M, self).__init__(False)
self.b = False
self.i = 1
self.c = 3.5
@torch.jit.script_method
def forward(self):
return self.b, self.i, self.c
m = M()
o0, o1, o2 = m()
self.assertEqual(o0, 0)
self.assertEqual(o1, 1)
self.assertEqual(o2, 3.5)
def test_script_module_fail_const(self):
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__(False)
self.b = False
@torch.jit.script_method
def forward(self):
return self.b
with self.assertRaisesRegex(RuntimeError, "is not usable in a script method"):
M()
def test_script_module_valid_consts(self):
class Foo(torch.jit.ScriptModule):
__constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
def __init__(self):
super(Foo, self).__init__(False)
self.a = 1
self.b = 1.2
self.c = False
self.d = [nn.Linear(3, 4)]
self.e = lambda x: x
self.f = [3, 4, 5]
self.assertTrue(type(self.f) is tuple)
self.g = [3, (3, 4), 5]
with self.assertRaisesRegex(TypeError, "is not a valid constant"):
self.h = type(1)
with self.assertRaisesRegex(TypeError, "is not a valid constant"):
self.i = (3, 4, {})
def test_script_module_for(self):
class M(torch.jit.ScriptModule):
__constants__ = ['b']
def __init__(self):
super(M, self).__init__(False)
self.b = [1, 2, 3, 4]
@torch.jit.script_method
def forward(self):
sum = 0
for i in self.b:
sum += i
return sum
m = M()
self.assertEqual(m(), 10)
def test_script_module_for2(self):
class Sub(torch.jit.ScriptModule):
def __init__(self):
super(Sub, self).__init__(False)
self.weight = nn.Parameter(torch.randn(2))
@torch.jit.script_method
def forward(self, thing):
return self.weight + thing
class M(torch.jit.ScriptModule):
__constants__ = ['mods']
def __init__(self):
super(M, self).__init__(False)
self.mods = nn.ModuleList([Sub() for i in range(10)])
@torch.jit.script_method
def forward(self, v):
for m in self.mods:
v = m(v)
return v
i = torch.Tensor(2)
m = M()
o = m(i)
v = i
for sub in m.mods:
v = sub(v)
self.assertEqual(o, v)
def test_script_module_const_submodule_fail(self):
class Sub(torch.jit.ScriptModule):
def __init__(self):
super(Sub, self).__init__(False)
self.weight = nn.Parameter(torch.randn(2))
@torch.jit.script_method
def forward(self, thing):
return self.weight + thing
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__(False)
self.mods = [Sub() for _ in range(10)]
@torch.jit.script_method
def forward(self):
for _ in self.mods:
print(1)
return 4
with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
M()
def test_script_module_not_tuple(self):
class M(torch.jit.ScriptModule):
__constants__ = ['mods']
def __init__(self):
super(M, self).__init__(False)
self.mods = 1
@torch.jit.script_method
def forward(self, v):
for m in self.mods:
print(m)
return v
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
M()
class StarTestSumStarred(torch.nn.Module):
def __init__(self):
super(TestScript.StarTestSumStarred, self).__init__()
def forward(self, *inputs):
output = inputs[0]
for i in range(1, len(inputs)):
output += inputs[i]
return output
class StarTestReturnThree(torch.nn.Module):
def __init__(self):
super(TestScript.StarTestReturnThree, self).__init__()
def forward(self, rep):
return rep, rep, rep
def test_script_star_expr(self):
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(True)
self.m = torch.jit.trace(
torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))(TestScript.StarTestSumStarred())
self.g = torch.jit.trace(torch.ones(4, 3))(TestScript.StarTestReturnThree())
@torch.jit.script_method
def forward(self, rep):
tup = self.g(rep)
return self.m(*tup)
m = M2()
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
def test_script_star_expr_string(self):
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(True)
self.m = torch.jit.trace(
torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))(TestScript.StarTestSumStarred())
self.g = torch.jit.trace(torch.ones(4, 3))(TestScript.StarTestReturnThree())
self.define('''
def forward(self, rep):
tup = self.g(rep)
return self.m(*tup)
''')
m = M2()
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
class StarTestSumAndReturnThree(torch.nn.Module):
def __init__(self):
super(TestScript.StarTestSumAndReturnThree, self).__init__()
def forward(self, *inputs):
output = inputs[0]
for i in range(1, len(inputs)):
output += inputs[i]
return output, output, output
def test_script_star_assign(self):
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(True)
self.g = torch.jit.trace(torch.ones(4, 3))(TestScript.StarTestSumAndReturnThree())
self.define('''
def forward(self, rep):
head, *tail = self.g(rep)
return head
''')
m = M2()
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
def test_script_module_star_assign2(self):
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(True)
self.g = torch.jit.trace(
torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)
)(
TestScript.StarTestSumAndReturnThree()
)
self.define('''
def forward(self, rep):
*head, tail = self.g(rep, rep, rep)
return tail
''')
m = M2()
self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
def test_script_module_star_assign_fail_pythonop(self):
with self.assertRaisesRegex(RuntimeError, "value cannot be used as a tuple"):
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(True)
def myfunc():
return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
self.define('''
def forward(self, rep):
a, *b = myfunc()
return a
''')
m = M2()
m(torch.zeros(4, 3))
def test_script_module_star_assign_fail_builtin(self):
with self.assertRaisesRegex(RuntimeError, "value cannot be used as a tuple"):
class M2(torch.jit.ScriptModule):
def __init__(self):
super(M2, self).__init__(True)
self.define('''
def forward(self, rep):
a, *b = torch.neg(rep)
return a
''')
m = M2()
m(torch.zeros(4, 3))
def test_pack_padded_pad_packed_trace(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
T, B, C = 3, 5, 7
class PadPackedWrapper(torch.nn.Module):
def __init__(self):
super(PadPackedWrapper, self).__init__()
def forward(self, x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, _ = pad_packed_sequence(x)
return x
x = np.ones((T, B, C))
seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
# set padding value so we can test equivalence
for b in range(B):
if seq_lens[b] < T:
x[seq_lens[b]:, b, :] = 0
seq_lens = torch.from_numpy(seq_lens)
x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
m = PadPackedWrapper()
m_traced = torch.jit.trace(x, seq_lens)(m)
y = m(x, seq_lens)
loss = torch.sum(y)
loss.backward()
grad = x.grad.clone()
x.grad.zero_()
y_traced = m_traced(x, seq_lens)
loss_traced = torch.sum(y_traced)
loss_traced.backward()
grad_traced = x.grad.clone()
self.assertEqual(y_traced, x)
self.assertEqual(y_traced, y)
self.assertEqual(grad, grad_traced)
f = io.BytesIO()
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
def test_script_outputs(self):
with self.assertRaisesRegex(RuntimeError, "value cannot be used as a tuple"):
@torch.jit.script
def foo(a):
c, d = a + a
return c + d
@torch.jit.script
def return3():
return 1, 2, 3
with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
@torch.jit.script
def bind2():
a, b = return3()
print(a)
print(b)
def test_script_chunk(self):
@torch.jit.script
def foo(a):
b, c = torch.chunk(a, dim=0, chunks=2)
return b
v = torch.rand(10, 3)
self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
@torch.jit.script
def foo(a):
b, c = torch.chunk(a, dim=0, chunks=3)
return b
def test_rnn_trace_override(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
num_layers = 3
T, B, C = 11, 5, 7
class RNNTraceWrapper(torch.nn.Module):
def __init__(self, cell_type):
super(RNNTraceWrapper, self).__init__()
if cell_type == 'RNN':
self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
elif cell_type == 'LSTM':
self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
elif cell_type == 'GRU':
self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
def forward(self, x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, _ = self.rnn(x)
x, _ = pad_packed_sequence(x)
return x
for cell_type in ['RNN', 'LSTM', 'GRU']:
x = torch.ones(T, B, C, requires_grad=True)
seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
m = RNNTraceWrapper(cell_type)
m_traced = torch.jit.trace(x, seq_lens)(m)
y = m(x, seq_lens)
loss = torch.sum(y)
loss.backward()
grad = x.grad.clone()
x.grad.zero_()
y_traced = m_traced(x, seq_lens)
loss_traced = torch.sum(y_traced)
loss_traced.backward()
grad_traced = x.grad.clone()
self.assertEqual(y_traced, y)
self.assertEqual(grad, grad_traced)
f = io.BytesIO()
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
def test_tuples(self):
@torch.jit.script
def foo(i):
a = torch.chunk(i, dim=0, chunks=2)
c = a
# some nonsense with if-statements and loops to check
# that tuple lowering doesn't fail
if True:
c = torch.chunk(i, dim=0, chunks=2)
t0, t1 = c
while False:
t0, t1 = c
c = torch.chunk(i, dim=0, chunks=2)
return t0
v = torch.rand(10, 3)
self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
with self.assertRaisesRegex(RuntimeError, "variable 'a' previously has type"):
@torch.jit.script
def mixtypes():
a = torch.chunk(1, dim=0, chunks=2)
if True:
a = 4
# Smoke tests for export methods
class TestPytorchExportModes(unittest.TestCase):
class MyModel(nn.Module):
def __init__(self):
super(TestPytorchExportModes.MyModel, self).__init__()
def forward(self, x):
return x.t()
def test_protobuf(self):
torch_model = TestPytorchExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
def test_zipfile(self):
torch_model = TestPytorchExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
def test_compressed_zipfile(self):
torch_model = TestPytorchExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
def test_directory(self):
torch_model = TestPytorchExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
d = tempfile.mkdtemp()
torch.onnx._export(torch_model, (fake_input), d, verbose=False,
export_type=torch.onnx.ExportTypes.DIRECTORY)
shutil.rmtree(d)
if __name__ == '__main__':
run_tests()