blob: 6501f91dc2e1ab2501f2e99616b84789c91b7484 [file] [log] [blame]
import math
import unittest
from copy import deepcopy
from common import make_jacobian, TestCase, iter_tensors, get_numerical_jacobian
from torch.autograd.functions import *
PRECISION = 1e-4
def iter_gradients(x):
if isinstance(x, Variable):
if x.requires_grad:
yield x.grad
else:
for elem in x:
for result in iter_gradients(elem):
yield result
def zero_gradients(i):
for t in iter_gradients(i):
t.zero_()
def get_analytical_jacobian(input, output):
jacobian = make_jacobian(input, output.numel())
grad_output = output.data.clone().zero_()
flat_grad_output = grad_output.view(-1)
for i in range(flat_grad_output.numel()):
flat_grad_output.zero_()
flat_grad_output[i] = 1
zero_gradients(input)
output.backward(grad_output, retain_variables=True)
for jacobian_x, d_x in zip(jacobian, iter_gradients(input)):
jacobian_x[:,i] = d_x
return jacobian
class TestAutograd(TestCase):
def test_hooks(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
counter = [0]
def bw_hook(inc, grad):
self.assertTrue(torch.is_tensor(grad))
counter[0] += inc
z = x ** 2 + x * 2 + x * y + y
z.register_hook('test', lambda *args: bw_hook(1, *args))
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(counter[0], 1)
z.register_hook('test2', lambda *args: bw_hook(2, *args))
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(counter[0], 4)
z.remove_hook('test2')
z.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(counter[0], 5)
def test_backward(self):
v_t = torch.randn(5, 5)
x_t = torch.randn(5, 5)
y_t = torch.rand(5, 5) + 0.1
z_t = torch.randn(5, 5)
grad_output = torch.randn(5, 5)
v = Variable(v_t, requires_grad=True)
x = Variable(x_t, requires_grad=True)
y = Variable(y_t, requires_grad=True)
z = Variable(z_t, requires_grad=True)
v.backward(grad_output)
self.assertEqual(v.grad, grad_output)
a = x + (y * z) + 4 * z**2 * x / y
a.backward(grad_output)
x_grad = 4 * z_t.pow(2) / y_t + 1
y_grad = z_t - 4 * x_t * z_t.pow(2) / y_t.pow(2)
z_grad = 8 * x_t * z_t / y_t + y_t
self.assertEqual(x.grad, x_grad * grad_output)
self.assertEqual(y.grad, y_grad * grad_output)
self.assertEqual(z.grad, z_grad * grad_output)
def test_volatile(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, volatile=True)
z = x ** 2
self.assertFalse(z.volatile)
self.assertTrue(z.requires_grad)
self.assertIsNotNone(z.creator)
z.backward(torch.ones(5, 5))
self.assertEqual(x.grad, torch.ones(5, 5) * 2)
w = z + y
self.assertTrue(w.volatile)
self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.creator)
def test_indexing(self):
x = torch.range(1, 16).resize_(4, 4)
y = Variable(x)
self.assertEqual(x[1], y[1].data)
self.assertEqual(x[1, 1], y[1, 1].data[0])
self.assertEqual(x[1:], y[1:].data)
self.assertEqual(x[:2], y[:2].data)
self.assertEqual(x[:2, 2], y[:2, 2].data)
self.assertEqual(x[1:2, 2], y[1:2, 2].data)
self.assertEqual(x[1, 2:], y[1, 2:].data)
def test_requires_grad(self):
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5))
z = Variable(torch.randn(5, 5), requires_grad=True)
a = x + y
self.assertFalse(a.requires_grad)
b = a + z
self.assertTrue(b.requires_grad)
def error():
raise RuntimeError
# Make sure backward isn't called on these
a.backward_hooks['test'] = error
x.backward_hooks['test'] = error
y.backward_hooks['test'] = error
def test_inplace(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
z = x * y
q = z + y
w = z * y
z.dirty = True
# Add doesn't need it's inputs to do backward, so it shouldn't raise
q.backward(torch.ones(5, 5))
# Mul saves both inputs in forward, so it should raise
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
z = x * y
q = z * y
r = z + y
w = z.add_(y)
# w is a the last expression, so this should succeed
w.backward(torch.ones(5, 5), retain_variables=True)
# r doesn't use the modified value in backward, so it should succeed
r.backward(torch.ones(5, 5), retain_variables=True)
# q uses dirty z, so it should raise
self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
x.grad.zero_()
m = x / 2
z = m + y / 8
q = z * y
r = z + y
w = z.exp_()
self.assertTrue(z.dirty)
r.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(x.grad, torch.ones(5, 5) / 2)
w.backward(torch.ones(5, 5), retain_variables=True)
self.assertEqual(x.grad, torch.Tensor(5, 5).fill_((1 + math.e) / 2))
self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
def test_type_conversions(self):
import torch.cuda
x = Variable(torch.randn(5, 5))
self.assertIs(type(x.float().data), torch.FloatTensor)
self.assertIs(type(x.int().data), torch.IntTensor)
if torch.cuda.is_available():
self.assertIs(type(x.float().cuda().data), torch.cuda.FloatTensor)
self.assertIs(type(x.int().cuda().data), torch.cuda.IntTensor)
self.assertIs(type(x.int().cuda().cpu().data), torch.IntTensor)
if torch.cuda.device_count() > 2:
x2 = x.float().cuda(1)
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)
x2 = x.float().cuda()
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 0)
x2 = x2.cuda(1)
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)
def index_variable(num_indices, max_indices):
index = torch.randperm(max_indices)[:num_indices].long()
return Variable(index, requires_grad=False)
L = 20
M = 10
S = 5
function_tests = [
(Add, (), ((M, M), (M, M)) ),
(Sub, (), ((M, M), (M, M)) ),
(Mul, (), ((M, M), (M, M)) ),
(Div, (), ((M, M), torch.rand(M, M) + 5e-2) ),
(Pow, (), (torch.rand(M, M) + 1e-3, torch.rand(M, M) + 0.1)),
(AddConstant, (3.14,), ((L, L),) ),
(SubConstant, (3.14,), ((L, L),) ),
(SubConstant, (3.14, True), ((L, L),), 'from_tensor' ),
(MulConstant, (3.14,), ((L, L),) ),
(DivConstant, (3.14, True), (torch.rand(L, L) + 1e-1,), 'by_tensor' ),
(PowConstant, (3.14,), (torch.rand(L, L),) ),
(Transpose, (0, 1), (torch.rand(L, L),) ),
(Transpose, (2, 0), (torch.rand(S, S, S),), '3d' ),
(Permute, (0, 4, 3, 5, 1, 2), ((1, 2, 3, 4, 5, 6),) ),
(Index, ((1, 2),), (torch.rand(S, S, S),) ),
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice' ),
(Index, ((slice(0, 3), 1),),(torch.rand(S, S, S),), 'slice_index' ),
(View, (S*S, S), (torch.rand(S, S, S),) ),
(Expand, (S, 5, S, 5), ((S, 1, S, 1),) ),
(Exp, (), (torch.rand(S, S, S),) ),
(Log, (), (torch.rand(S, S, S) + 1e-2,) ),
(Log1p, (), (torch.rand(S, S, S),) ),
(Tanh, (), ((S, S, S),) ),
(Sigmoid, (), ((S, S, S),) ),
(Sinh, (), ((S, S, S),) ),
(Cosh, (), ((S, S, S),) ),
(Abs, (), ((S, S, S),) ),
(Clamp, (0, 1), ((S, S, S),) ),
(Sqrt, (), (torch.rand(S, S, S) + 1e-4,) ),
(Sin, (), ((S, S, S),) ),
(Cos, (), ((S, S, S),) ),
(Tan, (), (torch.randn(S, S, S).clamp(-1, 1),) ),
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
(Atan, (), ((S, S, S),) ),
(Cinv, (), (torch.rand(S, S, S) + 0.1,) ),
(Cmax, (), ((S, S, S), (S, S, S)) ),
(Cmin, (), ((S, S, S), (S, S, S)) ),
(Round, (), ((S, S, S),) ),
(Sign, (), ((S, S, S),) ),
(Trunc, (), ((S, S, S),) ),
(Floor, (), ((S, S, S),) ),
(Ceil, (), ((S, S, S),) ),
(Frac, (), ((S, S, S),) ),
(Fmod, (1.5,), ((S, S, S),) ),
(Lerp, (0.2,), ((S, S, S), (S, S, S)) ),
(Rsqrt, (), (torch.rand(S, S, S) + 1e-2,) ),
(Remainder, (1.5,), ((S, S, S),) ),
(CmaxConstant, (0.5,), ((S, S, S),) ),
(CminConstant, (0.5,), ((S, S, S),) ),
(Mean, (), ((S, S, S),) ),
(Mean, (1,), ((S, S, S),), 'dim' ),
(Sum, (), ((S, S, S),) ),
(Sum, (1,), ((S, S, S),), 'dim' ),
(Prod, (), ((S, S, S),) ),
(Prod, (1,), ((S, S, S),), 'dim' ),
(Addmm, (), ((S, M), (S, S), (S, M)), ),
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef' ),
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)), ),
(Addbmm, (0.1, 0.4), ((S, M), (S, S, S), (S, S, M)), 'coef' ),
(Baddbmm, (), ((S, S, M), (S, S, S), (S, S, M)), ),
(Baddbmm, (0.1, 0.4), ((S, S, M), (S, S, S), (S, S, M)), 'coef' ),
(Addmv, (), ((S,), (S, M), (M,)), ),
(Addmv, (0.1, 0.4), ((S,), (S, M), (M,)), 'coef' ),
(Addr, (), ((S, M), (S,), (M,)), ),
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef' ),
(Dot, (), ((L,), (L,)), ),
(Max, (), ((S, S, S),), ),
(Min, (), ((S, S, S),), ),
(Max, (0,), ((S, S, S),), 'dim' ),
(Min, (0,), ((S, S, S),), 'dim' ),
(Mode, (0,), ((S, S, S),), ),
(Kthvalue, (2, 0), ((S, S, S),), ),
(Median, (0,), ((S, S, S),), ),
(Norm, (1.5,), (torch.rand(S, S, S),), '1.5' ),
(Norm, (), ((S, S, S),), '2' ),
(Norm, (3,), ((S, S, S),), '3' ),
(Norm, (1.5, 0), (torch.rand(S, S, S),), '1.5_dim' ),
(Norm, (2, 0), ((S, S, S),), '2_dim' ),
(Norm, (3, 0), ((S, S, S),), '3_dim' ),
(Addcmul, (), ((S, S), (S, S), (S, S)) ),
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale' ),
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 1e-2) ),
(Addcdiv, (0.6,), ((S, S), (S, S), torch.rand(S, S) + 1e-2), 'scale'),
(IndexAdd, (0,), ((S, S), index_variable(2, S), (2, S)) ),
(IndexCopy, (0,), ((S, S), index_variable(2, S), (2, S)) ),
(IndexFill, (0, 2), ((S, S), index_variable(2, S)) ),
(IndexSelect, (0,), ((S, S), index_variable(2, S)) ),
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S)) ),
(Resize, (S*S, S), ((S, S, S),) ),
(Diag, (), ((S, S),), '2d' ),
(Diag, (), ((S,),), '1d' ),
(Tril, (), ((S, S),) ),
(Tril, (2,), ((S, S),), 'idx' ),
(Triu, (), ((S, S),) ),
(Triu, (2,), ((S, S),), 'idx' ),
(Clone, (), ((S, M, S),) ),
(Squeeze, (), ((S, 1, M, 1),) ),
(Squeeze, (1,), ((S, 1, M, 1),), 'dim' ),
(Unsqueeze, (0,), ((S, M, S),), '0' ),
(Unsqueeze, (1,), ((S, M, S),), '1' ),
# (MaskedCopy, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False), (S, S),)),
(MaskedFill, (10,), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(MaskedSelect, (), ((S, S), Variable(torch.randn(S, S).gt(0), requires_grad=False))),
(Sort, (), ((S, M, S),) ),
(Sort, (1,), ((S, M, S),), 'dim' ),
(Sort, (1, True), ((S, M, S),), 'dim_desc' ),
(Topk, (3,), ((S, M, S),) ),
(Topk, (3, 1), ((S, M, S),), 'dim' ),
(Topk, (3, 1, True), ((S, M, S),), 'dim_desc' ),
(Topk, (3, 1, True, True), ((S, M, S),), 'dim_desc_sort' ),
]
method_tests = [
('add', (S, S, S), ((S, S, S),) ),
('add', (S, S, S), (3.14,), 'constant' ),
('sub', (S, S, S), ((S, S, S),) ),
('sub', (S, S, S), (3.14,), 'constant' ),
('mul', (S, S, S), ((S, S, S),) ),
('mul', (S, S, S), (3.14,), 'constant' ),
('div', (S, S, S), ((S, S, S),) ),
('div', (S, S, S), (3.14,), 'constant' ),
('pow', (S, S, S), ((S, S, S),) ),
('pow', (S, S, S), (3.14,), 'constant' ),
('transpose', (1, 2, 3), (1, 2) ),
('t', (1, 2), () ),
('view', (S, S, S), (S*S, S), ),
('view_as', (S, S, S), ((S*S, S),) ),
('expand', (S, 1, S), (S, S, S) ),
('exp', (S, S, S), () ),
('log', (S, S, S), () ),
('log1p', (S, S, S), () ),
('tanh', (S, S, S), () ),
('sigmoid', (S, S, S), () ),
('sinh', (S, S, S), () ),
('cosh', (S, S, S), () ),
('abs', (S, S, S), () ),
('clamp', (S, S, S), (0, 1) ),
('sqrt', (S, S, S), () ),
('sin', (S, S, S), () ),
('cos', (S, S, S), () ),
('tan', (S, S, S), () ),
('asin', (S, S, S), () ),
('acos', (S, S, S), () ),
('atan', (S, S, S), () ),
('cinv', (S, S, S), () ),
('round', (S, S, S), () ),
('sign', (S, S, S), () ),
('trunc', (S, S, S), () ),
('floor', (S, S, S), () ),
('ceil', (S, S, S), () ),
('rsqrt', (S, S, S), () ),
('fmod', (S, S, S), (1.5,) ),
('remainder', (S, S, S), (1.5,) ),
('lerp', (S, S, S), ((S, S, S), 0.4) ),
('cmax', (S, S, S), ((S, S, S),) ),
('cmax', (S, S, S), (0.5,), 'constant' ),
('cmin', (S, S, S), ((S, S, S),) ),
('cmin', (S, S, S), (0.5,), 'constant' ),
('mean', (S, S, S), () ),
('mean', (S, S, S), (1,), 'dim' ),
('sum', (S, S, S), () ),
('sum', (S, S, S), (1,), 'dim' ),
('prod', (S, S, S), () ),
('prod', (S, S, S), (1,), 'dim' ),
('addmm', (S, M), ((S, S), (S, M)), ),
('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef' ),
('addbmm', (S, M), ((S, S, S), (S, S, M)), ),
('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ),
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), ),
('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef' ),
('addmv', (S,), ((S, M), (M,)), ),
('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef' ),
('addr', (S, M), ((S,), (M,)), ),
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef' ),
('dot', (L,), ((L,),), ),
('max', (S, S, S), () ),
('min', (S, S, S), () ),
('addcmul', (S, S), ((S, S), (S, S)) ),
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale' ),
('addcdiv', (S, S), ((S, S), (S, S)) ),
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale' ),
('norm', (S, S, S), (2,) ),
('norm', (S, S, S), (2, 1), 'dim' ),
('dist', (S, S, S), ((S, S, S),) ),
('dist', (S, S, S), ((S, S, S), 4), '4' ),
('index_select', (S, S, S), (0, index_variable(2, S)) ),
('cat', (1, S, S), ((Variable(torch.randn(2, S, S)), Variable(torch.randn(3, S, S))), 0)),
('diag', (M, M), (), '2d' ),
('diag', (M,), (), '1d' ),
('tril', (M, M), () ),
('triu', (M, M), () ),
('clone', (S, M, S), () ),
('permute', (1, 2, 3, 4), (0, 2, 3, 1) ),
('select', (S, S, S), (1, 2) ),
('narrow', (S, S, S), (1, 2, 2) ),
('squeeze', (S, 1, S, 1), () ),
('squeeze', (S, 1, S, 1), (1,), '1_dim' ),
('squeeze', (S, 1, S, 1), (2,), 'not_1_dim' ),
('unsqueeze', (S, S, S), (0,), 'first' ),
('unsqueeze', (S, S, S), (1,), 'middle' ),
('unsqueeze', (S, S, S), (3,), 'last' ),
('masked_select', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False),) ),
('masked_fill_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10) ),
('masked_copy_', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M)) ),
]
# TODO: mm, bmm, mv, ger
# TODO: max, min with dim (problem with indices)
# TODO: mode, median, sort, kthvalue, topk (problem with indices)
# TODO: indexAdd, indexCopy, indexFill
# TODO: resize, resize_as (tensors only have resize_ and resize_as_)
def create_input(call_args):
if not isinstance(call_args, tuple):
call_args = (call_args,)
def map_arg(arg):
if isinstance(arg, tuple) and not isinstance(arg[0], Variable):
return Variable(torch.randn(*arg).double(), requires_grad=True)
elif torch.is_tensor(arg):
if isinstance(arg, torch.FloatTensor):
return Variable(arg.double(), requires_grad=True)
else:
return Variable(arg, requires_grad=True)
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
def unpack_variables(args):
if isinstance(args, Variable):
return args.data
elif isinstance(args, tuple):
return tuple(unpack_variables(elem) for elem in args)
else:
return args
ignore_inplace = set((
'test_DivConstant_by_tensor',
))
for test in function_tests:
cls, constructor_args, call_args = test[:3]
test_name = 'test_' + cls.__name__ + ('_' + test[3] if len(test) == 4 else '')
def do_test(self, cls=cls, constructor_args=constructor_args,
call_args=call_args, test_name=test_name):
input = create_input(call_args)
output = cls(*constructor_args)(*input)
if not isinstance(output, tuple):
output = (output,)
for i, o in enumerate(output):
analytical = get_analytical_jacobian(input, o)
def fn(input):
tmp = cls(*constructor_args)(*input)
if not isinstance(tmp, tuple):
tmp = (tmp,)
return tmp[i].data
numerical = get_numerical_jacobian(fn, input, input)
self.assertLessEqual(
max(a.add(-1, n).abs().max() for a, n in zip(analytical, numerical)),
PRECISION
)
if test_name not in ignore_inplace and issubclass(cls, InplaceFunction):
inplace_input = deepcopy(input)
inplace_input_copy = tuple(i + 0 for i in inplace_input)
fn = cls(*constructor_args, inplace=True)
inplace_output = fn(*inplace_input_copy)
if not isinstance(inplace_output, tuple):
inplace_output = (inplace_output,)
self.assertEqual(inplace_output, output)
# Check that gradient is the same
for inp_i, i in zip(inplace_input, input):
if inp_i.grad is not None:
inp_i.grad.zero_()
if i.grad is not None:
i.grad.zero_()
for io, o in zip(inplace_output, output):
grad = torch.randn(*io.size()).double()
io.backward(grad)
o.backward(grad)
for inp_i, i in zip(inplace_input, input):
self.assertEqual(inp_i.grad, i.grad)
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
setattr(TestAutograd, test_name, do_test)
for test in method_tests:
name, self_size, args = test[:3]
test_name = 'test_' + name + ('_' + test[3] if len(test) == 4 else '')
def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name):
def check(name):
self_variable = create_input((self_size,))[0]
args_variable = create_input(args)
self_tensor = deepcopy(self_variable.data)
args_tensor = deepcopy(unpack_variables(args_variable))
output_variable = getattr(self_variable, name)(*args_variable)
output_tensor = getattr(self_tensor, name)(*args_tensor)
if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
output_tensor = torch.DoubleTensor((output_tensor,))
self.assertEqual(unpack_variables(output_variable), output_tensor)
# TODO: check that both have changed after adding all inplace ops
check(name)
inplace_name = name + '_'
if hasattr(Variable(torch.ones(1)), inplace_name):
try:
check(inplace_name)
except Exception as e:
if not 'only supports scalar' in e.args[0]:
raise
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
setattr(TestAutograd, test_name, do_test)
if __name__ == '__main__':
unittest.main()