| # Owner(s): ["module: autograd"] |
| |
| import torch |
| |
| from torch.testing._internal.common_utils import TestCase, run_tests, gradcheck |
| |
| |
| class TestAutogradComplex(TestCase): |
| def test_view_func_for_complex_views(self): |
| # case 1: both parent and child have view_func |
| x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) |
| y = x.detach().requires_grad_(True) |
| |
| x0 = x.clone() |
| x1 = torch.view_as_complex(x0) |
| x2 = torch.view_as_real(x1) |
| x2.mul_(2) |
| x2.sum().backward() |
| |
| y0 = y.clone() |
| y0.mul_(2) |
| y0.sum().backward() |
| |
| self.assertEqual(x.grad, y.grad) |
| |
| # case 2: parent has view_func but child does not |
| x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) |
| y = x.detach().requires_grad_(True) |
| |
| def fn(a): |
| b = a.clone() |
| b1 = torch.view_as_complex(b) |
| b2 = b1.reshape(b1.numel()) |
| return b2 |
| |
| x0 = fn(x) |
| x0.mul_(2) |
| x0.sum().backward() |
| |
| y0 = fn(y) |
| y1 = y0.mul(2) |
| y1.sum().backward() |
| |
| self.assertEqual(x.grad, y.grad) |
| |
| # case 3: parent does not have a view_func but child does |
| x = torch.randn(10, dtype=torch.cdouble, requires_grad=True) |
| y = x.detach().requires_grad_(True) |
| |
| def fn(a, dim0_size=5): |
| b = a.clone() |
| b1 = b.reshape(dim0_size, 2) |
| b2 = torch.view_as_real(b1) |
| return b2 |
| |
| x0 = fn(x) |
| x0.mul_(2) |
| x0.sum().backward() |
| |
| y0 = fn(y) |
| y1 = y0.mul(2) |
| y1.sum().backward() |
| |
| self.assertEqual(x.grad, y.grad) |
| |
| def test_view_with_multi_output(self): |
| x = torch.randn(2, 2, 2, dtype=torch.double) |
| |
| x1 = torch.view_as_complex(x) |
| # Taking an invalid view should always be allowed as long as it is not |
| # modified inplace |
| res = x1.unbind(0) |
| |
| with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): |
| res[0] += torch.rand(2, requires_grad=True) |
| |
| x.requires_grad_(True) |
| x1 = torch.view_as_complex(x) |
| # Taking an invalid view should always be allowed as long as it is not |
| # modified inplace |
| res = x1.unbind(0) |
| |
| with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): |
| res[0] += torch.rand(2, requires_grad=True) |
| |
| def as_identity(self): |
| # view_as_real and view_as_complex behavior should be like an identity |
| def func(z): |
| z_ = torch.view_as_complex(z) |
| z_select = torch.select(z_, z_.dim() - 1, 0) |
| z_select_real = torch.view_as_real(z_select) |
| return z_select_real.sum() |
| |
| z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True) |
| gradcheck(func, [z]) |
| func(z).backward() |
| |
| z1 = z.clone().detach().requires_grad_(True) |
| torch.select(z1, z1.dim() - 2, 0).sum().backward() |
| |
| self.assertEqual(z.grad, z1.grad) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |