| import contextlib |
| import gc |
| import sys |
| import math |
| import torch |
| import unittest |
| import random |
| import warnings |
| from copy import deepcopy |
| from collections import OrderedDict |
| from itertools import product |
| from operator import mul |
| from functools import reduce |
| from torch.autograd.gradcheck import gradgradcheck, gradcheck |
| from torch.autograd.function import once_differentiable |
| from torch.autograd.profiler import profile |
| |
| from common import TestCase, run_tests, skipIfNoLapack |
| from torch.autograd import Variable, Function |
| from torch.autograd.function import InplaceFunction |
| |
| if sys.version_info[0] == 2: |
| import cPickle as pickle |
| else: |
| import pickle |
| |
| PRECISION = 1e-4 |
| |
| |
| class NoArgsClass(object): |
| def __iter__(self): |
| return self |
| |
| def __next__(self): |
| raise StopIteration() |
| next = __next__ # Python 2 compatibility |
| |
| NO_ARGS = NoArgsClass() |
| |
| |
| @contextlib.contextmanager |
| def backward_engine(engine): |
| _prev_engine = Variable._execution_engine |
| Variable._execution_engine = engine() |
| try: |
| yield |
| finally: |
| Variable._execution_engine = _prev_engine |
| |
| |
| def graph_desc(fn): |
| if fn is None: |
| return 'None' |
| result = type(fn).__name__ + '(' |
| next_functions = fn.next_functions |
| for next_fn, _ in next_functions: |
| result += graph_desc(next_fn) |
| result += ', ' |
| if next_functions: |
| result = result[:-2] |
| return result + ')' |
| |
| |
| class TestAutograd(TestCase): |
| |
| def _function_test(self, cls): |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5, 5), requires_grad=True) |
| result = cls.apply(x, 2, y) |
| go = Variable(torch.ones(1), requires_grad=True) |
| result.sum().backward(go, create_graph=True) |
| |
| self.assertEqual(x.grad.data, y.data + torch.ones(5, 5)) |
| self.assertEqual(y.grad.data, x.data + torch.ones(5, 5) * 2) |
| self.assertIsNotNone(x.grad.grad_fn) |
| self.assertIsNotNone(y.grad.grad_fn) |
| |
| return x, y |
| |
| def test_function(self): |
| class MyFunction(Function): |
| |
| @staticmethod |
| def forward(ctx, tensor1, pyscalar, tensor2): |
| ctx.pyscalar = pyscalar |
| ctx.save_for_backward(tensor1, tensor2) |
| return tensor1 + pyscalar * tensor2 + tensor1 * tensor2 |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| var1, var2 = ctx.saved_variables |
| # NOTE: self is the test case here |
| self.assertIsInstance(var1, Variable) |
| self.assertIsInstance(var2, Variable) |
| self.assertIsInstance(grad_output, Variable) |
| return (grad_output + grad_output * var2, None, |
| grad_output * ctx.pyscalar + grad_output * var1) |
| |
| x, y = self._function_test(MyFunction) |
| |
| x_grad_desc = graph_desc(x.grad.grad_fn) |
| y_grad_desc = graph_desc(y.grad.grad_fn) |
| self.assertEqual( |
| x_grad_desc, |
| 'CloneBackward(AddBackward1(ExpandBackward(AccumulateGrad()), ' |
| 'MulBackward1(ExpandBackward(AccumulateGrad()), AccumulateGrad())))') |
| self.assertEqual( |
| y_grad_desc, |
| 'CloneBackward(AddBackward1(MulBackward0(ExpandBackward(AccumulateGrad())), ' |
| 'MulBackward1(ExpandBackward(AccumulateGrad()), AccumulateGrad())))') |
| |
| def test_once_differentiable(self): |
| class MyFunction(Function): |
| |
| @staticmethod |
| def forward(ctx, tensor1, pyscalar, tensor2): |
| ctx.pyscalar = pyscalar |
| ctx.save_for_backward(tensor1, tensor2) |
| return tensor1 + pyscalar * tensor2 + tensor1 * tensor2 |
| |
| @staticmethod |
| @once_differentiable |
| def backward(ctx, grad_output): |
| t1, t2 = ctx.saved_tensors |
| # NOTE: self is the test case here |
| self.assertTrue(torch.is_tensor(t1)) |
| self.assertTrue(torch.is_tensor(t2)) |
| self.assertTrue(torch.is_tensor(grad_output)) |
| return (grad_output + grad_output * t2, None, |
| grad_output * ctx.pyscalar + grad_output * t1) |
| |
| x, y = self._function_test(MyFunction) |
| self.assertEqual(graph_desc(x.grad.grad_fn), |
| 'CloneBackward(Error(AccumulateGrad(), None, AccumulateGrad()))') |
| self.assertEqual(graph_desc(y.grad.grad_fn), |
| 'CloneBackward(Error(AccumulateGrad(), None, AccumulateGrad()))') |
| |
| def test_function_returns_input(self): |
| class MyFunction(Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x |
| |
| @staticmethod |
| def backward(ctx, grad): |
| return grad * 2 |
| |
| v = Variable(torch.ones(1), requires_grad=True) |
| MyFunction.apply(v).backward() |
| self.assertEqual(v.grad.data.tolist(), [2]) |
| |
| v.grad.data.zero_() |
| MyFunction.apply(v.clone()).backward() |
| self.assertEqual(v.grad.data.tolist(), [2]) |
| |
| def test_legacy_function_none_grad(self): |
| class MyFunction(Function): |
| def forward(self, x): |
| return torch.zeros(2, 2, 2) |
| |
| def backward(self, grad_output): |
| return None |
| |
| shape = (2, 3) |
| v = Variable(torch.ones(shape), requires_grad=True) |
| y = v[0, 0].expand(3, 5).t().sum() |
| MyFunction()(y).sum().backward() |
| self.assertEqual(v.grad.data, torch.zeros(shape)) |
| |
| def test_accumulate_grad(self): |
| grad_output = Variable(torch.ones(5, 5)) |
| |
| def compute_grad(create_graph): |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = x + 2 |
| y.backward(grad_output, retain_graph=True) |
| x_grad = x.grad |
| x_grad_clone = x.grad.clone() |
| y.backward(grad_output, create_graph=create_graph) |
| return x_grad, x_grad_clone |
| |
| # Accumulate in-place when create_graph is False |
| x_grad, x_grad_clone = compute_grad(create_graph=False) |
| self.assertEqual(x_grad, x_grad_clone * 2) |
| |
| # Accumulate out-of-place when create_graph is False |
| x_grad, x_grad_clone = compute_grad(create_graph=True) |
| self.assertEqual(x_grad, x_grad_clone) |
| |
| def test_hessian_vector(self): |
| x = Variable(torch.randn(2, 2), requires_grad=True) |
| y = Variable(torch.randn(2, 2), requires_grad=True) |
| |
| z = x ** 2 + y * x + y ** 2 |
| z.backward(torch.ones(2, 2), create_graph=True) |
| |
| x_grad = 2 * x.data + y.data |
| y_grad = x.data + 2 * y.data |
| self.assertEqual(x.grad.data, x_grad) |
| self.assertEqual(y.grad.data, y_grad) |
| |
| grad_sum = 2 * x.grad + y.grad |
| grad_sum.backward(torch.ones(2, 2)) |
| x_hv = torch.ones(2, 2) * 5 |
| y_hv = torch.ones(2, 2) * 4 |
| self.assertEqual(x.grad.data, x_grad + x_hv) |
| self.assertEqual(y.grad.data, y_grad + y_hv) |
| |
| def test_grad(self): |
| x = Variable(torch.randn(2, 2), requires_grad=True) |
| y = Variable(torch.randn(2, 2), requires_grad=True) |
| z = x ** 2 + y * x + y ** 2 |
| z.backward(torch.ones(2, 2), create_graph=True) |
| |
| x_grad = 2 * x.data + y.data |
| y_grad = x.data + 2 * y.data |
| self.assertEqual(x.grad.data, x_grad) |
| self.assertEqual(y.grad.data, y_grad) |
| |
| grad_sum = 2 * x.grad + y.grad |
| x_hv = torch.autograd.grad( |
| outputs=[grad_sum], grad_outputs=[torch.ones(2, 2)], |
| inputs=[x], create_graph=True, only_inputs=True) |
| expected_x_hv = torch.ones(2, 2) * 5 |
| expected_y_hv = torch.ones(2, 2) * 4 |
| |
| self.assertEqual(x_hv[0].data, expected_x_hv) |
| self.assertEqual(x.grad.data, x_grad) |
| self.assertEqual(y.grad.data, y_grad) |
| |
| grad_sum = 2 * x.grad + y.grad |
| x_hv = torch.autograd.grad( |
| outputs=grad_sum, inputs=x, |
| grad_outputs=torch.ones(2, 2), |
| only_inputs=False) |
| |
| self.assertEqual(x_hv[0].data, expected_x_hv) |
| self.assertEqual(x.grad.data, x_grad) |
| self.assertEqual(y.grad.data, y_grad + expected_y_hv) |
| |
| def test_grad_nonleaf(self): |
| x_init = Variable(torch.randn(2, 2), requires_grad=True) |
| x = x_init |
| y = Variable(torch.randn(2, 2), requires_grad=True) |
| grad_output = torch.ones(2, 2) |
| |
| def fn(x): |
| return x ** 2 + y * x + y ** 2 |
| |
| for i in range(5): |
| grad_x, = torch.autograd.grad( |
| fn(x), x, grad_outputs=grad_output, create_graph=True) |
| |
| grad_x_expected = 2 * x.data + y.data |
| self.assertIsNone(y.grad) |
| self.assertIsNone(x.grad) |
| self.assertEqual(grad_x.data, grad_x_expected) |
| |
| x = x + 0.05 * grad_x |
| |
| val_init = fn(x_init).data.sum() |
| val_final = fn(x).data.sum() |
| self.assertGreater(val_final, val_init) |
| |
| x.backward(grad_output) |
| self.assertIsNotNone(y.grad) |
| self.assertIsNotNone(x_init.grad) |
| |
| def test_grad_nonleaf_many_outputs(self): |
| # This checks an edge case for function callbacks |
| # We want to capture two grads of a function, but can only |
| # register a single callback. |
| x = Variable(torch.randn(4, 2), requires_grad=True) |
| a, b = x.chunk(2) |
| |
| def hook(*grads): |
| hook_called[0] = True |
| hook_called = [False] |
| x.register_hook(hook) |
| |
| go = torch.randn(2, 2) |
| grad_a, grad_b = torch.autograd.grad( |
| (a + 2 * b), [a, b], grad_outputs=go, create_graph=True) |
| |
| self.assertEqual(grad_a.data, go) |
| self.assertEqual(grad_b.data, go * 2) |
| self.assertFalse(hook_called[0]) |
| self.assertIsNone(x.grad) |
| |
| def test_backward_badcalls(self): |
| x = Variable(torch.ones(1)) |
| with self.assertRaisesRegex(RuntimeError, 'does not require grad'): |
| x.backward() |
| |
| def test_grad_badcalls(self): |
| x = Variable(torch.ones(1)) |
| y = x ** 2 |
| with self.assertRaisesRegex(RuntimeError, 'does not require grad'): |
| torch.autograd.grad(x, y) |
| with self.assertRaisesRegex(RuntimeError, 'does not require grad'): |
| torch.autograd.grad(y, x) |
| |
| x = Variable(torch.ones(1), requires_grad=True) |
| y = x ** 2 |
| torch.autograd.grad(y, x) # this should succeed now |
| with self.assertRaisesRegex(RuntimeError, 'unreachable'): |
| torch.autograd.grad(x, y) |
| |
| def test_grad_unreachable(self): |
| x = Variable(torch.ones(1), requires_grad=True) |
| y = Variable(torch.ones(1), requires_grad=True) |
| # Make sure x and y have grad accumulators allocated |
| z = x * 2 |
| w = y * 2 |
| with self.assertRaisesRegex(RuntimeError, 'unreachable'): |
| torch.autograd.grad(x * 2, [x, y]) |
| |
| grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=True) |
| self.assertEqual(grad_x, x * 2) |
| self.assertIsNone(grad_y) |
| |
| # This is slightly different than the case above, because z doesn't even |
| # have a grad accumulator allocated. |
| z = Variable(torch.ones(1), requires_grad=True) |
| grad_x, grad_z = torch.autograd.grad(x * 2, [x, z], allow_unused=True) |
| self.assertEqual(grad_x, x * 2) |
| self.assertIsNone(grad_z) |
| |
| def test_hooks(self): |
| x = Variable(torch.ones(5, 5), requires_grad=True) |
| y = Variable(torch.ones(5, 5) * 4, requires_grad=True) |
| |
| counter = [0] |
| |
| def bw_hook(inc, grad): |
| self.assertIsInstance(grad, Variable) |
| counter[0] += inc |
| |
| z = x ** 2 + x * 2 + x * y + y |
| x.register_hook(lambda *args: bw_hook(0, *args)) |
| test = z.register_hook(lambda *args: bw_hook(1, *args)) |
| z.backward(torch.ones(5, 5), retain_graph=True) |
| self.assertEqual(counter[0], 1) |
| |
| test2 = z.register_hook(lambda *args: bw_hook(2, *args)) |
| z.backward(torch.ones(5, 5), retain_graph=True) |
| self.assertEqual(counter[0], 4) |
| |
| test2.remove() |
| z.backward(torch.ones(5, 5), retain_graph=True) |
| self.assertEqual(counter[0], 5) |
| |
| def bw_hook_modify(grad): |
| return grad.mul(2) |
| |
| test.remove() |
| z.register_hook(bw_hook_modify) |
| y.grad.data.zero_() |
| z.backward(torch.ones(5, 5), retain_graph=True) |
| self.assertEqual(y.grad.data, (x.data + 1) * 2) |
| |
| y.register_hook(bw_hook_modify) |
| y.grad.data.zero_() |
| z.backward(torch.ones(5, 5)) |
| self.assertEqual(y.grad.data, (x.data + 1) * 4) |
| |
| def test_hooks_cpp(self): |
| # Tests hooks for autograd function implemented in C++ |
| bn = torch.nn.BatchNorm1d(5, affine=False) |
| bn.eval() |
| |
| counter = [0] |
| |
| def bw_hook(grad): |
| counter[0] += 1 |
| return grad * 2 |
| |
| x = Variable(torch.ones(5, 5), requires_grad=True) |
| z = bn(x) |
| z.register_hook(bw_hook) |
| z.sum().backward() |
| |
| self.assertEqual(counter[0], 1, 'bw_hook not called') |
| self.assertEqual(x.grad.data, torch.ones(5, 5) * 2) |
| |
| def test_hook_none(self): |
| # WARNING: this is a test for autograd internals. |
| # You should never have to use such things in your code. |
| class NoneGradientFunction(Function): |
| |
| def forward(self, x, y): |
| assert self.needs_input_grad[0] |
| assert not self.needs_input_grad[1] |
| return x, y |
| |
| def backward(self, grad_x, grad_y): |
| return grad_x, None |
| |
| fn = NoneGradientFunction() |
| was_called = [False] |
| |
| def hook(grad_input, grad_output): |
| self.assertIsInstance(grad_input, tuple) |
| self.assertIsInstance(grad_output, tuple) |
| self.assertIsNotNone(grad_input[0]) |
| self.assertIsNotNone(grad_input[1]) |
| self.assertIsNotNone(grad_output[0]) |
| self.assertIsNotNone(grad_output[1]) |
| was_called[0] = True |
| fn.register_hook(hook) |
| |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5, 5)) |
| sum(fn(x, y)).sum().backward() |
| self.assertTrue(was_called[0]) |
| |
| def test_retain_grad(self): |
| input = Variable(torch.rand(1, 3), requires_grad=True) |
| h1 = input * 3 |
| out = (h1 * h1).sum() |
| |
| # It should be possible to call retain_grad() multiple times |
| h1.retain_grad() |
| h1.retain_grad() |
| |
| # Gradient should be accumulated |
| out.backward(retain_graph=True) |
| self.assertEqual(h1.data * 2, h1.grad.data) |
| out.backward(retain_graph=True) |
| self.assertEqual(h1.data * 4, h1.grad.data) |
| |
| input.grad.data.zero_() |
| # It should be a no-op for leaves |
| input.retain_grad() |
| input.retain_grad() |
| out.backward() |
| self.assertEqual(input.data * 18, input.grad.data) |
| |
| def test_retain_grad_cycle(self): |
| import gc |
| import weakref |
| counter = [0] |
| refs = [None] |
| |
| x = Variable(torch.ones(5, 5), requires_grad=True) |
| |
| def run_test(): |
| y = x * 2 |
| y.retain_grad() |
| |
| def inc(*args): |
| counter[0] += 1 |
| refs[0] = weakref.ref(y, inc) |
| return y / 2 |
| |
| z = run_test() |
| gc.collect() |
| self.assertIsNone(refs[0]()) |
| self.assertEqual(counter[0], 1) |
| z.sum().backward() |
| |
| def test_backward(self): |
| v_t = torch.randn(5, 5) |
| x_t = torch.randn(5, 5) |
| y_t = torch.rand(5, 5) + 0.1 |
| z_t = torch.randn(5, 5) |
| grad_output = torch.randn(5, 5) |
| v = Variable(v_t, requires_grad=True) |
| x = Variable(x_t, requires_grad=True) |
| y = Variable(y_t, requires_grad=True) |
| z = Variable(z_t, requires_grad=True) |
| |
| v.backward(grad_output) |
| self.assertEqual(v.grad.data, grad_output) |
| |
| a = x + (y * z) + 4 * z ** 2 * x / y |
| a.backward(grad_output) |
| x_grad = 4 * z_t.pow(2) / y_t + 1 |
| y_grad = z_t - 4 * x_t * z_t.pow(2) / y_t.pow(2) |
| z_grad = 8 * x_t * z_t / y_t + y_t |
| self.assertEqual(x.grad.data, x_grad * grad_output) |
| self.assertEqual(y.grad.data, y_grad * grad_output) |
| self.assertEqual(z.grad.data, z_grad * grad_output) |
| |
| def test_sparse_backward(self): |
| class FixedGradientFunction(Function): |
| |
| def __init__(self, grad): |
| self.grad = grad |
| |
| def forward(self, x): |
| return x |
| |
| def backward(self, grad_x): |
| return self.grad |
| |
| size = torch.Size([6, 3, 2]) |
| i1 = torch.LongTensor([ |
| [0, 3, 4], |
| [0, 2, 2], |
| ]) |
| v1 = torch.DoubleTensor([[1, 2], [4, 5], [7, 8]]) |
| sparse_grad1 = torch.sparse.DoubleTensor(i1, v1, size) |
| i2 = torch.LongTensor([ |
| [0, 1, 3, 4], |
| [0, 1, 2, 2], |
| ]) |
| v2 = torch.DoubleTensor([[1, 2], [4, 3], [4, 5], [7, 8]]) |
| sparse_grad2 = torch.sparse.DoubleTensor(i2, v2, size) |
| dense_grad = torch.rand(size).double() |
| sparse_fn1 = FixedGradientFunction(sparse_grad1) |
| sparse_fn2 = FixedGradientFunction(sparse_grad2) |
| dense_fn = FixedGradientFunction(dense_grad) |
| |
| # sparse first |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| (sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward() |
| self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2) |
| # dense first |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| (dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward() |
| self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2) |
| # sparse only |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| (sparse_fn1(x) + sparse_fn2(x)).sum().backward() |
| self.assertEqual(x.grad.data, sparse_grad1 + sparse_grad2) |
| |
| def test_multi_backward(self): |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| q = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| a = Variable(torch.randn(5, 5), requires_grad=True) |
| b = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| q2 = q * 2 |
| z = x + y + q2 |
| c = a * b + q2 |
| grad_z = torch.randn(5, 5) |
| grad_c = torch.randn(5, 5) |
| torch.autograd.backward([z, c], [grad_z, grad_c]) |
| |
| self.assertEqual(x.grad.data, grad_z) |
| self.assertEqual(y.grad.data, grad_z) |
| self.assertEqual(a.grad.data, grad_c * b.data) |
| self.assertEqual(b.grad.data, grad_c * a.data) |
| self.assertEqual(q.grad.data, (grad_c + grad_z) * 2) |
| |
| def test_multi_backward_no_grad(self): |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5, 5), requires_grad=False) |
| |
| z = x + y |
| q = y * 2 |
| |
| # NB: we currently raise an exception if any arguments to backwards |
| # have requires_grad=False and don't have a grad_fn. We may want to |
| # relax that check to a warning. |
| def call_backwards(): |
| torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)]) |
| self.assertRaises(RuntimeError, call_backwards) |
| |
| def test_dependent_backward(self): |
| x = Variable(torch.randn(10), requires_grad=True) |
| y = x ** 2 |
| z = y ** 3 |
| |
| go_y = torch.randn(10) |
| go_z = torch.randn(10) |
| torch.autograd.backward([y, z], [go_y, go_z]) |
| |
| xd = x.data |
| self.assertEqual(x.grad.data, 2 * xd * go_y + 6 * xd.pow(5) * go_z) |
| |
| def test_save_output_nr(self): |
| x = Variable(torch.randn(10), requires_grad=True) |
| |
| class MultiOutputFn(Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x[:5], x[5:] |
| |
| @staticmethod |
| def backward(ctx, *grad): |
| return torch.cat(grad) |
| |
| a, b = MultiOutputFn.apply(x) |
| self.assertEqual(b.output_nr, 1) |
| |
| class TestFn(Function): |
| @staticmethod |
| def forward(ctx, b): |
| ctx.save_for_backward(b) |
| return b * 2 |
| |
| @staticmethod |
| def backward(ctx, grad_b): |
| b, = ctx.saved_variables |
| self.assertEqual(b.output_nr, 1) |
| |
| TestFn.apply(b).sum().backward() |
| |
| def test_no_grad(self): |
| x = Variable(torch.ones(5, 5), requires_grad=True) |
| y = Variable(torch.ones(5, 5) * 4) |
| with torch.no_grad(): |
| w = x + y |
| self.assertFalse(w.requires_grad) |
| self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) |
| self.assertIsNone(w.grad_fn) |
| |
| def test_indexing(self): |
| x = torch.arange(1, 17).view(4, 4) |
| y = Variable(x, requires_grad=True) |
| |
| def compare(x, y, idx, indexed_tensor, indexed_var): |
| indexed_var_t = indexed_var.data |
| if not torch.is_tensor(indexed_tensor): |
| indexed_var_t = indexed_var_t[0] |
| self.assertEqual(indexed_tensor, indexed_var_t) |
| |
| indexed_var.sum().backward() |
| expected_grad = torch.Tensor(x.size()).fill_(0) |
| expected_grad[idx] = 1 |
| self.assertEqual(y.grad.data, expected_grad) |
| |
| def check_index(x, y, idx): |
| if y.grad is not None: |
| y.grad.data.zero_() |
| indexed_tensor = x[idx] |
| indexed_var = y[idx] |
| compare(x, y, idx, indexed_tensor, indexed_var) |
| |
| check_index(x, y, 1) |
| check_index(x, y, (1, 1)) |
| check_index(x, y, slice(1, None)) |
| check_index(x, y, slice(None, 2)) |
| check_index(x, y, (slice(None, 2), 2)) |
| check_index(x, y, (slice(1, 2), 2)) |
| check_index(x, y, (1, slice(2, None))) |
| check_index(x, y, (slice(None, None), slice(2, None))) |
| check_index(x, y, torch.LongTensor([0, 2])) |
| check_index(x, y, torch.rand(4, 4).bernoulli().byte()) |
| check_index(x, y, (Ellipsis, slice(2, None))) |
| check_index(x, y, ([0], [0])) |
| check_index(x, y, ([1, 2, 3], [0])) |
| check_index(x, y, ([1, 2], [2, 1])) |
| check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]])) |
| check_index(x, y, ([slice(None), [2, 3]])) |
| check_index(x, y, ([[2, 3], slice(None)])) |
| |
| # advanced indexing, with less dim, or ellipsis |
| check_index(x, y, ([0])) |
| check_index(x, y, ([0], )) |
| |
| x = torch.arange(1, 49).view(4, 3, 4) |
| y = Variable(x, requires_grad=True) |
| |
| check_index(x, y, (slice(None), [0], [0])) |
| check_index(x, y, ([0], [0], slice(None))) |
| check_index(x, y, (slice(None), [0, 1, 2], [0])) |
| check_index(x, y, ([0, 1, 2], [0], slice(None))) |
| check_index(x, y, (slice(None), [1, 2], [2, 1])) |
| check_index(x, y, ([1, 2], [2, 1], slice(None))) |
| check_index(x, y, (slice(None), [[1, 2], [2, 0]], [[0, 1], [2, 3]])) |
| check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 2]], slice(None))) |
| check_index(x, y, (slice(None), slice(None), [2, 1])) |
| check_index(x, y, (slice(None), [2, 1], slice(None))) |
| check_index(x, y, ([2, 1], slice(None), slice(None))) |
| |
| # advanced indexing, with less dim, or ellipsis |
| check_index(x, y, ([0], )) |
| check_index(x, y, ([0], slice(None))) |
| check_index(x, y, ([0], Ellipsis)) |
| check_index(x, y, ([1, 2], [0, 1])) |
| check_index(x, y, ([1, 2], [0, 1], Ellipsis)) |
| check_index(x, y, (Ellipsis, [1, 2], [0, 1])) |
| |
| # advanced indexing, with a tensor wrapped in a variable |
| z = torch.LongTensor([0, 1]) |
| zv = Variable(z, requires_grad=False) |
| seq = [z, Ellipsis] |
| seqv = [zv, Ellipsis] |
| |
| if y.grad is not None: |
| y.grad.data.zero_() |
| indexed_tensor = x[seq] |
| indexed_var = y[seqv] |
| compare(x, y, seq, indexed_tensor, indexed_var) |
| |
| def test_indexing_duplicates(self): |
| x = torch.arange(1, 17).view(4, 4) |
| y = Variable(x, requires_grad=True) |
| |
| idx = torch.LongTensor([1, 1, 3, 2, 1, 2]) |
| y[idx].sum().backward() |
| expected_grad = torch.zeros(4, 4) |
| for i in idx: |
| expected_grad[i] += 1 |
| self.assertEqual(y.grad.data, expected_grad) |
| |
| # with advanced indexing |
| x = torch.arange(1, 17).view(4, 4) |
| y = Variable(x, requires_grad=True) |
| |
| idx = [[1, 1, 3, 2, 1, 2], [0]] |
| y[idx].sum().backward() |
| expected_grad = torch.zeros(4, 4) |
| for i in idx[0]: |
| for j in idx[1]: |
| expected_grad[i][j] += 1 |
| |
| self.assertEqual(y.grad.data, expected_grad) |
| |
| x = torch.arange(1, 17).view(4, 4) |
| y = Variable(x, requires_grad=True) |
| idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]] |
| y[idx].sum().backward() |
| expected_grad = torch.Tensor([[0, 2, 0, 0], |
| [1, 0, 0, 0], |
| [0, 1, 0, 0], |
| [0, 0, 0, 0]]) |
| self.assertEqual(y.grad.data, expected_grad) |
| |
| x = torch.arange(1, 65).view(4, 4, 4) |
| y = Variable(x, requires_grad=True) |
| |
| idx = [[1, 1, 1], slice(None), slice(None)] |
| y[idx].sum().backward() |
| expected_grad = torch.Tensor(4, 4, 4).zero_() |
| expected_grad[1].fill_(3) |
| self.assertEqual(y.grad.data, expected_grad) |
| |
| def test_volatile_deprecated(self): |
| v = torch.autograd.Variable(torch.randn(3, 3)) |
| with warnings.catch_warnings(record=True) as w: |
| self.assertFalse(v.volatile) |
| self.assertIn('volatile', str(w[0].message)) |
| |
| def test_requires_grad(self): |
| x = Variable(torch.randn(5, 5)) |
| y = Variable(torch.randn(5, 5)) |
| z = Variable(torch.randn(5, 5), requires_grad=True) |
| a = x + y |
| self.assertFalse(a.requires_grad) |
| b = a + z |
| self.assertTrue(b.requires_grad) |
| |
| def error(): |
| raise RuntimeError |
| # Make sure backward isn't called on these |
| a._backward_hooks = OrderedDict() |
| x._backward_hooks = OrderedDict() |
| y._backward_hooks = OrderedDict() |
| a._backward_hooks['test'] = error |
| x._backward_hooks['test'] = error |
| y._backward_hooks['test'] = error |
| b.backward(torch.ones(5, 5)) |
| |
| def test_requires_grad_inplace(self): |
| a = Variable(torch.randn(5, 5)) |
| b = Variable(torch.randn(5, 5), requires_grad=True) |
| a += b |
| self.assertTrue(a.requires_grad) |
| |
| # non-leaf Variable |
| a = Variable(torch.randn(5, 5)) + 0 |
| b = Variable(torch.randn(5, 5), requires_grad=True) |
| a += b |
| self.assertTrue(a.requires_grad) |
| |
| def test_no_requires_grad_inplace(self): |
| # basic case, should be able to modify inplace while requires_grad is False |
| a = Variable(torch.randn(2, 3)) |
| a.add_(5) |
| a.requires_grad = True |
| a.sum().backward() |
| self.assertEqual(a.grad.data, torch.ones(2, 3)) |
| |
| # same but with a view |
| a = Variable(torch.randn(2, 3)) |
| b = a[:] |
| b.add_(5) |
| a.requires_grad = True |
| a.sum().backward() |
| self.assertEqual(a.grad.data, torch.ones(2, 3)) |
| |
| # should fail if requires_grad = True when we modify inplace |
| a = Variable(torch.randn(2, 3)) |
| b = a[:] |
| a.requires_grad = True |
| with self.assertRaises(RuntimeError): |
| a.add_(5) |
| with self.assertRaises(RuntimeError): |
| b.add_(5) |
| |
| def test_grad_assignment(self): |
| x = Variable(torch.randn(5, 5)) |
| a = Variable(torch.randn(2, 2)) # size mismatch |
| b = Variable(torch.randn(5, 5).long()) # type mismatch |
| |
| with self.assertRaises(RuntimeError): |
| x.grad = Variable(torch.randn(2, 2)) |
| with self.assertRaises(RuntimeError): |
| x.grad = Variable(torch.randn(5, 5).long()) |
| with self.assertRaises(RuntimeError): |
| x.grad = x |
| |
| if not torch.cuda.is_available(): |
| raise unittest.SkipTest("CUDA not available") |
| with self.assertRaises(RuntimeError): |
| x.grad = Variable(torch.randn(5, 5).cuda()) |
| |
| if torch.cuda.device_count() < 2: |
| raise unittest.SkipTest("At least 2 CUDA devices needed") |
| x = Variable(torch.randn(5, 5).cuda(0)) |
| with self.assertRaises(RuntimeError): |
| x.grad = Variable(torch.randn(5, 5).cuda(1)) |
| |
| def test_duplicate_backward_root(self): |
| a = Variable(torch.randn(5, 5), requires_grad=True) |
| b = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| x = a * b |
| grad_output = x.data.clone().normal_() |
| torch.autograd.backward([x, x], [grad_output, grad_output]) |
| |
| self.assertEqual(a.grad.data, b.data * grad_output * 2) |
| self.assertEqual(b.grad.data, a.data * grad_output * 2) |
| |
| def test_backward_no_grad(self): |
| a = Variable(torch.randn(5, 5), requires_grad=True) |
| b = a + 2 |
| with self.assertRaises(RuntimeError): |
| torch.autograd.backward([b], [None]) |
| |
| def test_next_functions(self): |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| a = x + y |
| self.assertIsNotNone(a.grad_fn) |
| next_functions = a.grad_fn.next_functions |
| self.assertEqual(len(next_functions), 2) |
| self.assertIsInstance(next_functions[0][0], torch._C._functions.AccumulateGrad) |
| self.assertEqual(next_functions[0][1], 0) |
| self.assertIsInstance(next_functions[1][0], torch._C._functions.AccumulateGrad) |
| self.assertEqual(next_functions[1][1], 0) |
| |
| b = a + 5 |
| next_functions = b.grad_fn.next_functions |
| self.assertEqual(len(next_functions), 1) |
| self.assertIs(next_functions[0][0], a.grad_fn) |
| |
| def test_inplace(self): |
| x = Variable(torch.ones(5, 5), requires_grad=True) |
| y = Variable(torch.ones(5, 5) * 4, requires_grad=True) |
| |
| z = x * y |
| q = z + y |
| w = z * y |
| z.add_(2) |
| # Add doesn't need it's inputs to do backward, so it shouldn't raise |
| q.backward(torch.ones(5, 5), retain_graph=True) |
| # Mul saves both inputs in forward, so it should raise |
| self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) |
| |
| z = x * y |
| q = z * y |
| r = z + y |
| w = z.add_(y) |
| # w is a the last expression, so this should succeed |
| w.backward(torch.ones(5, 5), retain_graph=True) |
| # r doesn't use the modified value in backward, so it should succeed |
| r.backward(torch.ones(5, 5), retain_graph=True) |
| # q uses dirty z, so it should raise |
| self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) |
| |
| x.grad.data.zero_() |
| m = x / 2 |
| z = m + y / 8 |
| q = z * y |
| r = z + y |
| prev_version = z._version |
| w = z.exp_() |
| self.assertNotEqual(z._version, prev_version) |
| r.backward(torch.ones(5, 5), retain_graph=True) |
| self.assertEqual(x.grad.data, torch.ones(5, 5) / 2) |
| w.backward(torch.ones(5, 5), retain_graph=True) |
| self.assertEqual(x.grad.data, torch.Tensor(5, 5).fill_((1 + math.e) / 2)) |
| self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) |
| |
| leaf = Variable(torch.ones(5, 5), requires_grad=True) |
| x = leaf.clone() |
| x.add_(10) |
| self.assertEqual(x.data, torch.ones(5, 5) * 11) |
| # x should be still usable |
| y = x + 2 |
| y.backward(torch.ones(5, 5)) |
| self.assertEqual(leaf.grad.data, torch.ones(5, 5)) |
| z = x * y |
| x.add_(2) |
| self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) |
| |
| def test_mark_non_differentiable(self): |
| class MyFunction(Function): |
| @staticmethod |
| def forward(ctx, input): |
| output = input > 0 |
| ctx.mark_non_differentiable(output) |
| return output |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return (grad_output * 0).type(torch.DoubleTensor) |
| |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| mask = MyFunction.apply(x) |
| self.assertFalse(mask.requires_grad) |
| y = x.masked_fill(mask, 0) |
| y.sum().backward() |
| |
| def test_mark_non_differentiable_mixed(self): |
| class MyFunction(Function): |
| @staticmethod |
| def forward(ctx, input): |
| a = input + 1 |
| b = input + 2 |
| ctx.mark_non_differentiable(a) |
| return a, b |
| |
| @staticmethod |
| def backward(ctx, grad_a, grad_b): |
| self.assertTrue((grad_a == 0).all()) |
| self.assertTrue((grad_b == 1).all()) |
| return grad_b |
| |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| a, b = MyFunction.apply(x) |
| self.assertFalse(a.requires_grad) |
| self.assertTrue(b.requires_grad) |
| b.sum().backward() |
| self.assertEqual(x.grad.data, torch.ones(5, 5)) |
| |
| def test_mark_non_differentiable_none(self): |
| # This used to segfault because MyFunction would send back null |
| # gradients to MulBackward, which is implemented in C++. C++ |
| # implemented functions expect incoming grad_ouptuts to be non-null. |
| class MyFunction(Function): |
| @staticmethod |
| def forward(ctx, input): |
| output = input.clone() |
| ctx.mark_non_differentiable(output) |
| return output |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return None |
| |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| r = MyFunction.apply(x * x) |
| (r * x).sum().backward() |
| |
| def test_resize(self): |
| x = Variable(torch.ones(2, 3)) |
| self.assertTrue(x.resize(3, 2).size() == (3, 2)) |
| |
| def _test_setitem(self, size, index): |
| x = Variable(torch.ones(*size), requires_grad=True) |
| y = x + 2 |
| y_version = y._version |
| y[index] = 2 |
| self.assertNotEqual(y._version, y_version) |
| y.backward(torch.ones(*size)) |
| expected_grad = torch.ones(*size) |
| if isinstance(index, Variable): |
| index = index.data |
| expected_grad[index] = 0 |
| self.assertEqual(x.grad.data, expected_grad) |
| |
| def _test_setitem_tensor(self, size, index): |
| x = Variable(torch.ones(*size), requires_grad=True) |
| y = x + 2 |
| y_version = y._version |
| value = Variable(torch.Tensor(x[index].size()).fill_(7), requires_grad=True) |
| y[index] = value |
| self.assertNotEqual(y._version, y_version) |
| y.backward(torch.ones(*size)) |
| expected_grad_input = torch.ones(*size) |
| |
| # remove all variables when indexing a Tensor for comparison, |
| # whether a top-level Variable or in a sequence |
| if isinstance(index, Variable): |
| index = index.data |
| elif isinstance(index, list): |
| novars = [] |
| for i in index: |
| if isinstance(i, Variable): |
| novars.append(i.data) |
| else: |
| novars.append(i) |
| index = novars |
| |
| expected_grad_input[index] = 0 |
| self.assertEqual(x.grad.data, expected_grad_input) |
| self.assertEqual(value.grad.data, torch.ones(value.size())) |
| |
| # case when x broadcasts to as y[1] |
| x = Variable(torch.randn(4), requires_grad=True) |
| y = Variable(torch.zeros(2, 3, 4)) |
| y[1] = x |
| y.backward(torch.randn(2, 3, 4)) |
| self.assertEqual(x.size(), x.grad.size()) |
| |
| def test_setitem(self): |
| self._test_setitem((5, 5), 1) |
| self._test_setitem((5,), 1) |
| self._test_setitem((1,), 0) |
| self._test_setitem((10,), [[0, 4, 2]]) |
| self._test_setitem((5, 5), [[0, 4], [2, 2]]) |
| self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]]) |
| self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)]) |
| self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)]) |
| self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]]) |
| self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)]) |
| self._test_setitem_tensor((5, 5), 3) |
| self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]]) |
| self._test_setitem_tensor((5,), 3) |
| self._test_setitem_tensor((5,), [[0, 1, 2, 3]]) |
| self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]]) |
| self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)]) |
| self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)]) |
| self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]]) |
| self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)]) |
| self._test_setitem_tensor((5, 5, 5), [Variable(torch.LongTensor([1, |
| 3]), requires_grad=False), [2, 4], slice(None)]) |
| |
| def test_setitem_mask(self): |
| mask = torch.ByteTensor(5, 5).bernoulli_() |
| self._test_setitem((5, 5), Variable(mask)) |
| self._test_setitem((5,), Variable(mask[0])) |
| self._test_setitem((1,), Variable(mask[0, 0:1])) |
| self._test_setitem_tensor((5, 5), Variable(mask)) |
| self._test_setitem_tensor((5,), Variable(mask[0])) |
| |
| def test_select_sum(self): |
| # both select and sum return Scalars in ATen; ensure they work together. |
| x = Variable(torch.randn(10), requires_grad=True) |
| |
| def func(x): |
| return x.select(0, 1).sum() |
| |
| gradcheck(func, [x]) |
| gradgradcheck(func, [x]) |
| |
| def test_stack(self): |
| x = Variable(torch.randn(10, 10), requires_grad=True) |
| y = Variable(torch.randn(10, 10), requires_grad=True) |
| z = Variable(torch.randn(10, 10), requires_grad=True) |
| stacked = torch.stack([x, y, z], 0) |
| grad = torch.randn(3, 10, 10) |
| stacked.backward(grad) |
| self.assertEqual(x.grad.data, grad[0]) |
| self.assertEqual(y.grad.data, grad[1]) |
| self.assertEqual(z.grad.data, grad[2]) |
| |
| def test_put(self): |
| root = Variable(torch.randn(4, 5), requires_grad=True) |
| values = Variable(torch.randn(6), requires_grad=True) |
| idx = Variable(torch.LongTensor([1, 2, 3, -1, -2, -3])) |
| |
| def func(root, values): |
| x = root.clone() |
| x.put_(idx, values) |
| return x |
| |
| gradcheck(func, [root, values]) |
| gradgradcheck(func, [root, values]) |
| |
| def test_put_accumulate(self): |
| root = Variable(torch.randn(4, 5), requires_grad=True) |
| values = Variable(torch.randn(6), requires_grad=True) |
| idx = Variable(torch.LongTensor([1, 2, 3, 1, 2, 3])) |
| |
| def func(root, values): |
| x = root.clone() |
| x.put_(idx, values, accumulate=True) |
| return x |
| |
| gradcheck(func, [root, values]) |
| gradgradcheck(func, [root, values]) |
| |
| def test_fill(self): |
| root = Variable(torch.randn(4, 5), requires_grad=True) |
| |
| def func(root): |
| x = root.clone() |
| x.fill_(2) |
| return x |
| |
| gradcheck(func, [root]) |
| gradgradcheck(func, [root]) |
| |
| def test_unused_output(self): |
| x = Variable(torch.randn(10, 10), requires_grad=True) |
| outputs = x.chunk(5) |
| o = outputs[2] |
| o = o * 4 + 2 |
| o.sum().backward() |
| expected_grad = torch.zeros(10, 10) |
| expected_grad[4:6] = 4 |
| self.assertEqual(x.grad.data, expected_grad) |
| |
| x.grad.data.zero_() |
| grad_output = torch.randn(2, 10) |
| outputs = x.chunk(5) |
| outputs[0].backward(grad_output) |
| expected_grad = torch.zeros(10, 10) |
| expected_grad[:2] = grad_output |
| self.assertEqual(x.grad.data, expected_grad) |
| |
| def test_gc_in_destructor(self): |
| """ |
| Previously, if a Function destructor triggered a garbage collection, |
| the Variable's tp_dealloc handler would get called twice leading to a |
| segfault. |
| """ |
| class CollectOnDelete(Function): |
| |
| def __del__(self): |
| gc.collect() |
| |
| for i in range(10): |
| Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete()) |
| |
| @unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU") |
| def test_unused_output_gpu(self): |
| from torch.nn.parallel._functions import Broadcast |
| x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True) |
| outputs = Broadcast.apply(list(range(torch.cuda.device_count())), x) |
| y = outputs[-1] * 2 |
| y.sum().backward() |
| self.assertEqual(x.grad.data, torch.ones(5, 5) * 2) |
| |
| @unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU") |
| def test_backward_device(self): |
| # check that current device matches the variable's device |
| device = [None] |
| |
| class Identity(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x.clone() |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| device[0] = torch.cuda.current_device() |
| return grad_output.clone() |
| |
| v = Variable(torch.randn(1).cuda(1), requires_grad=True) |
| Identity.apply(v).backward() |
| self.assertEqual(device[0], 1) |
| |
| def test_detach(self): |
| x = Variable(torch.randn(10, 10), requires_grad=True) |
| y = x + 2 |
| y = y.detach() |
| z = y * 4 + 2 |
| self.assertFalse(y.requires_grad) |
| self.assertFalse(z.requires_grad) |
| |
| x = Variable(torch.randn(10, 10), requires_grad=True) |
| y = x * 2 |
| y = y.detach() |
| self.assertFalse(y.requires_grad) |
| self.assertIsNone(y.grad_fn) |
| z = x + y |
| z.sum().backward() |
| # This is an incorrect gradient, but we assume that's what the user |
| # wanted. detach() is an advanced option. |
| self.assertEqual(x.grad.data, torch.ones(10, 10)) |
| |
| # in-place detach |
| x = Variable(torch.randn(10, 10), requires_grad=True) |
| y = Variable(torch.randn(10, 10), requires_grad=True) |
| a = x * 2 |
| (y + a).sum().backward(retain_graph=True) |
| a.detach_() |
| self.assertFalse(a.requires_grad) |
| (y + a).sum().backward() # this won't backprop to x |
| self.assertEqual(x.grad.data, torch.ones(10, 10) * 2) |
| self.assertEqual(y.grad.data, torch.ones(10, 10) * 2) |
| |
| # in-place deatch on a view raises an exception |
| view = x.narrow(0, 1, 4) |
| self.assertRaisesRegex(RuntimeError, 'view', lambda: view.detach_()) |
| |
| def test_detach_base(self): |
| "detaching base does not detach view" |
| x = Variable(torch.randn(10, 10), requires_grad=True) |
| view = x.narrow(0, 1, 4) |
| x.detach_() |
| self.assertFalse(x.requires_grad) |
| self.assertTrue(view.requires_grad) |
| self.assertIsNotNone(view.grad_fn) |
| self.assertIs(view._base, x) |
| |
| def _test_type_conversion_backward(self, t, ): |
| fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=True) |
| fvar.double().sum().backward() |
| self.assertEqual(fvar.grad, torch.ones_like(fvar)) |
| self.assertEqual(type(fvar.grad.data), type(fvar.data)) |
| dvar = Variable(t(torch.randn(5, 5).double()), requires_grad=True) |
| dvar.float().sum().backward() |
| self.assertEqual(dvar.grad, torch.ones_like(dvar)) |
| self.assertEqual(type(dvar.grad.data), type(dvar.data)) |
| |
| def test_type_conversions(self): |
| x = Variable(torch.randn(5, 5)) |
| self.assertIs(type(x.float().data), torch.FloatTensor) |
| self.assertIs(type(x.int().data), torch.IntTensor) |
| if torch.cuda.is_available(): |
| self.assertIs(type(x.float().cuda().data), torch.cuda.FloatTensor) |
| self.assertIs(type(x.int().cuda().data), torch.cuda.IntTensor) |
| self.assertIs(type(x.int().cuda().cpu().data), torch.IntTensor) |
| if torch.cuda.device_count() >= 2: |
| x2 = x.float().cuda(1) |
| self.assertIs(type(x2.data), torch.cuda.FloatTensor) |
| self.assertIs(x2.get_device(), 1) |
| x2 = x.float().cuda() |
| self.assertIs(type(x2.data), torch.cuda.FloatTensor) |
| self.assertIs(x2.get_device(), 0) |
| x2 = x2.cuda(1) |
| self.assertIs(type(x2.data), torch.cuda.FloatTensor) |
| self.assertIs(x2.get_device(), 1) |
| y = Variable(torch.randn(5).cuda(1), requires_grad=True) |
| y.cpu().sum().backward() |
| self.assertIs(y.grad.get_device(), 1) |
| self.assertIs(y.long().data.get_device(), 1) |
| |
| for t in [torch.DoubleTensor, torch.FloatTensor, torch.IntTensor, torch.ByteTensor]: |
| for y_var in (True, False): |
| y = torch.randn(5, 5).type(t) |
| y = Variable(y) if y_var else y |
| self.assertIs(type(x.type(t).data), t) |
| self.assertIs(type(x.type_as(y).data), t) |
| if torch.cuda.is_available(): |
| for x_cuda in (True, False): |
| for y_cuda in (True, False): |
| x_c = x.cuda() if x_cuda else x |
| y_c = y.cuda() if y_cuda else y |
| y_type = type(y_c.data) if y_var else type(y_c) |
| y_typestr = ('torch.cuda.' if y_cuda else 'torch.') + y_type.__name__ |
| self.assertIs(y_type, type(x_c.type(y_typestr).data)) |
| self.assertIs(type(y_c.data) if y_var else type(y_c), type(x_c.type_as(y_c).data)) |
| |
| self._test_type_conversion_backward(lambda x: x) |
| if torch.cuda.is_available(): |
| self._test_type_conversion_backward(lambda x: x.cuda()) |
| if torch.cuda.device_count() >= 2: |
| # one of these has to be the non-default device |
| self._test_type_conversion_backward(lambda x: x.cuda(0)) |
| self._test_type_conversion_backward(lambda x: x.cuda(1)) |
| |
| def _test_pyscalar_conversions(self, t, integral_conv): |
| # integral -> integral |
| l = Variable(t(torch.zeros(1, 1, 1).long())) |
| pyscalar = -12345 |
| l[0] = pyscalar |
| self.assertEqual(integral_conv(l), pyscalar) |
| |
| # floating point -> floating point |
| f = Variable(t(torch.randn(1, 1))) |
| pyscalar = -12345.1 |
| f[0] = pyscalar |
| self.assertEqual(float(f), pyscalar) |
| f[0] = float('nan') |
| self.assertTrue(math.isnan(float(f))) |
| f[0] = float('inf') |
| self.assertEqual(float(f), float('inf'), allow_inf=True) |
| f[0] = float('-inf') |
| self.assertEqual(float(f), float('-inf'), allow_inf=True) |
| |
| # integral -> floating point |
| # check we can convert something that loses precision |
| pyscalar = 1234567890123456789 |
| self.assertNotEqual(pyscalar, integral_conv(float(pyscalar))) |
| l[0] = pyscalar |
| self.assertEqual(float(l), float(pyscalar)) |
| |
| # floating point -> integral |
| f[0] = float('nan') |
| self.assertRaises(ValueError, lambda: integral_conv(f[0])) |
| f[0] = float('inf') |
| self.assertRaises(OverflowError, lambda: integral_conv(f[0])) |
| f[0] = float('-inf') |
| self.assertRaises(OverflowError, lambda: integral_conv(f[0])) |
| f[0] = sys.float_info.max |
| self.assertEqual(integral_conv(f), sys.float_info.max) |
| |
| # bool, nonzero |
| def test_nonzero(tensor, value, expected): |
| tensor[0] = value |
| self.assertEqual(expected, bool(tensor)) |
| self.assertEqual(expected, True if tensor else False) |
| |
| test_nonzero(l, 0, False) |
| test_nonzero(l, -2, True) |
| test_nonzero(f, 0.0, False) |
| test_nonzero(f, sys.float_info.min, True) |
| test_nonzero(f, float('nan'), bool(float('nan'))) |
| test_nonzero(f, float('inf'), bool(float('inf'))) |
| test_nonzero(f, float('-inf'), bool(float('-inf'))) |
| |
| def test_pyscalar_conversions(self): |
| self._test_pyscalar_conversions(lambda x: x, lambda x: int(x)) |
| if sys.version_info[0] == 2: |
| self._test_pyscalar_conversions(lambda x: x, lambda x: long(x)) |
| if torch.cuda.is_available(): |
| self._test_pyscalar_conversions(lambda x: x.cuda(), lambda x: int(x)) |
| if sys.version_info[0] == 2: |
| self._test_pyscalar_conversions(lambda x: x.cuda(), lambda x: long(x)) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") |
| def test_pin_memory(self): |
| x = Variable(torch.randn(2, 2), requires_grad=True) |
| self.assertEqual(x, x.pin_memory()) |
| self.assertIsNot(x, x.pin_memory()) |
| self.assertTrue(x.pin_memory().requires_grad) |
| gradcheck(lambda x: x.pin_memory(), [x]) |
| gradgradcheck(lambda x: x.pin_memory(), [x]) |
| |
| def test_isolated_node(self): |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| a = x + y |
| b = torch.max(a, 1, True)[1].repeat(1, 5).double() |
| o = (b + a).sum() |
| o.backward() |
| |
| def test_shape(self): |
| x = Variable(torch.randn(3, 4)) |
| self.assertEqual(2, len(x.shape)) |
| self.assertEqual(x.shape[0], 3) |
| self.assertEqual(x.shape[1], 4) |
| |
| def test_numpy_requires_grad(self): |
| x = Variable(torch.randn(2, 2), requires_grad=True) |
| self.assertRaisesRegex(RuntimeError, 'requires grad', lambda: x.numpy()) |
| |
| def test_return_leaf(self): |
| class Identity(Function): |
| |
| def forward(self, a, b): |
| return a, a + b |
| |
| def backward(self, grad_a, grad_b): |
| return grad_a + grad_b, grad_b |
| |
| class Inplace(InplaceFunction): |
| |
| def forward(self, a, b): |
| self.mark_dirty(a) |
| return a.add_(b), b + 2 |
| |
| def backward(self, grad_a, grad_b): |
| return grad_a, grad_a + grad_b |
| |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| q, p = Identity()(x, y) |
| # Make sure hooks only receive grad from usage of q, not x. |
| q.register_hook( |
| lambda grad: self.assertEqual(grad.data, torch.ones(5, 5))) |
| (q + p + x).sum().backward() |
| self.assertEqual(x.grad.data, torch.ones(5, 5) * 3) |
| self.assertEqual(y.grad.data, torch.ones(5, 5)) |
| del q, p # these need to be freed, or next part will raise an error |
| |
| def test_return_leaf_inplace(self): |
| class Inplace(InplaceFunction): |
| |
| def forward(self, a, b): |
| self.mark_dirty(a) |
| return a.add_(b), b + 2 |
| |
| def backward(self, grad_a, grad_b): |
| return grad_a, grad_a + grad_b |
| |
| x = Variable(torch.randn(5, 5)) |
| y = Variable(torch.randn(5, 5), requires_grad=True) |
| |
| fn = Inplace(True) |
| q, p = fn(x, y) |
| self.assertIs(q, x) |
| self.assertIs(q.grad_fn, fn) |
| self.assertTrue(q.requires_grad) |
| q.sum().backward() |
| self.assertEqual(y.grad.data, torch.ones(5, 5)) |
| |
| def test_leaf_assignment(self): |
| x = Variable(torch.randn(5, 5)) |
| y = Variable(torch.randn(5), requires_grad=True) |
| z = Variable(torch.randn(5), requires_grad=True) |
| |
| x[0] = y |
| x[1] = 2 * z |
| self.assertTrue(x.requires_grad) |
| self.assertIsNot(x.grad_fn, None) |
| x.sum().backward() |
| self.assertEqual(y.grad.data, torch.ones(5)) |
| self.assertEqual(z.grad.data, torch.ones(5) * 2) |
| |
| def test_no_grad_assignment(self): |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = Variable(torch.randn(5)) |
| with torch.no_grad(): |
| x[0] = y |
| |
| self.assertTrue(x.requires_grad) |
| self.assertIsNone(x.grad_fn) |
| |
| def test_no_grad_modifies_version(self): |
| x = Variable(torch.randn(5), requires_grad=True) |
| y = Variable(torch.randn(5), requires_grad=True) |
| z = (x * y).sum() |
| with torch.no_grad(): |
| x *= 2 |
| self.assertRaisesRegex(RuntimeError, 'modified by an inplace operation', |
| lambda: z.backward()) |
| |
| def test_backward_copy(self): |
| # This tests checks backward engine for a very subtle bug that appreared |
| # in one of the initial versions of autograd. Gradients tensors were |
| # simply stored in lists while the function waited for all its gradients |
| # to be computed. However, sometimes an output was used multiple times, |
| # so the gradients needed to be summed. Engine used to keep a need_copy |
| # set of tensors that will need a clone upon next addition and removed |
| # them from the set as soon as the clone was performed. However, this |
| # could lead to incorrect results if the same gradient tensor was |
| # buffered in three places in the graph: |
| # 1. When accumulating gradients in one of these places it was cloned |
| # and removed from need_copy set. |
| # 2. When accumulating in second place, it wasn't in the need_copy set, |
| # so the gradients were simply accumulated in-place (which already |
| # modified the grad in 3rd place) |
| # 3. When accumulating in the third place, it wasn't in the need_copy set |
| # as well, so the incoming gradient was summed in-place, yielding |
| # incorrect results in all functions, except the first one. |
| x = Variable(torch.ones(5, 5), requires_grad=True) |
| y = Variable(torch.ones(5, 5), requires_grad=True) |
| # Simulate that we're in the middle of the graph |
| a = x + 2 |
| b = y + 2 |
| c = x + 2 |
| # This op will just return grad_output two times in backward |
| add1 = a + b |
| add2 = add1 + c |
| # Simulate a long branch, so grad_output will get buffered. |
| for i in range(4): |
| a = a * 2 |
| b = b * 2 |
| c = c * 2 |
| branch = a + b + c |
| out = add2 + branch |
| # expected gradients are: |
| # for x: 34 (16 from final a, 16 from final c, 2 from add2) |
| # for y: 17 (16 from final b, 1 from add2) |
| grad_output = torch.ones(5, 5) |
| out.backward(grad_output) |
| self.assertEqual(x.grad.data, torch.ones(5, 5) * 34) |
| self.assertEqual(y.grad.data, torch.ones(5, 5) * 17) |
| |
| def test_functional_blas(self): |
| def compare(fn, *args): |
| unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg |
| for arg in args) |
| unpacked_result = fn(*unpacked_args) |
| packed_result = fn(*args).data |
| # if non-Variable torch function returns a pyscalar, compare to pyscalar |
| if not torch.is_tensor(unpacked_result): |
| assert packed_result.dim() == 1 |
| assert packed_result.nelement() == 1 |
| packed_result = packed_result[0] |
| self.assertEqual(packed_result, unpacked_result) |
| |
| def test_blas_add(fn, x, y, z): |
| # Checks all signatures |
| compare(fn, x, y, z) |
| compare(fn, 0.5, x, y, z) |
| compare(fn, 0.5, x, 0.25, y, z) |
| |
| def test_blas(fn, x, y): |
| compare(fn, x, y) |
| |
| test_blas(torch.mm, Variable(torch.randn(2, 10)), |
| Variable(torch.randn(10, 4))) |
| test_blas_add(torch.addmm, Variable(torch.randn(2, 4)), |
| Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) |
| test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)), |
| Variable(torch.randn(4, 10, 4))) |
| test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)), |
| Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) |
| test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)), |
| Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) |
| test_blas(torch.mv, Variable(torch.randn(2, 10)), |
| Variable(torch.randn(10))) |
| test_blas_add(torch.addmv, Variable(torch.randn(2)), |
| Variable(torch.randn(2, 10)), Variable(torch.randn(10))) |
| test_blas(torch.ger, Variable(torch.randn(5)), |
| Variable(torch.randn(6))) |
| test_blas_add(torch.addr, Variable(torch.randn(5, 6)), |
| Variable(torch.randn(5)), Variable(torch.randn(6))) |
| test_blas(torch.matmul, Variable(torch.randn(6)), Variable(torch.randn(6))) |
| test_blas(torch.matmul, Variable(torch.randn(10, 4)), Variable(torch.randn(4))) |
| test_blas(torch.matmul, Variable(torch.randn(5)), Variable(torch.randn(5, 6))) |
| test_blas(torch.matmul, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) |
| test_blas(torch.matmul, Variable(torch.randn(5, 2, 10)), Variable(torch.randn(5, 10, 4))) |
| test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(3, 5, 10, 4))) |
| test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(10))) |
| test_blas(torch.matmul, Variable(torch.randn(10)), Variable(torch.randn(3, 5, 10, 4))) |
| |
| def test_save_none_for_backward(self): |
| test_case = self |
| |
| class MyFn(Function): |
| |
| def forward(self, input): |
| self.save_for_backward(None, input, None) |
| return input * input |
| |
| def backward(self, grad_output): |
| n1, input, n2 = self.saved_tensors |
| test_case.assertIsNone(n1) |
| test_case.assertIsNone(n2) |
| return 2 * input * grad_output |
| |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = MyFn()(x) |
| y.sum().backward() |
| self.assertEqual(x.grad.data, 2 * x.data) |
| |
| def test_too_many_grads(self): |
| class MyFn(Function): |
| |
| def forward(self, input): |
| return input |
| |
| def backward(self, grad_output): |
| return grad_output, None, None |
| |
| x = Variable(torch.randn(5, 5), requires_grad=True) |
| y = MyFn()(x) |
| y.sum().backward() |
| self.assertEqual(x.grad.data, x.data.clone().fill_(1)) |
| |
| def test_pickle(self): |
| x = Variable(torch.randn(10, 10), requires_grad=True) |
| y = Variable(torch.randn(10, 10), requires_grad=False) |
| |
| def assert_strict_equal(var1, var2): |
| self.assertEqual(var1.data, var2.data) |
| self.assertEqual(var1.requires_grad, var2.requires_grad) |
| |
| serialized = [pickle.dumps([x, y], protocol=p) for p in range(3)] |
| for dump in serialized: |
| xc, yc = pickle.loads(dump) |
| assert_strict_equal(xc, x) |
| assert_strict_equal(yc, y) |
| |
| def test_dep_nograd(self): |
| class F1(Function): |
| |
| def forward(self, input): |
| out = torch.randn(input.size()) |
| self.mark_non_differentiable(out) |
| return input, out |
| |
| def backward(self, grad_output, ignored): |
| return grad_output |
| |
| class F2(Function): |
| |
| def forward(self, input, ignored): |
| return input |
| |
| def backward(self, grad_output): |
| return grad_output, None |
| |
| x = Variable(torch.randn(5), requires_grad=True) |
| a, b = F1()(x) |
| b = b + 1 # separate F1 from F2 by another op |
| self.assertTrue(a.requires_grad) |
| self.assertFalse(b.requires_grad) |
| c = F2()(a, b) |
| c.backward(torch.ones(c.size())) |
| self.assertEqual(x.grad.data, torch.ones(x.size())) |
| |
| def test_reentrant(self): |
| y_data = torch.randn(2, 2) |
| |
| class Reenter(Function): |
| @staticmethod |
| def forward(ctx, x_data): |
| ctx.x = Variable(x_data, requires_grad=True) |
| ctx.y = Variable(y_data, requires_grad=True) |
| ctx.output_var = ctx.x * ctx.y |
| return ctx.output_var.data |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| ctx.output_var.sum().backward() |
| return ctx.x.grad * grad_output |
| |
| x = Variable(torch.randn(2, 2), requires_grad=True) |
| out = Reenter.apply(x) |
| out.sum().backward(create_graph=True) |
| self.assertEqual(x.grad.data, y_data) |
| |
| def test_cat(self): |
| f_args_variable = (Variable(torch.randn(1, S, S), requires_grad=True), |
| Variable(torch.randn(2, S, S), requires_grad=True), |
| Variable(torch.randn(3, S, S), requires_grad=True), |
| 0) |
| f_args_tensor = deepcopy(unpack_variables(f_args_variable)) |
| run_functional_checks(self, "test_cat", "cat", |
| lambda a, b, c, dim: torch.cat((a, b, c), dim), |
| True, f_args_variable, f_args_tensor) |
| |
| def test_cat_negdim_1(self): |
| f_args_variable = (Variable(torch.randn(S, S, 1), requires_grad=True), |
| Variable(torch.randn(S, S, 2), requires_grad=True), |
| Variable(torch.randn(S, S, 3), requires_grad=True), |
| -1) |
| f_args_tensor = deepcopy(unpack_variables(f_args_variable)) |
| run_functional_checks(self, "test_cat_negdim_1", "cat", |
| lambda a, b, c, dim: torch.cat((a, b, c), dim), |
| True, f_args_variable, f_args_tensor) |
| |
| def test_cat_negdim_2(self): |
| f_args_variable = (Variable(torch.randn(S, 1, S), requires_grad=True), |
| Variable(torch.randn(S, 2, S), requires_grad=True), |
| Variable(torch.randn(S, 3, S), requires_grad=True), |
| -2) |
| f_args_tensor = deepcopy(unpack_variables(f_args_variable)) |
| run_functional_checks(self, "test_cat_negdim_2", "cat", |
| lambda a, b, c, dim: torch.cat((a, b, c), dim), |
| True, f_args_variable, f_args_tensor) |
| |
| def test_cat_empty(self): |
| f_args_variable = (Variable(torch.randn(0), requires_grad=True), |
| Variable(torch.randn(S, S), requires_grad=True)) |
| # gradgradcheck doesn't work (because gradcheck doesn't work for empty outputs?) |
| # hence False passed below, but gradcheck checked explicitly. |
| f_args_tensor = deepcopy(unpack_variables(f_args_variable)) |
| run_functional_checks(self, "test_cat_empty", "cat", |
| lambda a, b: torch.cat((a, b)), |
| False, f_args_variable, f_args_tensor) |
| self.assertTrue(gradcheck(lambda a, b: torch.cat((a, b)), f_args_variable, eps=1e-6, atol=PRECISION)) |
| |
| @skipIfNoLapack |
| def test_potrf(self): |
| root = Variable(torch.tril(torch.rand(S, S)), requires_grad=True) |
| |
| def run_test(upper): |
| def func(root): |
| x = torch.mm(root, root.t()) |
| return torch.potrf(x, upper) |
| |
| gradcheck(func, [root]) |
| gradgradcheck(func, [root]) |
| |
| run_test(upper=True) |
| run_test(upper=False) |
| |
| @skipIfNoLapack |
| def test_trtrs(self): |
| def _test_with_size(N, C): |
| A = Variable(torch.rand(N, N), requires_grad=True) |
| b = Variable(torch.rand(N, C), requires_grad=True) |
| |
| for upper, transpose, unitriangular in product((True, False), repeat=3): |
| def func(A, b): |
| return torch.trtrs(b, A, upper, transpose, unitriangular) |
| |
| gradcheck(func, [A, b]) |
| gradgradcheck(func, [A, b]) |
| |
| _test_with_size(S, S + 1) |
| _test_with_size(S, S - 1) |
| |
| def test_variable_traverse(self): |
| def get_out_and_unrefed_cycle(): |
| inp = Variable(torch.randn(10), requires_grad=True) |
| tmp = inp.view(10, 1) |
| out = tmp.view(10) |
| |
| # Create a reference cycle that contains an |
| # intermediary Variable in the graph |
| my_list = [] |
| my_list.append(tmp) |
| my_list.append(my_list) |
| |
| return out |
| |
| out = get_out_and_unrefed_cycle() |
| gc.collect() |
| # This will segfault if things have been erroneously released |
| out.backward(torch.randn(out.size())) |
| |
| def test_norm_subgradient(self): |
| def run_test(input_size, norm_deg): |
| input = Variable(torch.zeros(*input_size), requires_grad=True) |
| input.norm(norm_deg).backward() |
| self.assertEqual(input.grad.data.abs().sum(), 0) |
| |
| run_test((10,), 2) |
| run_test((10, 10), 2) |
| run_test((10,), 3) |
| run_test((10,), 1) |
| run_test((10,), 1.5) |
| |
| @unittest.skipIf(sys.platform == "win32", "Profiler uses `c++filt`, which doesn't exist on Windows.") |
| def test_profiler(self): |
| x = Variable(torch.randn(10, 10)) |
| |
| with profile() as p: |
| y = x * 2 + 4 |
| |
| last_end = 0 |
| names = ['mul', 'add'] |
| self.assertEqual(len(p.function_events), len(names)) |
| for info, expected_name in zip(p.function_events, names): |
| self.assertGreater(info.cpu_interval.start, last_end) |
| self.assertEqual(info.name, expected_name) |
| last_end = info.cpu_interval.end |
| |
| def test_dir(self): |
| x = Variable(torch.randn(10, 10)) |
| keys = dir(x) |
| self.assertIn('shape', keys) |
| |
| for key in keys: |
| self.assertTrue(hasattr(x, key)) |
| |
| def test_as_strided(self): |
| x = Variable(torch.arange(0, 25).view(5, 5), requires_grad=True) |
| |
| def as_strided(x): |
| return x.as_strided([3, 3], [6, 2], 2) |
| |
| gradcheck(as_strided, [x], raise_exception=True) |
| gradgradcheck(as_strided, [x], [Variable(torch.randn(3, 3))]) |
| |
| def _test_where_functional(self, t): |
| x = Variable(t(torch.randn(5, 5)), requires_grad=True) |
| y = Variable(t(torch.randn(5, 5)), requires_grad=True) |
| cond = Variable(t(mask_not_all_zeros((5, 5))), requires_grad=False) |
| |
| def where(cond, x, y): |
| return torch.where(cond, x, y) |
| |
| gradcheck(where, [cond, x, y], raise_exception=True) |
| gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5)))]) |
| |
| x = Variable(t(torch.randn(5, 1, 5)), requires_grad=True) |
| y = Variable(t(torch.randn(5, 5, 1)), requires_grad=True) |
| gradcheck(where, [cond, x, y], raise_exception=True) |
| gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5, 5)))]) |
| |
| def test_where_functional(self): |
| self._test_where_functional(lambda t: t) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") |
| def test_where_functional_cuda(self): |
| self._test_where_functional(lambda t: t.cuda()) |
| |
| def test_inplace_view_backprop_base(self): |
| # modify view and back-prop through base |
| root = Variable(torch.randn(2, 2), requires_grad=True) |
| x = root.clone() |
| v1 = x.narrow(0, 0, 1) |
| v1.mul_(2) |
| x.sum().backward() |
| self.assertEqual(root.grad.data.tolist(), [[2, 2], [1, 1]]) |
| |
| def test_inplace_view_backprop_view_of_view(self): |
| # modify view and backprop through view-of-view |
| root = Variable(torch.randn(2, 2), requires_grad=True) |
| x = root.clone() |
| v1 = x.narrow(0, 0, 1) |
| v2 = x.narrow(0, 0, 1) |
| v1.mul_(2) |
| v2.sum().backward() |
| self.assertEqual(root.grad.data.tolist(), [[2, 2], [0, 0]]) |
| |
| def test_inplace_view_of_view(self): |
| # modify view-of-view and backprop through base |
| root = Variable(torch.randn(2, 2), requires_grad=True) |
| x = root.clone() |
| v1 = x.narrow(0, 0, 1) |
| v2 = v1.narrow(1, 1, 1) |
| v2.mul_(2) |
| x.sum().backward() |
| self.assertEqual(root.grad.data.tolist(), [[1, 2], [1, 1]]) |
| |
| def test_inplace_view_gradcheck(self): |
| # gradcheck modifications to views |
| a = Variable(torch.randn(4, 4), requires_grad=True) |
| b = Variable(torch.randn(2, 2), requires_grad=True) |
| |
| def func(root, b): |
| x = root.clone() |
| x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b) |
| x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b) |
| return x |
| |
| gradcheck(func, [a, b], raise_exception=True) |
| go = Variable(torch.randn(a.size()), requires_grad=True) |
| gradgradcheck(func, (a, b), (go,)) |
| |
| def test_inplace_view_makes_base_require_grad(self): |
| # in-place modification to view makes base require grad |
| a = Variable(torch.randn(4, 4), requires_grad=False) |
| b = Variable(torch.randn(4, 2), requires_grad=True) |
| |
| def func(root, b): |
| x = root.clone() |
| self.assertFalse(x.requires_grad) |
| x.narrow(1, 2, 2).mul_(b) |
| self.assertTrue(x.requires_grad) |
| return x |
| |
| gradcheck(func, [a, b], raise_exception=True) |
| go = Variable(torch.randn(a.size()), requires_grad=True) |
| gradgradcheck(func, (a, b), (go,)) |
| |
| def test_inplace_view_backprop_view(self): |
| # modify view and backprop through view |
| a = Variable(torch.Tensor([2, 5]), requires_grad=False) |
| b = Variable(torch.Tensor([3]), requires_grad=True) |
| res = a.narrow(0, 1, 1).mul_(b) |
| res.sum().backward() |
| self.assertEqual(b.grad.data.tolist(), [5]) |
| self.assertIsNone(a.grad) |
| |
| def test_inplace_view_modify_base(self): |
| # Test that an in-place operation on a base that forced it to require |
| # grad also forces any previous views to require grad and backprop |
| # correctly |
| r = Variable(torch.ones(1), requires_grad=True) |
| |
| def fn(r): |
| x = Variable(torch.ones(5)) |
| v = x.select(0, 1) |
| self.assertFalse(v.requires_grad) |
| self.assertIsNone(v.grad_fn) |
| x.add_(r) # v is now dependent on r due to the in-place op on x |
| self.assertTrue(v.requires_grad) |
| return v |
| |
| gradcheck(fn, [r]) |
| gradgradcheck(fn, [r]) |
| |
| def test_inplace_view_python(self): |
| # in-place modifications of Python-autograd created view |
| a = Variable(torch.randn(4, 4), requires_grad=True) |
| b = Variable(torch.randn(2, 2), requires_grad=True) |
| |
| class PyAdd(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, y): |
| ctx.mark_dirty(x) |
| x.add_(y) |
| return x |
| |
| @staticmethod |
| def backward(ctx, grad): |
| return grad, grad |
| |
| def func(root, b): |
| x = root.clone() |
| PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b) |
| PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b) |
| return x |
| |
| gradcheck(func, [a, b], raise_exception=True) |
| go = Variable(torch.randn(a.size()), requires_grad=True) |
| gradgradcheck(func, (a, b), (go,)) |
| |
| def test_inplace_view_non_contig(self): |
| data = torch.ones(2, 3, 2).select(2, 1).t() |
| root = Variable(data, requires_grad=True) |
| x = root.clone() |
| v1 = x.narrow(0, 0, 1) |
| v2 = v1.narrow(1, 1, 1) |
| v2.mul_(2) |
| x.sum().backward() |
| self.assertEqual(root.grad.data.tolist(), [[1, 2], [1, 1], [1, 1]]) |
| |
| def test_inplace_view_saved_output(self): |
| # Test an in-place operation on a view in which the in-place op saves |
| # its output. Previously, this created a reference cycle. |
| dealloc = [0] |
| |
| class IncrementOnDelete(object): |
| def __del__(self): |
| dealloc[0] += 1 |
| |
| def test(): |
| root = Variable(torch.randn(3, 3), requires_grad=True) |
| copy = root.clone() |
| copy.grad_fn.register_hook(IncrementOnDelete()) |
| view = copy.view(9) |
| torch.nn.functional.relu(view, inplace=True) |
| |
| test() |
| self.assertEqual(dealloc[0], 1) |
| |
| |
| def index_variable(shape, max_indices): |
| if not isinstance(shape, tuple): |
| shape = (shape,) |
| index = torch.rand(*shape).mul_(max_indices).floor_().long() |
| return Variable(index, requires_grad=False) |
| |
| |
| def index_perm_variable(shape, max_indices): |
| if not isinstance(shape, tuple): |
| shape = (shape,) |
| |
| index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape) |
| return Variable(index, requires_grad=False) |
| |
| |
| def gather_variable(shape, index_dim, max_indices, duplicate=False): |
| assert len(shape) == 2 |
| assert index_dim < 2 |
| batch_dim = 1 - index_dim |
| index = torch.LongTensor(*shape) |
| for i in range(shape[index_dim]): |
| index.select(index_dim, i).copy_( |
| torch.randperm(max_indices)[:shape[batch_dim]]) |
| if duplicate: |
| index.select(batch_dim, 0).copy_(index.select(batch_dim, 1)) |
| return Variable(index, requires_grad=False) |
| |
| |
| def mask_not_all_zeros(shape): |
| assert len(shape) > 0 |
| while True: |
| result = torch.randn(shape).gt(0) |
| if result.sum() > 0: |
| return result |
| |
| |
| def prod_zeros(dim_size, dim_select): |
| assert len(dim_select) == 2 |
| result = torch.randn(dim_size, dim_size, dim_size) |
| result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_() |
| result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_() |
| result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_() |
| return Variable(result, requires_grad=True) |
| |
| |
| def prod_single_zero(dim_size): |
| result = torch.randn(dim_size, dim_size) |
| result[0, 1] = 0 |
| return Variable(result, requires_grad=True) |
| |
| |
| def random_square_matrix_of_rank(l, rank): |
| assert rank <= l |
| A = torch.randn(l, l) |
| u, s, v = A.svd() |
| for i in range(l): |
| if i >= rank: |
| s[i] = 0 |
| elif s[i] == 0: |
| s[i] = 1 |
| return u.mm(torch.diag(s)).mm(v.transpose(0, 1)) |
| |
| |
| def random_symmetric_matrix(l): |
| A = torch.randn(l, l) |
| return A.mm(A.transpose(0, 1)) |
| |
| |
| def random_fullrank_matrix_distinct_singular_value(l): |
| A = torch.randn(l, l) |
| u, _, v = A.svd() |
| s = torch.arange(1, l + 1).mul_(1.0 / (l + 1)) |
| return u.mm(torch.diag(s)).mm(v.transpose(0, 1)) |
| |
| |
| class dont_convert(tuple): |
| pass |
| |
| |
| L = 20 |
| M = 10 |
| S = 5 |
| |
| # (name, size, args...) |
| method_tests = [ |
| ('add', (S, S, S), ((S, S, S),)), |
| ('add', (S, S, S), ((S, S),), 'broadcast_rhs'), |
| ('add', (S, S), ((S, S, S),), 'broadcast_lhs'), |
| ('add', (S, 1, S), ((M, S),), 'broadcast_all'), |
| ('add', (S, S, S), (3.14,), 'constant'), |
| ('__radd__', (S, S, S), (3.14,), 'constant'), |
| ('sub', (S, S, S), ((S, S, S),)), |
| ('sub', (S, S, S), ((S, S),), 'broadcast_rhs'), |
| ('sub', (S, S), ((S, S, S),), 'broadcast_lhs'), |
| ('sub', (S, 1, S), ((M, S),), 'broadcast_all'), |
| ('sub', (S, S, S), (3.14,), 'constant'), |
| ('__rsub__', (S, S, S), (3.14,), 'constant'), |
| ('mul', (S, S, S), ((S, S, S),)), |
| ('mul', (S, S, S), ((S, S),), 'broadcast_rhs'), |
| ('mul', (S, S), ((S, S, S),), 'broadcast_lhs'), |
| ('mul', (S, 1, S), ((M, S),), 'broadcast_all'), |
| ('mul', (S, S, S), (3.14,), 'constant'), |
| ('__rmul__', (S, S, S), (3.14,), 'constant'), |
| ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,)), |
| ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs'), |
| ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'), |
| ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all'), |
| ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'), |
| ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'), |
| ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,)), |
| ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'), |
| ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'), |
| ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'), |
| ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'), |
| ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'), |
| ('transpose', (1, 2, 3), (1, 2), 'dim', [0, 1]), |
| ('transpose', torch.rand(L, L), (0, 1), '2d'), |
| ('transpose', torch.rand(S, S, S), (2, 0), '3d'), |
| ('t', (1, 2), NO_ARGS), |
| ('view', (S, S, S), (S * S, S),), |
| ('view', (S, S, S), (torch.Size([S * S, S]),), 'size'), |
| ('view', (S,), (S,), '1d'), |
| ('view_as', (S, S, S), (Variable(torch.rand(S * S, S), requires_grad=False),)), |
| ('expand', (S, 1, 1), (S, S, S)), |
| ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'), |
| ('expand', (S, 1), (S, S, S), 'new_dim'), |
| ('expand', (1,), (S, S, S), '1_element'), |
| ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'), |
| ('exp', (S, S, S), NO_ARGS), |
| ('expm1', (S, S, S), NO_ARGS), |
| ('erf', torch.rand(S, S, S), NO_ARGS), |
| ('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS), |
| ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS), |
| ('log1p', torch.rand(S, S, S), NO_ARGS), |
| ('tanh', (S, S, S), NO_ARGS), |
| ('sigmoid', (S, S, S), NO_ARGS), |
| ('sinh', (S, S, S), NO_ARGS), |
| ('cosh', (S, S, S), NO_ARGS), |
| ('abs', (S, S, S), NO_ARGS), |
| ('clamp', (S, S, S), (0, 1)), |
| ('clamp', (S, S, S), (None, 0.5), 'min'), |
| ('clamp', (S, S, S), (0.5, None), 'max'), |
| ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS), |
| ('sin', (S, S, S), NO_ARGS), |
| ('cos', (S, S, S), NO_ARGS), |
| ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS), |
| ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS), |
| ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS), |
| ('atan', (S, S, S), NO_ARGS), |
| ('atan2', (S, S, S), ((S, S, S),)), |
| ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS), |
| ('round', (S, S, S), NO_ARGS), |
| ('sign', (S, S, S), NO_ARGS), |
| ('trunc', (S, S, S), NO_ARGS), |
| ('floor', (S, S, S), NO_ARGS), |
| ('ceil', (S, S, S), NO_ARGS), |
| ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS), |
| ('frac', (S, S, S), NO_ARGS), |
| ('fmod', (S, S, S), (1.5,)), |
| ('fmod', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'), |
| ('fmod', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'), |
| ('fmod', (S, S, S), (Variable(torch.rand(S) + 1.5, requires_grad=False),), 'tensor_broadcast_rhs'), |
| ('fmod', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'), |
| ('remainder', (S, S, S), (1.5,)), |
| ('remainder', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'), |
| ('remainder', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'), |
| ('remainder', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'), |
| ('lerp', (S, S, S), ((S, S, S), 0.4)), |
| ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'), |
| ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'), |
| ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all'), |
| ('max', (S, S, S), NO_ARGS), |
| ('max', (S, S, S), (1,), 'dim', [0]), |
| ('max', (S, S, S), (1, True,), 'keepdim_dim', [0]), |
| ('max', (S,), (0,), 'dim_1d', [0]), |
| ('max', (S,), (0, True,), 'keepdim_dim_1d', [0]), |
| ('max', (S, S, S), ((S, S, S),), 'elementwise'), |
| ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'), |
| ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'), |
| ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'), |
| ('min', (S, S, S), NO_ARGS), |
| ('min', (S, S, S), (1,), 'dim', [0]), |
| ('min', (S, S, S), (1, True,), 'keepdim_dim', [0]), |
| ('min', (S,), (0,), 'dim_1d', [0]), |
| ('min', (S,), (0, True,), 'keepdim_dim_1d', [0]), |
| ('min', (S, S, S), ((S, S, S),), 'elementwise'), |
| ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'), |
| ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'), |
| ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'), |
| ('mean', (S, S, S), NO_ARGS), |
| ('mean', (S, S, S), (1,), 'dim', [0]), |
| ('mean', (S, S, S), (1, True,), 'keepdim_dim', [0]), |
| ('mean', (S,), (0,), 'dim_1d', [0]), |
| ('mean', (S,), (0, True), 'keepdimdim_1d', [0]), |
| ('kthvalue', (S, S, S), (2,)), |
| ('kthvalue', (S, S, S), (2, 1,), 'dim', [1]), |
| ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', [1]), |
| ('kthvalue', (S,), (2, 0,), 'dim_1d', [1]), |
| ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', [1]), |
| ('median', (S, S, S), NO_ARGS), |
| ('median', (S, S, S), (1,), 'dim', [0]), |
| ('median', (S, S, S), (1, True,), 'keepdim_dim', [0]), |
| ('median', (S,), (0,), 'dim_1d', [0]), |
| ('median', (S,), (0, True,), 'keepdim_dim_1d', [0]), |
| ('mode', (S, S, S), NO_ARGS), |
| ('mode', (S, S, S), (1,), 'dim', [0]), |
| ('mode', (S, S, S), (1, True,), 'keepdim_dim', [0]), |
| ('mode', (S,), (0,), 'dim_1d', [0]), |
| ('mode', (S,), (0, True,), 'keepdim_dim_1d', [0]), |
| ('sum', (S, S, S), NO_ARGS), |
| ('sum', (S, S, S), (1,), 'dim', [0]), |
| ('sum', (S, S, S), (1, True,), 'keepdim_dim', [0]), |
| ('sum', (S,), (0,), 'dim_1d', [0]), |
| ('sum', (S,), (0, True), 'keepdim_1d', [0]), |
| ('prod', (S, S, S), NO_ARGS), |
| ('prod', (S, S, S), (1,), 'dim', [0]), |
| ('prod', (S, S, S), (1, True,), 'keepdim_dim', [0]), |
| ('prod', (S,), (0,), 'dim_1d', [0]), |
| ('prod', (S,), (0, True), 'keepdim_1d', [0]), |
| ('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'), |
| ('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'), |
| ('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'), |
| ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', [0]), |
| ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', [0]), |
| ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', [0]), |
| ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', [0]), |
| ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', [0]), |
| ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', [0]), |
| ('prod', prod_single_zero(S), NO_ARGS, 'single_zero'), |
| ('var', (S, S, S), NO_ARGS), |
| ('var', (S, S, S), (1,), 'dim', [0]), |
| ('var', (S, S, S), (1, True, True), 'keepdim_dim', [0]), |
| ('var', (S,), (0,), 'dim_1d', [0]), |
| ('var', (S,), (0, True, True), 'keepdim_dim_1d', [0]), |
| ('std', (S, S, S), NO_ARGS), |
| ('std', (S, S, S), (1,), 'dim', [0]), |
| ('std', (S, S, S), (1, True, True), 'keepdim_dim', [0]), |
| ('std', (S,), (0,), 'dim_1d', [0]), |
| ('std', (S,), (0, True, True), 'keepdim_dim_1d', [0]), |
| ('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]), |
| ('renorm', (S, S, S), (1, 2, 3), 'norm_1'), |
| ('repeat', (S, S, S, S), (2, 3, 1, 4)), |
| ('repeat', (S, S, S, S), (2, 2, 1, 3, 1, 2), 'unsqueeze'), |
| ('cumsum', (S, S, S), (1,), 'dim0', [0]), |
| ('cumsum', (S, S, S), (1,), 'dim1', [0]), |
| ('cumsum', (S,), (0,), '1d', [0]), |
| ('cumprod', (S, S, S), (0,)), |
| ('cumprod', (S, S, S), (1,), 'dim1', [0]), |
| ('cumprod', (S,), (0,), '1d'), |
| ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', [0]), |
| ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', [0]), |
| ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', [0]), |
| ('unfold', (S, S, S, S), (1, 3, 1)), |
| ('unfold', (S, S, S), (2, 3, 2), 'lastdim'), |
| ('addmm', (S, M), ((S, S), (S, M)),), |
| ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs'), |
| ('addmm', (S, M), (0.2, 0.6, (S, S), (S, M)), 'coef'), |
| ('addmm', (1,), (0.2, 0.6, (S, S), (S, M)), 'broadcast_lhs_coef'), |
| ('addbmm', (S, M), ((S, S, S), (S, S, M)),), |
| ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'), |
| ('addbmm', (S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'), |
| ('addbmm', (1,), (0.2, 0.6, (S, S, S), (S, S, M)), 'broadcast_lhs_coef'), |
| ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),), |
| ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'), |
| ('baddbmm', (S, S, M), (0.2, 0.6, (S, S, S), (S, S, M)), 'coef'), |
| ('baddbmm', (1,), (0.2, 0.6, (S, S, S), (S, S, M)), 'broadcast_lhs_coef'), |
| ('addmv', (S,), ((S, M), (M,)),), |
| ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs'), |
| ('addmv', (S,), (0.2, 0.6, (S, M), (M,)), 'coef'), |
| ('addmv', (1,), (0.2, 0.6, (S, M), (M,)), 'broadcast_lhs_coef'), |
| ('addr', (S, M), ((S,), (M,)),), |
| ('addr', (1,), ((S,), (M,)), 'broadcast_lhs'), |
| ('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef'), |
| ('addr', (1,), (0.2, 0.6, (S,), (M,)), 'broadcast_lhs_coef'), |
| ('dot', (L,), ((L,),),), |
| ('mm', (S, M), ((M, S),)), |
| ('bmm', (M, S, M), ((M, M, S),)), |
| ('mv', (S, M), ((M,),)), |
| ('ger', (S,), ((M,),)), |
| ('matmul', (L,), ((L,),),), |
| ('matmul', (S, M), ((M,),), "2d_1d"), |
| ('matmul', (M, ), ((M, S),), "1d_2d"), |
| ('matmul', (S, M), ((M, S),), "2d_2d"), |
| ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"), |
| ('matmul', (S, S, M, M), ((M,),), "4d_1d"), |
| ('matmul', (M,), ((S, S, M, S),), "1d_4d"), |
| ('addcmul', (S, S), ((S, S), (S, S))), |
| ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'), |
| ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'), |
| ('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale'), |
| ('addcmul', (S, S), (0.5, (S, 1), (1, S)), 'scale_broadcast_rhs'), |
| ('addcmul', (1,), (0.5, (S, S, 1), (1, S)), 'scale_broadcast_all'), |
| ('addcdiv', (S, S), ((S, S), (S, S))), |
| ('addcdiv', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'), |
| ('addcdiv', (1,), ((S, S, 1), (1, S)), 'broadcast_all'), |
| ('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'), |
| ('addcdiv', (S, S), (0.5, (S, 1), (1, S)), 'scale_broadcast_rhs'), |
| ('addcdiv', (1,), (0.5, (S, S, 1), (1, S)), 'scale_broadcast_all'), |
| ('zero_', (S, S, S), NO_ARGS), |
| ('norm', (S, S), (2,)), |
| ('norm', (S, S), (0,), '0'), |
| ('norm', (S, S), (0.5,), '0_5'), |
| ('norm', (S, S), (1,), '1'), |
| ('norm', (S, S), (3,), '3'), |
| ('norm', (S, S), (-1,), 'neg_1'), |
| ('norm', (S, S), (-0.5,), 'neg_0_5'), |
| ('norm', (S, S), (-1.5,), 'neg_1_5'), |
| ('norm', torch.rand(S, S, S) + 5e-2, (1.5,), '1_5'), |
| ('norm', (S, S, S), (2, 1), '2_dim', [1]), |
| ('norm', (S, S, S), (3, 1), '3_dim', [1]), |
| ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', [1]), |
| ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', [1]), |
| ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', [1]), |
| ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', [1]), |
| ('norm', (S,), (2, 0), '2_dim_1d', [1]), |
| ('norm', (S,), (3, 0), '3_dim_1d', [1]), |
| ('norm', (S,), (2, 0, True), 'keepdim_2_dim_1d', [1]), |
| ('norm', (S,), (3, 0, True), 'keepdim_3_dim_1d', [1]), |
| ('clone', (S, M, S), NO_ARGS), |
| ('dist', (S, S, S), ((S, S, S),)), |
| ('dist', (S, S, S), ((S,),), 'broadcast_rhs'), |
| ('dist', (S,), ((S, S, S),), 'broadcast_lhs'), |
| ('dist', (S, 1, S), ((S, S),), 'broadcast_all'), |
| ('dist', (S, S, S), ((S, S, S), 4), '4'), |
| ('dist', (S, S, S), ((S,), 4), '4_broadcast_rhs'), |
| ('dist', (S,), ((S, S, S), 4), '4_broadcast_lhs'), |
| ('dist', (S, 1, S), ((S, S), 4), '4_broadcast_all'), |
| ('diag', (M, M), NO_ARGS, '2d'), |
| ('diag', (M,), NO_ARGS, '1d'), |
| ('diag', (M, M), (1,), '2d_1'), |
| ('diag', (M, M), (2,), '2d_2'), |
| ('tril', (M, M), NO_ARGS), |
| ('tril', (M, M), (2,), 'idx'), |
| ('triu', (M, M), NO_ARGS), |
| ('triu', (M, M), (2,), 'idx'), |
| ('trace', (M, M), NO_ARGS), |
| ('cross', (S, 3), ((S, 3),)), |
| ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'), |
| ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', [0]), |
| ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', [0]), |
| ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]), |
| ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]), |
| ('inverse', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), |
| ('det', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), |
| ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', NO_ARGS, [skipIfNoLapack]), |
| ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', NO_ARGS, [skipIfNoLapack]), |
| ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', NO_ARGS, [skipIfNoLapack]), |
| ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', NO_ARGS, [skipIfNoLapack]), |
| ('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, |
| 'distinct_postive_s', NO_ARGS, [skipIfNoLapack]), |
| ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), |
| ('gesv', (S, S), ((S, S),), '', NO_ARGS, [skipIfNoLapack]), |
| ('eq_', (S, S, S), ((S, S, S),)), |
| ('eq_', (S, S, S), ((1,),), 'broadcast_rhs'), |
| ('ne_', (S, S, S), ((S, S, S),)), |
| ('ne_', (S, S, S), ((1,),), 'broadcast_rhs'), |
| ('gt_', (S, S, S), ((S, S, S),)), |
| ('gt_', (S, S, S), ((1,),), 'broadcast_rhs'), |
| ('ge_', (S, S, S), ((S, S, S),)), |
| ('ge_', (S, S, S), ((1,),), 'broadcast_rhs'), |
| ('lt_', (S, S, S), ((S, S, S),)), |
| ('lt_', (S, S, S), ((1,),), 'broadcast_rhs'), |
| ('le_', (S, S, S), ((S, S, S),)), |
| ('le_', (S, S, S), ((1,),), 'broadcast_rhs'), |
| ('eq_', (S, S, S), (0,), 'pyscalar'), |
| ('ne_', (S, S, S), (0,), 'pyscalar'), |
| ('gt_', (S, S, S), (0,), 'pyscalar'), |
| ('ge_', (S, S, S), (0,), 'pyscalar'), |
| ('lt_', (S, S, S), (0,), 'pyscalar'), |
| ('le_', (S, S, S), (0,), 'pyscalar'), |
| ('permute', (1, 2, 3, 4), (0, 2, 3, 1)), |
| ('select', (S, S, S), (1, 2), 'dim', [0]), |
| ('select', (S,), (0, 2), '1d'), |
| ('narrow', (S, S, S), (1, 2, 2), 'dim', [0]), |
| ('slice', (S, S, S), (-2, 1, -1, 2)), |
| ('squeeze', (S, 1, S, 1), NO_ARGS), |
| ('squeeze', (S, 1, S, 1), (1,), '1_dim', [0]), |
| ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', [0]), |
| ('squeeze', (1,), (0,), '1d_dim0', [0]), |
| ('unsqueeze', (S, S, S), (0,), 'first', [0]), |
| ('unsqueeze', (S, S, S), (1,), 'middle', [0]), |
| ('unsqueeze', (S, S, S), (3,), 'last', [0]), |
| ('chunk', (S, S, S), (2,)), |
| ('chunk', (S, S, S), (S, 1), 'dim', [1]), |
| ('split', (S, S, S), (2,)), |
| ('split', (S, S, S), (S, 1), 'dim', [1]), |
| ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', [0]), |
| ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', [0]), |
| ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]), |
| ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]), |
| ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]), |
| ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]), |
| ('masked_select', (M, M), (Variable(mask_not_all_zeros((M, M)), requires_grad=False),)), |
| ('masked_select', (M, M), (Variable(mask_not_all_zeros((M,)), requires_grad=False),), 'broadcast_rhs'), |
| ('masked_select', (M,), (Variable(mask_not_all_zeros((M, M)), requires_grad=False),), 'broadcast_lhs'), |
| ('masked_select', (M, 1, M), (Variable(mask_not_all_zeros((M, M)), requires_grad=False),), |
| 'broadcast_all'), |
| ('masked_fill', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), 10)), |
| # no lhs or all broadcast on masked_fill or masked_scatter because it's always inplace |
| ('masked_fill', (M, M), (Variable(torch.ByteTensor(M,).bernoulli_(), requires_grad=False), 10), 'broadcast_rhs'), |
| ('masked_scatter', (M, M), (Variable(torch.ByteTensor(M, M).bernoulli_(), requires_grad=False), (M, M))), |
| ('masked_scatter', (M, M), (Variable(torch.ByteTensor(M,).bernoulli_(), requires_grad=False), (M, M)), |
| 'broadcast_rhs'), |
| ('resize', (S, S, S), (torch.Size([S * S, S])), 'fewer_dims'), |
| ('resize_as', (S, S, S), (Variable(torch.randn((S * S, S)), requires_grad=False),)), |
| ('sort', (S, M, S), NO_ARGS), |
| ('sort', (S, M, S), (1,), 'dim'), |
| ('sort', (S, M, S), (1, True), 'dim_desc'), |
| ('topk', (S, M, S), (3,)), |
| ('topk', (S, M, S), (3, 1), 'dim'), |
| ('topk', (S, M, S), (3, 1, True), 'dim_desc'), |
| ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort'), |
| ('take', (S, S, S), (Variable(torch.LongTensor([[-3, 2], [20, 2]])),)), |
| ('where', (M, M), (Variable(mask_not_all_zeros((M, M)), requires_grad=False), (M, M))), |
| ('where', (M, 1, M), (Variable(mask_not_all_zeros((M, M)), requires_grad=False), (M, M, 1)), 'broadcast_all'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([1, 2]),)), |
| ('__getitem__', torch.randn(S, S, S), (slice(0, 3),), 'slice'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(0, 3), 1]),), 'slice_index'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], [0, 0, 2]]),), 'adv_index'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 0, 3], [1, 1, 3], [0, 0, 2]]),), 'adv_index_dup'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(None), slice(None), [0, 3]]),), 'adv_index_end'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(None), [0, 3], slice(None)]),), 'adv_index_mid'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], slice(None), slice(None)]),), 'adv_index_beg'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], [1, 2], slice(None)]),), 'adv_index_comb'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], ]),), 'adv_index_sub'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], slice(None)]),), 'adv_index_sub_2'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], Ellipsis]),), 'adv_index_sub_3'), |
| ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], |
| Variable(torch.LongTensor([0, 0, 2]), requires_grad=False)]),), 'adv_index_var'), |
| ] |
| # TODO: clamp with min/max |
| |
| |
| def make_non_contiguous(tensor): |
| osize = list(tensor.size()) |
| |
| # randomly inflate a few dimensions in osize |
| for _ in range(2): |
| dim = random.randint(0, len(osize) - 1) |
| add = random.randint(4, 15) |
| osize[dim] = osize[dim] + add |
| |
| # narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension, |
| # (which will always happen with a 1-dimensional tensor), so let's make a new |
| # right-most dimension and cut it off |
| |
| input = tensor.new(torch.Size(osize + [random.randint(2, 3)])) |
| input = input.select(len(input.size()) - 1, random.randint(0, 1)) |
| # now extract the input of correct size from 'input' |
| for i in range(len(osize)): |
| if input.size(i) != tensor.size(i): |
| bounds = random.randint(1, input.size(i) - tensor.size(i)) |
| input = input.narrow(i, bounds, tensor.size(i)) |
| |
| input.copy_(tensor) |
| return input |
| |
| |
| def create_input(call_args, requires_grad=True, non_contiguous=False): |
| if not isinstance(call_args, tuple): |
| call_args = (call_args,) |
| |
| def map_arg(arg): |
| def maybe_non_contig(tensor): |
| return tensor if not non_contiguous else make_non_contiguous(tensor) |
| |
| if isinstance(arg, torch.Size) or isinstance(arg, dont_convert): |
| return arg |
| elif isinstance(arg, tuple) and not isinstance(arg[0], Variable): |
| return Variable(maybe_non_contig(torch.randn(*arg).double()), requires_grad=requires_grad) |
| elif torch.is_tensor(arg): |
| if isinstance(arg, torch.FloatTensor): |
| return Variable(maybe_non_contig(arg.double()), requires_grad=requires_grad) |
| else: |
| return Variable(maybe_non_contig(arg), requires_grad=requires_grad) |
| elif isinstance(arg, Variable) and non_contiguous: |
| return Variable(maybe_non_contig(arg.data), requires_grad=arg.requires_grad) |
| elif callable(arg): |
| return map_arg(arg()) |
| else: |
| return arg |
| return tuple(map_arg(arg) for arg in call_args) |
| |
| |
| def unpack_variables(args): |
| if isinstance(args, Variable): |
| return args.data |
| elif isinstance(args, tuple): |
| return tuple(unpack_variables(elem) for elem in args) |
| else: |
| return args |
| |
| |
| def generate_gradoutput(dummy_out, non_contiguous=False): |
| def maybe_non_contig(tensor): |
| return tensor if not non_contiguous else make_non_contiguous(tensor) |
| |
| if isinstance(dummy_out, tuple): |
| grad_y = tuple(Variable(maybe_non_contig(torch.randn(x.size())), requires_grad=x.requires_grad) |
| for x in dummy_out if isinstance(x, Variable)) |
| else: |
| grad_y = (Variable(maybe_non_contig(torch.randn(dummy_out.size())), requires_grad=dummy_out.requires_grad),) |
| |
| return grad_y |
| |
| EXCLUDE_FUNCTIONAL = { |
| 'addmm', |
| 'addbmm', |
| 'baddbmm', |
| 'addmv', |
| 'addr', |
| 'where' # argument order |
| } |
| EXCLUDE_GRADCHECK = { |
| } |
| EXCLUDE_GRADGRADCHECK = { |
| 'svd' |
| } |
| EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = { |
| # Some of the following det ones pass because random matrix has full rank |
| # with high probability. But we can't rely on this. So only test gradgrad on |
| # test_det_distinct_postive_s. |
| 'test_det', |
| 'test_det_symmetric', |
| 'test_det_dim2_null', |
| 'test_det_rank1', |
| 'test_det_rank2' |
| } |
| |
| |
| def exclude_tensor_method(name, test_name): |
| # there are no tensor equivalents for these (inplace or out) |
| exclude_all_tensor_method_by_test_name = { |
| 'test_clamp_min', |
| 'test_clamp_max', |
| 'test_slice', |
| 'test_where', |
| 'test_where_broadcast_all' |
| } |
| # there are no out-of-place tensor equivalents for these |
| exclude_outplace_tensor_method = { |
| 'index_add', |
| 'index_copy', |
| 'index_fill', |
| 'masked_fill', |
| 'masked_scatter', |
| 'resize', |
| 'resize_as', |
| 'scatter', |
| 'scatter_add', |
| 'det', |
| } |
| if test_name in exclude_all_tensor_method_by_test_name: |
| return True |
| is_magic_method = name[:2] == '__' and name[-2:] == '__' |
| is_inplace = name[-1] == "_" and not is_magic_method |
| if not is_inplace and name in exclude_outplace_tensor_method: |
| return True |
| return False |
| |
| |
| def gradgradcheck_method_precision_override(test_name): |
| # these are just empirical observations, we should improve |
| gradgradcheck_precision_override = { |
| 'test_norm': {'atol': 2e-2, 'rtol': 1e-2}, |
| 'test_norm_1_5': {'atol': 1.5e-2, 'rtol': 1e-2}, |
| 'test_norm_3': {'atol': 5e-2, 'rtol': 1e-2}, |
| 'test_dist': {'atol': 5e-2, 'rtol': 1e-2}, |
| 'test_dist_4': {'atol': 8e-2, 'rtol': 1e-2}, |
| } |
| non_broadcasted_test_name = test_name.split("_broadcast")[0] |
| override = gradgradcheck_precision_override.get(non_broadcasted_test_name) |
| if override: |
| if 'broadcast_lhs' in test_name or 'broadcast_rhs' in test_name: |
| # errors accumulated across 1 dimension |
| override = {'atol': override['atol'] * S, 'rtol': override['atol'] * S} |
| elif 'broadcast_all' in test_name: |
| # errors accumulated across multiple dimensions |
| override = {'atol': override['atol'] * S * S, 'rtol': override['atol'] * S * S} |
| return override |
| |
| |
| def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable, |
| input_variables, run_gradgradcheck=True): |
| test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION)) |
| if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME: |
| return |
| grad_y = generate_gradoutput(output_variable, non_contiguous=True) |
| gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name) |
| if gradgradcheck_precision_override is not None: |
| atol = gradgradcheck_precision_override['atol'] |
| rtol = gradgradcheck_precision_override['rtol'] |
| test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y, atol=atol, rtol=rtol)) |
| else: |
| test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y)) |
| |
| |
| def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, |
| f_args_variable, f_args_tensor): |
| output_variable = apply_fn(*f_args_variable) |
| if not exclude_tensor_method(name, test_name): |
| output_tensor = apply_fn(*f_args_tensor) |
| if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple): |
| output_tensor = torch.DoubleTensor((output_tensor,)) |
| test_case.assertEqual(unpack_variables(output_variable), output_tensor) |
| |
| if run_grad_checks: |
| run_grad_and_gradgrad_checks(test_case, name, test_name, apply_fn, |
| output_variable, f_args_variable) |
| |
| self_variable = f_args_variable[0] |
| if isinstance(output_variable, torch.autograd.Variable) and self_variable is not None: |
| output_variable.backward(torch.randn(*output_variable.size()).type_as(output_variable.data)) |
| test_case.assertTrue(type(self_variable.data) == type(self_variable.grad.data)) |
| test_case.assertTrue(self_variable.size() == self_variable.grad.size()) |
| |
| for test in method_tests: |
| name, self_size, args = test[:3] |
| basic_test_name = 'test_' + name |
| if len(test) >= 4 and test[3] != '': |
| basic_test_name += '_' + test[3] |
| |
| dim_args_idx = test[4] if len(test) == 5 else [] |
| |
| skipTestIf = test[5] if len(test) == 6 else [] |
| |
| for dim_perm in product([-1, 1], repeat=len(dim_args_idx)): |
| test_name = basic_test_name |
| new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)] |
| test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0) |
| new_args = tuple(new_args) |
| |
| # for-loop bodies don't define scopes, so we have to save the variables |
| # we want to close over in some way |
| def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name): |
| def check(name): |
| is_magic_method = name[:2] == '__' and name[-2:] == '__' |
| is_inplace = name[-1] == "_" and not is_magic_method |
| self_variable = create_input((self_size,), requires_grad=not is_inplace)[0] |
| args_variable = create_input(args, requires_grad=not is_inplace) |
| self_tensor = deepcopy(self_variable.data) |
| args_tensor = deepcopy(unpack_variables(args_variable)) |
| output_variable = getattr(self_variable, name)(*args_variable) |
| if not exclude_tensor_method(name, test_name): |
| output_tensor = getattr(self_tensor, name)(*args_tensor) |
| if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple): |
| output_tensor = torch.DoubleTensor((output_tensor,)) |
| self.assertEqual(unpack_variables(output_variable), output_tensor) |
| # TODO: check that both have changed after adding all inplace ops |
| |
| if not is_inplace and name not in EXCLUDE_GRADCHECK: |
| run_grad_and_gradgrad_checks(self, name, test_name, |
| lambda *inputs: getattr(inputs[0], name)(*inputs[1:]), |
| output_variable, (self_variable,) + args_variable) |
| |
| # functional interface tests |
| if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL: |
| f_args_variable = (self_variable,) + args_variable |
| f_args_tensor = (self_tensor,) + args_tensor |
| # could run the gradchecks again, but skip since we did it for the methods above. |
| run_functional_checks(self, test_name, name, |
| lambda *inputs: getattr(torch, name)(*inputs), |
| False, f_args_variable, f_args_tensor) |
| |
| # check for correct type of input.data and input.grad.data |
| if not is_inplace: |
| self_variable = create_input((self_size,), requires_grad=True)[0] |
| args_variable = create_input(args, requires_grad=False) |
| output_variable = getattr(self_variable, name)(*args_variable) |
| if isinstance(output_variable, torch.autograd.Variable): |
| output_variable.backward(torch.randn(*output_variable.size()).type_as(output_variable.data)) |
| self.assertTrue(type(self_variable.data) == type(self_variable.grad.data)) |
| self.assertTrue(self_variable.size() == self_variable.grad.size()) |
| |
| # compare grads to inplace grads |
| inplace_name = name + '_' |
| # can't broadcast inplace to left hand side |
| skip_inplace = ('broadcast_lhs' in test_name or |
| 'broadcast_all' in test_name or |
| test_name.startswith('test_resize')) |
| if hasattr(Variable(torch.ones(1)), inplace_name) and not skip_inplace: |
| output_variable = getattr(self_variable, name)(*args_variable) |
| if not isinstance(output_variable, tuple): |
| output_variable = (output_variable,) |
| inplace_self_variable = deepcopy(self_variable) |
| inplace_self_variable_copy = tuple(i + 0 if i is not None else None |
| for i in (inplace_self_variable,)) |
| inplace_args_variable = deepcopy(args_variable) |
| inplace_args_variable_copy = tuple(i + 0 if i is not None else None |
| for i in inplace_args_variable) |
| |
| inplace_output_variable = ( |
| getattr(inplace_self_variable_copy[0], inplace_name)(*inplace_args_variable_copy)) |
| if not isinstance(inplace_output_variable, tuple): |
| inplace_output_variable = (inplace_output_variable,) |
| self.assertEqual(inplace_output_variable, output_variable) |
| # Check that gradient is the same |
| for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable, |
| (self_variable,) + args_variable): |
| if not isinstance(inp_i, Variable): |
| assert not isinstance(i, Variable) |
| continue |
| if inp_i.grad is not None: |
| inp_i.grad.data.zero_() |
| if i.grad is not None: |
| i.grad.data.zero_() |
| for io, o in zip(inplace_output_variable, output_variable): |
| grad = torch.randn(*io.size()).double() |
| io.backward(grad) |
| o.backward(grad) |
| for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable, |
| (self_variable,) + args_variable): |
| if not isinstance(inp_i, Variable): |
| continue |
| self.assertEqual(inp_i.grad, i.grad) |
| |
| check(name) |
| inplace_name = name + '_' |
| # can't broadcast inplace to left hand side |
| broadcast_skip_inplace = 'broadcast_lhs' in test_name or 'broadcast_all' in test_name |
| if hasattr(Variable(torch.ones(1)), inplace_name) and not broadcast_skip_inplace: |
| check(inplace_name) |
| |
| assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name |
| |
| for skip in skipTestIf: |
| do_test = skip(do_test) |
| |
| setattr(TestAutograd, test_name, do_test) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |