blob: cdc5d01066cc44720a058f7dc597661df1413c96 [file] [log] [blame]
import contextlib
import gc
import sys
import math
import torch
import unittest
from copy import deepcopy
from collections import OrderedDict
from common import make_jacobian, TestCase, iter_tensors, \
get_numerical_jacobian, run_tests
from torch.autograd._functions import *
from torch.autograd import Variable, Function
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
PRECISION = 1e-4
def iter_gradients(x):
if isinstance(x, Variable):
if x.requires_grad:
yield x.grad.data
else:
for elem in x:
for result in iter_gradients(elem):
yield result
def zero_gradients(i):
for t in iter_gradients(i):
t.zero_()
def get_analytical_jacobian(input, output):
jacobian = make_jacobian(input, output.numel())
grad_output = output.data.clone().zero_()
flat_grad_output = grad_output.view(-1)
for i in range(flat_grad_output.numel()):
flat_grad_output.zero_()
flat_grad_output[i] = 1
zero_gradients(input)
output.backward(grad_output, retain_variables=True)
for jacobian_x, d_x in zip(jacobian, iter_gradients(input)):
jacobian_x[:, i] = d_x
return jacobian
@contextlib.contextmanager
def backward_engine(engine):
_prev_engine = Variable._execution_engine
Variable._execution_engine = engine()
try:
yield
finally:
Variable._execution_engine = _prev_engine
class TestAutograd(TestCase):
def test_hooks(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
counter = [0]
def bw_hook(inc, grad):
self.assertIsInstance(grad, Variable)
counter[0] += inc
z = x ** 2 + x * 2 + x * y + y
test = z.register_hook(lambda *args: bw_hook(1, *args))
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(counter[0], 1)
test2 = z.register_hook(lambda *args: bw_hook(2, *args))
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(counter[0], 4)
test2.remove()
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(counter[0], 5)
def bw_hook_modify(grad):
return grad.mul(2)
test.remove()
z.register_hook(bw_hook_modify)
y.grad.data.zero_()
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(y.grad.data, (x.data + 1) * 2)
y.register_hook(bw_hook_modify)
y.grad.data.zero_()
z.backward(torch.ones(5, 5))
self.assertEqual(y.grad.data, (x.data + 1) * 4)
def test_hook_none(self):
# WARNING: this is a test for autograd internals.
# You should never have to use such things in your code.
class NoneGradientFunction(Function):
def forward(self, x, y):
assert self.needs_input_grad[0]
assert not self.needs_input_grad[1]
return x, y
def backward(self, grad_x, grad_y):
return grad_x, None
fn = NoneGradientFunction()
fn._backward_hooks = OrderedDict()
was_called = [False]
def hook(grad_input, grad_output):
self.assertIsInstance(grad_input, tuple)
self.assertIsInstance(grad_output, tuple)
self.assertIsNotNone(grad_input[0])
self.assertIsNone(grad_input[1])
self.assertIsNotNone(grad_output[0])
self.assertIsNotNone(grad_output[1])
was_called[0] = True
fn._backward_hooks[id(hook)] = hook
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5))
sum(fn(x, y)).sum().backward()
self.assertTrue(was_called[0])
def _test_backward(self):
v_t = torch.randn(5, 5)
x_t = torch.randn(5, 5)
y_t = torch.rand(5, 5) + 0.1
z_t = torch.randn(5, 5)
grad_output = torch.randn(5, 5)
v = Variable(v_t, requires_grad=True)
x = Variable(x_t, requires_grad=True)
y = Variable(y_t, requires_grad=True)
z = Variable(z_t, requires_grad=True)
v.backward(grad_output)
self.assertEqual(v.grad.data, grad_output)
a = x + (y * z) + 4 * z ** 2 * x / y
a.backward(grad_output)
x_grad = 4 * z_t.pow(2) / y_t + 1
y_grad = z_t - 4 * x_t * z_t.pow(2) / y_t.pow(2)
z_grad = 8 * x_t * z_t / y_t + y_t
self.assertEqual(x.grad.data, x_grad * grad_output)
self.assertEqual(y.grad.data, y_grad * grad_output)
self.assertEqual(z.grad.data, z_grad * grad_output)
def test_backward(self):
self._test_backward()
@unittest.skip("BasicEngine is out of date")
def test_backward_basic_engine(self):
with backward_engine(torch.autograd.engine.BasicEngine):
self._test_backward()
def test_multi_backward(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
q = Variable(torch.randn(5, 5), requires_grad=True)
a = Variable(torch.randn(5, 5), requires_grad=True)
b = Variable(torch.randn(5, 5), requires_grad=True)
q2 = q * 2
z = x + y + q2
c = a * b + q2
grad_z = torch.randn(5, 5)
grad_c = torch.randn(5, 5)
torch.autograd.backward([z, c], [grad_z, grad_c])
self.assertEqual(x.grad.data, grad_z)
self.assertEqual(y.grad.data, grad_z)
self.assertEqual(a.grad.data, grad_c * b.data)
self.assertEqual(b.grad.data, grad_c * a.data)
self.assertEqual(q.grad.data, (grad_c + grad_z) * 2)
def test_multi_backward_stochastic(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
z = x + y
q = torch.normal(x)
q.reinforce(torch.randn(5, 5))
torch.autograd.backward([z, q], [torch.ones(5, 5), None])
def test_multi_backward_no_grad(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=False)
z = x + y
q = y * 2
torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)])
self.assertEqual(x.grad.data, torch.ones(5, 5))
def test_volatile(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, volatile=True)
z = x ** 2
self.assertFalse(z.volatile)
self.assertTrue(z.requires_grad)
self.assertIsNotNone(z.creator)
z.backward(torch.ones(5, 5))
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
w = z + y
self.assertTrue(w.volatile)
self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.creator)
def test_indexing(self):
x = torch.range(1, 16).resize_(4, 4)
y = Variable(x)
self.assertEqual(x[1], y[1].data)
self.assertEqual(x[1, 1], y[1, 1].data[0])
self.assertEqual(x[1:], y[1:].data)
self.assertEqual(x[:2], y[:2].data)
self.assertEqual(x[:2, 2], y[:2, 2].data)
self.assertEqual(x[1:2, 2], y[1:2, 2].data)
self.assertEqual(x[1, 2:], y[1, 2:].data)
def test_requires_grad(self):
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5))
z = Variable(torch.randn(5, 5), requires_grad=True)
a = x + y
self.assertFalse(a.requires_grad)
b = a + z
self.assertTrue(b.requires_grad)
def error():
raise RuntimeError
# Make sure backward isn't called on these
a._backward_hooks = OrderedDict()
x._backward_hooks = OrderedDict()
y._backward_hooks = OrderedDict()
a._backward_hooks['test'] = error
x._backward_hooks['test'] = error
y._backward_hooks['test'] = error
b.backward(torch.ones(5, 5))
def test_inplace(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
z = x * y
q = z + y
w = z * y
z.add_(2)
# Add doesn't need it's inputs to do backward, so it shouldn't raise
q.backward(torch.ones(5, 5), retain_variables=True)
# Mul saves both inputs in forward, so it should raise
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
z = x * y
q = z * y
r = z + y
w = z.add_(y)
# w is a the last expression, so this should succeed
w.backward(torch.ones(5, 5), retain_variables=True)
# r doesn't use the modified value in backward, so it should succeed
r.backward(torch.ones(5, 5), retain_variables=True)
# q uses dirty z, so it should raise
self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
x.grad.data.zero_()
m = x / 2
z = m + y / 8
q = z * y
r = z + y
prev_version = z._version
w = z.exp_()
self.assertNotEqual(z._version, prev_version)
r.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(x.grad.data, torch.ones(5, 5) / 2)
w.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(x.grad.data, torch.Tensor(5, 5).fill_((1 + math.e) / 2))
self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
leaf = Variable(torch.ones(5, 5), requires_grad=True)
x = leaf.clone()
x.add_(10)
self.assertEqual(x.data, torch.ones(5, 5) * 11)
# x should be still usable
y = x + 2
y.backward(torch.ones(5, 5))
self.assertEqual(leaf.grad.data, torch.ones(5, 5))
z = x * y
x.add_(2)
self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
def test_shared_storage(self):
x = Variable(torch.ones(5, 5))
y = x.t()
z = x[1]
self.assertRaises(RuntimeError, lambda: x.add_(2))
self.assertRaises(RuntimeError, lambda: y.add_(2))
self.assertRaises(RuntimeError, lambda: z.add_(2))
def _test_setitem(self, size, index):
x = Variable(torch.ones(*size), requires_grad=True)
y = x + 2
y_version = y._version
y[index] = 2
self.assertNotEqual(y._version, y_version)
y.backward(torch.ones(*size))
expected_grad = torch.ones(*size)
if isinstance(index, Variable):
index = index.data
expected_grad[index] = 0
self.assertEqual(x.grad.data, expected_grad)
def _test_setitem_tensor(self, size, index):
x = Variable(torch.ones(*size), requires_grad=True)
y = x + 2
y_version = y._version
value = Variable(torch.Tensor(x[index].size()).fill_(7), requires_grad=True)
y[index] = value
self.assertNotEqual(y._version, y_version)
y.backward(torch.ones(*size))
expected_grad_input = torch.ones(*size)
if isinstance(index, Variable):
index = index.data
expected_grad_input[index] = 0
self.assertEqual(x.grad.data, expected_grad_input)
self.assertEqual(value.grad.data, torch.ones(value.size()))
def test_setitem(self):
self._test_setitem((5, 5), 1)
self._test_setitem((5,), 1)
self._test_setitem((1,), 0)
self._test_setitem_tensor((5, 5), 3)
self._test_setitem_tensor((5,), 3)
def test_setitem_mask(self):
mask = torch.ByteTensor(5, 5).bernoulli_()
self._test_setitem((5, 5), Variable(mask))
self._test_setitem((5,), Variable(mask[0]))
self._test_setitem((1,), Variable(mask[0, 0:1]))
self._test_setitem_tensor((5, 5), Variable(mask))
self._test_setitem_tensor((5,), Variable(mask[0]))
def test_unused_output(self):
x = Variable(torch.randn(10, 10), requires_grad=True)
outputs = x.chunk(5)
o = outputs[2]
o = o * 4 + 2
o.sum().backward()
expected_grad = torch.zeros(10, 10)
expected_grad[4:6] = 4
self.assertEqual(x.grad.data, expected_grad)
x.grad.data.zero_()
grad_output = torch.randn(2, 10)
outputs = x.chunk(5)
outputs[0].backward(grad_output)
expected_grad = torch.zeros(10, 10)
expected_grad[:2] = grad_output
self.assertEqual(x.grad.data, expected_grad)
def test_gc_in_destructor(self):
"""
Previously, if a Function destructor triggered a garbage collection,
the Variable's tp_dealloc handler would get called twice leading to a
segfault.
"""
class CollectOnDelete(Function):
def __del__(self):
gc.collect()
for i in range(10):
Variable(torch.randn(10, 10), creator=CollectOnDelete())
@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2,
"CUDA not available or <2 GPUs detected")
def test_unused_output_gpu(self):
from torch.nn.parallel._functions import Broadcast
x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True)
outputs = Broadcast(list(range(torch.cuda.device_count())))(x)
y = outputs[-1] * 2
y.sum().backward()
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
def test_detach(self):
x = Variable(torch.randn(10, 10), requires_grad=True)
y = x + 2
y = y.detach()
z = y * 4 + 2
self.assertFalse(y.requires_grad)
self.assertFalse(z.requires_grad)
x = Variable(torch.randn(10, 10), requires_grad=True)
y = x * 2
y = y.detach()
self.assertFalse(y.requires_grad)
self.assertFalse(y.creator.requires_grad)
z = x + y
z.sum().backward()
# This is an incorrect gradient, but we assume that's what the user
# wanted. detach() is an advanced option.
self.assertEqual(x.grad.data, torch.ones(10, 10))
def test_type_conversions(self):
import torch.cuda
x = Variable(torch.randn(5, 5))
self.assertIs(type(x.float().data), torch.FloatTensor)
self.assertIs(type(x.int().data), torch.IntTensor)
if torch.cuda.is_available():
self.assertIs(type(x.float().cuda().data), torch.cuda.FloatTensor)
self.assertIs(type(x.int().cuda().data), torch.cuda.IntTensor)
self.assertIs(type(x.int().cuda().cpu().data), torch.IntTensor)
if torch.cuda.device_count() > 2:
x2 = x.float().cuda(1)
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)
x2 = x.float().cuda()
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 0)
x2 = x2.cuda(1)
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)
def test_return_leaf(self):
class Identity(Function):
def forward(self, a, b):
return a, a + b
def backward(self, grad_a, grad_b):
return grad_a + grad_b, grad_b
class Inplace(InplaceFunction):
def forward(self, a, b):
self.mark_dirty(a)
return a.add_(b), b + 2
def backward(self, grad_a, grad_b):
return grad_a, grad_a + grad_b
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
q, p = Identity()(x, y)
# Make sure hooks only receive grad from usage of q, not x.
q.register_hook(
lambda grad: self.assertEqual(grad.data, torch.ones(5, 5)))
(q + p + x).sum().backward()
self.assertEqual(x.grad.data, torch.ones(5, 5) * 3)
self.assertEqual(y.grad.data, torch.ones(5, 5))
del q, p # these need to be freed, or next part will raise an error
def test_return_leaf_inplace(self):
class Inplace(InplaceFunction):
def forward(self, a, b):
self.mark_dirty(a)
return a.add_(b), b + 2
def backward(self, grad_a, grad_b):
return grad_a, grad_a + grad_b
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5), requires_grad=True)
fn = Inplace(True)
q, p = fn(x, y)
self.assertIs(q, x)
self.assertIs(q.creator, fn)
self.assertTrue(q.requires_grad)
q.sum().backward()
self.assertEqual(y.grad.data, torch.ones(5, 5))
def test_leaf_assignment(self):
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5), requires_grad=True)
z = Variable(torch.randn(5), requires_grad=True)
x[0] = y
x[1] = 2 * z
self.assertTrue(x.requires_grad)
self.assertIsNot(x.creator, None)
x.sum().backward()
self.assertEqual(y.grad.data, torch.ones(5))
self.assertEqual(z.grad.data, torch.ones(5) * 2)
def test_backward_copy(self):
# This tests checks backward engine for a very subtle bug that appreared
# in one of the initial versions of autograd. Gradients tensors were
# simply stored in lists while the function waited for all its gradients
# to be computed. However, sometimes an output was used multiple times,
# so the gradients needed to be summed. Engine used to keep a need_copy
# set of tensors that will need a clone upon next addition and removed
# them from the set as soon as the clone was performed. However, this
# could lead to incorrect results if the same gradient tensor was
# buffered in three places in the graph:
# 1. When accumulating gradients in one of these places it was cloned
# and removed from need_copy set.
# 2. When accumulating in second place, it wasn't in the need_copy set,
# so the gradients were simply accumulated in-place (which already
# modified the grad in 3rd place)
# 3. When accumulating in the third place, it wasn't in the need_copy set
# as well, so the incoming gradient was summed in-place, yielding
# incorrect results in all functions, except the first one.
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5), requires_grad=True)
# Simulate that we're in the middle of the graph
a = x + 2
b = y + 2
c = x + 2
# This op will just return grad_output two times in backward
add1 = a + b
add2 = add1 + c
# Simulate a long branch, so grad_output will get buffered.
for i in range(4):
a = a * 2
b = b * 2
c = c * 2
branch = a + b + c
out = add2 + branch
# expected gradients are:
# for x: 34 (16 from final a, 16 from final c, 2 from add2)
# for y: 17 (16 from final b, 1 from add2)
grad_output = torch.ones(5, 5)
out.backward(grad_output)
self.assertEqual(x.grad.data, torch.ones(5, 5) * 34)
self.assertEqual(y.grad.data, torch.ones(5, 5) * 17)
def test_functional_blas(self):
def compare(fn, *args):
unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
for arg in args)
self.assertEqual(fn(*args).data, fn(*unpacked_args))
def test_blas_add(fn, x, y, z):
# Checks all signatures
compare(fn, x, y, z)
compare(fn, 0.5, x, y, z)
compare(fn, 0.5, x, 0.25, y, z)
def test_blas(fn, x, y):
compare(fn, x, y)
test_blas(torch.mm, Variable(torch.randn(2, 10)),
Variable(torch.randn(10, 4)))
test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
test_blas(torch.mv, Variable(torch.randn(2, 10)),
Variable(torch.randn(10)))
test_blas_add(torch.addmv, Variable(torch.randn(2)),
Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
test_blas(torch.ger, Variable(torch.randn(5)),
Variable(torch.randn(6)))
test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
Variable(torch.randn(5)), Variable(torch.randn(6)))
def test_save_none_for_backward(self):
test_case = self
class MyFn(Function):
def forward(self, input):
self.save_for_backward(None, input, None)
return input * input
def backward(self, grad_output):
n1, input, n2 = self.saved_tensors
test_case.assertIsNone(n1)
test_case.assertIsNone(n2)
return 2 * input * grad_output
x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
self.assertEqual(x.grad.data, 2 * x.data)
def test_too_many_grads(self):
class MyFn(Function):
def forward(self, input):
return input
def backward(self, grad_output):
return grad_output, None, None
x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
self.assertEqual(x.grad.data, x.data.clone().fill_(1))
def test_stochastic(self):
x = Variable(torch.rand(2, 10), requires_grad=True)
stddevs = Variable(torch.rand(2, 10) * 5, requires_grad=True)
y = (x * 2).clamp(0, 1)
y = y / y.sum(1).expand_as(y)
samples_multi = y.multinomial(5)
samples_multi_flat = y[0].multinomial(5)
samples_bernoulli = y.bernoulli()
samples_norm = torch.normal(y)
samples_norm_std = torch.normal(y, stddevs)
z = samples_multi * 2 + 4
z = z + samples_multi_flat.unsqueeze(0).expand_as(samples_multi)
z = torch.cat([z, z], 1)
z = z.double()
z = z + samples_bernoulli + samples_norm + samples_norm_std
last_sample = torch.normal(z, 4)
z = last_sample + 2
self.assertFalse(z.requires_grad)
self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
samples_multi.reinforce(torch.randn(2, 5))
self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
samples_multi_flat.reinforce(torch.randn(5))
self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
samples_bernoulli.reinforce(torch.randn(2, 10))
self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
samples_norm.reinforce(torch.randn(2, 10))
self.assertRaises(RuntimeError, lambda: z.backward(retain_variables=True))
samples_norm_std.reinforce(torch.randn(2, 10))
# We don't have to specify rewards w.r.t. last_sample - it doesn't
# require gradient
last_sample.backward(retain_variables=True)
z.backward()
self.assertGreater(x.grad.data.abs().sum(), 0)
def test_stochastic_sequence(self):
x = Variable(torch.rand(10).clamp_(0, 1), requires_grad=True)
b = x.bernoulli()
n1 = torch.normal(b, x)
n2 = torch.normal(n1, 2)
b.reinforce(torch.randn(10))
n1.reinforce(torch.randn(10))
n2.reinforce(torch.randn(10))
n2.backward()
self.assertGreater(x.grad.data.abs().sum(), 0)
def test_stochastic_output(self):
x = Variable(torch.rand(10), requires_grad=True)
b = x.clone().clamp(0, 1).bernoulli()
b.reinforce(torch.randn(10))
b.backward()
self.assertGreater(x.grad.data.abs().sum(), 0)
def test_pickle(self):
x = Variable(torch.randn(10, 10), requires_grad=True)
y = Variable(torch.randn(10, 10), volatile=True)
z = Variable(torch.randn(10, 10), requires_grad=False)
def assert_strict_equal(var1, var2):
self.assertEqual(var1.data, var2.data)
self.assertEqual(var1.requires_grad, var2.requires_grad)
self.assertEqual(var1.volatile, var2.volatile)
serialized = [pickle.dumps([x, y, z], protocol=p) for p in range(3)]
for dump in serialized:
xc, yc, zc = pickle.loads(dump)
assert_strict_equal(xc, x)
assert_strict_equal(yc, y)
assert_strict_equal(zc, z)
def test_dep_nograd(self):
class F1(Function):
def forward(self, input):
out = torch.randn(input.size())
self.mark_non_differentiable(out)
return input, out
def backward(self, grad_output, ignored):
return grad_output
class F2(Function):
def forward(self, input, ignored):
return input
def backward(self, grad_output):
return grad_output, None
x = Variable(torch.randn(5), requires_grad=True)
a, b = F1()(x)
b = b + 1 # separate F1 from F2 by another op
self.assertTrue(a.requires_grad)
self.assertFalse(b.requires_grad)
c = F2()(a, b)
c.backward(torch.ones(c.size()))
self.assertEqual(x.grad.data, torch.ones(x.size()))
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
shape = (shape,)
index = torch.rand(*shape).mul_(max_indices).floor_().long()
return Variable(index, requires_grad=False)
def gather_variable(shape, index_dim, max_indices):
assert len(shape) == 2
assert index_dim < 2
batch_dim = 1 - index_dim
index = torch.LongTensor(*shape)
for i in range(shape[index_dim]):
index.select(index_dim, i).copy_(
torch.randperm(max_indices)[:shape[batch_dim]])
return Variable(index, requires_grad=False)
L = 20
M = 10
S = 5
function_tests = [
(Add, (), ((M, M), (M, M))),
(Sub, (), ((M, M), (M, M))),
(Mul, (), ((M, M), (M, M))),
(Div, (), ((M, M), torch.rand(M, M) + 5e-2)),
(Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)),
(AddConstant, (3.14,), ((L, L),)),
(SubConstant, (3.14,), ((L, L),)),
(SubConstant, (3.14, True), ((L, L),), 'from_tensor'),
(MulConstant, (3.14,), ((L, L),)),
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor'),
(PowConstant, (3.14,), (torch.rand(L, L),)),
(PowConstant, (3.14, True), (torch.rand(L, L),), 'tensor_power'),
(Transpose, (0, 1), (torch.rand(L, L),)),
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d'),
(Permute, ((0, 4, 3, 5, 1, 2),), ((1, 2, 3, 4, 5, 6),)),
(Index, ((1, 2),), (torch.rand(S, S, S),)),
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice'),
(Index, ((slice(0, 3), 1),), (torch.rand(S, S, S),), 'slice_index'),
(View, (S * S, S), (torch.rand(S, S, S),)),
(Expand, ((S, 5, S, 5),), ((S, 1, S, 1),)),
(Exp, (), (torch.rand(S, S, S),)),
(Log, (), (torch.rand(S, S, S) + 1e-2,)),
(Log1p, (), (torch.rand(S, S, S),)),
(Tanh, (), ((S, S, S),)),
(Sigmoid, (), ((S, S, S),)),
(Sinh, (), ((S, S, S),)),
(Cosh, (), ((S, S, S),)),
(Abs, (), ((S, S, S),)),
(Clamp, (0, 1), ((S, S, S),)),
(Sqrt, (), (torch.rand(S, S, S) + 5e-4,)),
(Sin, (), ((S, S, S),)),
(Cos, (), ((S, S, S),)),
(Tan, (), (torch.randn(S, S, S).clamp(-1, 1),)),
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),)),
(Atan, (), ((S, S, S),)),
(Reciprocal, (), (torch.rand(S, S, S) + 0.1,)),
(Cmax, (), ((S, S, S), (S, S, S))),
(Cmin, (), ((S, S, S), (S, S, S))),
(Round, (), ((S, S, S),)),
(Sign, (), ((S, S, S),)),
(Trunc, (), ((S, S, S),)),
(Floor, (), ((S, S, S),)),
(Ceil, (), ((S, S, S),)),
(Frac, (), ((S, S, S),)),
(Fmod, (1.5,), ((S, S, S),)),
(Lerp, (0.2,), ((S, S, S), (S, S, S))),
(Rsqrt, (), (torch.rand(S, S, S) + 1e-2,)),
(Remainder, (1.5,), ((S, S, S),)),
(CmaxConstant, (0.5,), ((S, S, S),)),
(CminConstant, (0.5,), ((S, S, S),)),
(Mean, (), ((S, S, S),)),
(Mean, (1,), ((S, S, S),), 'dim'),
(Sum, (), ((S, S, S),)),
(Sum, (1,), ((S, S, S),), 'dim'),
(Prod, (), ((S, S, S),)),
(Prod, (1,), ((S, S, S),), 'dim'),
(Addmm, (), ((S, M), (S, S), (S, M)),),
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef'),
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)),),
(Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef'),
(Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)),),
(Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef'),
(Addmv, (), ((S,), (S, M), (M,)),),
(Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef'),
(Addr, (), ((S, M), (S,), (M,)),),
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef'),
(Dot, (), ((L,), (L,)),),
(Max, (), ((S, S, S),),),
(Repeat, (torch.Size([2, 3, 1, 4]),), ((S, S, S, S),)),
(Min, (), ((S, S, S),),),
(Max, (0,), ((S, S, S),), 'dim'),
(Min, (0,), ((S, S, S),), 'dim'),
(Mode, (0,), ((S, S, S),),),
(Kthvalue, (2, 0), ((S, S, S),),),
(Median, (0,), ((S, S, S),),),
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
(Norm, (), ((S, S, S),), '2'),
(Norm, (3,), ((S, S, S),), '3'),
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1_5_dim'),
(Norm, (2, 0), ((S, S, S),), '2_dim'),
(Norm, (3, 0), ((S, S, S),), '3_dim'),
(Addcmul, (), ((S, S), (S, S), (S, S))),
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale'),
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2)),
(Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'),
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S))),
# (IndexCopy, (0,), ((S, S), index_variable(2, S), (2, S)) ),
(IndexFill, (0, 2), ((S, S), index_variable(2, S))),
(IndexSelect, (0,), ((S, S), index_variable(2, S))),
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M))),
(Gather, (1,), ((M, S), gather_variable((M, S // 2), 0, S)), 'dim1'),
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
(Scatter, (1,), ((M, S), gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1'),
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S))),
(Resize, (S * S, S), ((S, S, S),)),
(Diag, (), ((S, S),), '2d'),
(Diag, (), ((S,),), '1d'),
(Tril, (), ((S, S),)),
(Tril, (2,), ((S, S),), 'idx'),
(Triu, (), ((S, S),)),
(Triu, (2,), ((S, S),), 'idx'),
(Clone, (), ((S, M, S),)),
(Squeeze, (), ((S, 1, M, 1),)),
(Squeeze, (1,), ((S, 1, M, 1),), 'dim'),
(Unsqueeze, (0,), ((S, M, S),), '0'),
(Unsqueeze, (1,), ((S, M, S),), '1'),
# (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)),
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(Sort, (), ((S, M, S),)),
(Sort, (1,), ((S, M, S),), 'dim'),
(Sort, (1, True), ((S, M, S),), 'dim_desc'),
(Topk, (3,), ((S, M, S),)),
(Topk, (3, 1), ((S, M, S),), 'dim'),
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc'),
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort'),
]
method_tests = [
('add', (S, S, S), ((S, S, S),)),
('add', (S, S, S), (3.14,), 'constant'),
('sub', (S, S, S), ((S, S, S),)),
('sub', (S, S, S), (3.14,), 'constant'),
('mul', (S, S, S), ((S, S, S),)),
('mul', (S, S, S), (3.14,), 'constant'),
('div', (S, S, S), ((S, S, S),)),
('div', (S, S, S), (3.14,), 'constant'),
('pow', (S, S, S), ((S, S, S),)),
('pow', (S, S, S), (3.14,), 'constant'),
('transpose', (1, 2, 3), (1, 2)),
('t', (1, 2), ()),
('view', (S, S, S), (S * S, S),),
('view_as', (S, S, S), ((S * S, S),)),
('expand', (S, 1, S), (S, S, S)),
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
('exp', (S, S, S), ()),
('log', (S, S, S), ()),
('log1p', (S, S, S), ()),
('tanh', (S, S, S), ()),
('sigmoid', (S, S, S), ()),
('sinh', (S, S, S), ()),
('cosh', (S, S, S), ()),
('abs', (S, S, S), ()),
('clamp', (S, S, S), (0, 1)),
('sqrt', (S, S, S), ()),
('sin', (S, S, S), ()),
('cos', (S, S, S), ()),
('tan', (S, S, S), ()),
('asin', (S, S, S), ()),
('acos', (S, S, S), ()),
('atan', (S, S, S), ()),
('reciprocal', (S, S, S), ()),
('round', (S, S, S), ()),
('sign', (S, S, S), ()),
('trunc', (S, S, S), ()),
('floor', (S, S, S), ()),
('ceil', (S, S, S), ()),
('rsqrt', (S, S, S), ()),
('fmod', (S, S, S), (1.5,)),
('remainder', (S, S, S), (1.5,)),
('lerp', (S, S, S), ((S, S, S), 0.4)),
('max', (S, S, S), ()),
('max', (S, S, S), ((S, S, S),), 'elementwise'),
('min', (S, S, S), ()),
('min', (S, S, S), ((S, S, S),), 'elementwise'),
('mean', (S, S, S), ()),
('mean', (S, S, S), (1,), 'dim'),
('sum', (S, S, S), ()),
('sum', (S, S, S), (1,), 'dim'),
('prod', (S, S, S), ()),
('prod', (S, S, S), (1,), 'dim'),
('var', (S, S, S), ()),
('var', (S, S, S), (1,), 'dim'),
('std', (S, S, S), ()),
('std', (S, S, S), (1,), 'dim'),
('renorm', (S, S, S), (2, 1, 0.5)),
('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
('repeat', (S, S, S, S), (2, 3, 1, 4)),
('addmm', (S, M), ((S, S), (S, M)),),
('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef'),
('addbmm', (S, M), ((S, S, S), (S, S, M)),),
('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'),
('addmv', (S,), ((S, M), (M,)),),
('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef'),
('addr', (S, M), ((S,), (M,)),),
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef'),
('dot', (L,), ((L,),),),
('addcmul', (S, S), ((S, S), (S, S))),
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale'),
('addcdiv', (S, S), ((S, S), (S, S))),
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'),
('norm', (S, S, S), (2,)),
('norm', (S, S, S), (2, 1), 'dim'),
('dist', (S, S, S), ((S, S, S),)),
('dist', (S, S, S), ((S, S, S), 4), '4'),
('index_select', (S, S, S), (0, index_variable(2, S))),
('diag', (M, M), (), '2d'),
('diag', (M,), (), '1d'),
('tril', (M, M), ()),
('triu', (M, M), ()),
('clone', (S, M, S), ()),
('eq', (S, S, S), ((S, S, S),)),
('ne', (S, S, S), ((S, S, S),)),
('gt', (S, S, S), ((S, S, S),)),
('ge', (S, S, S), ((S, S, S),)),
('lt', (S, S, S), ((S, S, S),)),
('le', (S, S, S), ((S, S, S),)),
('eq', (S, S, S), (0,), 'scalar'),
('ne', (S, S, S), (0,), 'scalar'),
('gt', (S, S, S), (0,), 'scalar'),
('ge', (S, S, S), (0,), 'scalar'),
('lt', (S, S, S), (0,), 'scalar'),
('le', (S, S, S), (0,), 'scalar'),
('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
('select', (S, S, S), (1, 2)),
('narrow', (S, S, S), (1, 2, 2)),
('squeeze', (S, 1, S, 1), ()),
('squeeze', (S, 1, S, 1), (1,), '1_dim'),
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim'),
('unsqueeze', (S, S, S), (0,), 'first'),
('unsqueeze', (S, S, S), (1,), 'middle'),
('unsqueeze', (S, S, S), (3,), 'last'),
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),)),
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10)),
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))),
]
# TODO: mm, bmm, mv, ger
# TODO: max, min with dim (problem with indices)
# TODO: mode, median, sort, kthvalue, topk (problem with indices)
# TODO: indexAdd, indexCopy, indexFill
# TODO: resize, resize_as (tensors only have resize_ and resize_as_)
# TODO: clamp with min/max
def create_input(call_args):
if not isinstance(call_args, tuple):
call_args = (call_args,)
def map_arg(arg):
if isinstance(arg, tuple) and not isinstance(arg[0], Variable):
return Variable(torch.randn(*arg).double(), requires_grad=True)
elif torch.is_tensor(arg):
if isinstance(arg, torch.FloatTensor):
return Variable(arg.double(), requires_grad=True)
else:
return Variable(arg, requires_grad=True)
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
def unpack_variables(args):
if isinstance(args, Variable):
return args.data
elif isinstance(args, tuple):
return tuple(unpack_variables(elem) for elem in args)
else:
return args
ignore_inplace = set((
'test_DivConstant_by_tensor',
))
for test in function_tests:
cls, constructor_args, call_args = test[:3]
test_name = 'test_' + cls.__name__ + ('_' + test[3] if len(test) == 4 else '')
def do_test(self, cls=cls, constructor_args=constructor_args,
call_args=call_args, test_name=test_name):
input = create_input(call_args)
output = cls(*constructor_args)(*input)
if not isinstance(output, tuple):
output = (output,)
for i, o in enumerate(output):
if not o.requires_grad:
continue
analytical = get_analytical_jacobian(input, o)
def fn(input):
tmp = cls(*constructor_args)(*input)
if not isinstance(tmp, tuple):
tmp = (tmp,)
return tmp[i].data
numerical = get_numerical_jacobian(fn, input, input)
self.assertLessEqual(
max(a.add(-1, n).abs().max() for a, n in zip(analytical, numerical)),
PRECISION
)
if test_name not in ignore_inplace and issubclass(cls, InplaceFunction):
inplace_input = deepcopy(input)
inplace_input_copy = tuple(i + 0 for i in inplace_input)
fn = cls(*constructor_args, inplace=True)
inplace_output = fn(*inplace_input_copy)
if not isinstance(inplace_output, tuple):
inplace_output = (inplace_output,)
self.assertEqual(inplace_output, output)
# Check that gradient is the same
for inp_i, i in zip(inplace_input, input):
if inp_i.grad is not None:
inp_i.grad.data.zero_()
if i.grad is not None:
i.grad.data.zero_()
for io, o in zip(inplace_output, output):
grad = torch.randn(*io.size()).double()
io.backward(grad)
o.backward(grad)
for inp_i, i in zip(inplace_input, input):
self.assertEqual(inp_i.grad, i.grad)
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
setattr(TestAutograd, test_name, do_test)
EXCLUDE_FUNCTIONAL = {
'addmm',
'addbmm',
'baddbmm',
'addmv',
'addr',
}
for test in method_tests:
name, self_size, args = test[:3]
test_name = 'test_' + name + ('_' + test[3] if len(test) == 4 else '')
def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name):
def check(name):
self_variable = create_input((self_size,))[0]
args_variable = create_input(args)
self_tensor = deepcopy(self_variable.data)
args_tensor = deepcopy(unpack_variables(args_variable))
output_variable = getattr(self_variable, name)(*args_variable)
output_tensor = getattr(self_tensor, name)(*args_tensor)
if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
output_tensor = torch.DoubleTensor((output_tensor,))
self.assertEqual(unpack_variables(output_variable), output_tensor)
# TODO: check that both have changed after adding all inplace ops
# functional interface tests
if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
f_args_variable = (self_variable,) + args_variable
f_args_tensor = (self_tensor,) + args_tensor
output_variable = getattr(torch, name)(*f_args_variable)
output_tensor = getattr(torch, name)(*f_args_tensor)
if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
output_tensor = torch.DoubleTensor((output_tensor,))
self.assertEqual(unpack_variables(output_variable), output_tensor)
check(name)
inplace_name = name + '_'
if hasattr(Variable(torch.ones(1)), inplace_name):
try:
check(inplace_name)
except Exception as e:
if 'only supports scalar' not in e.args[0]:
raise
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
setattr(TestAutograd, test_name, do_test)
if __name__ == '__main__':
run_tests()