| # Owner(s): ["module: autograd"] | 
 |  | 
 | import contextlib | 
 | import warnings | 
 |  | 
 | import numpy as np | 
 |  | 
 | import torch | 
 | from torch.library import _scoped_library, Library | 
 | from torch.testing._internal.common_utils import ( | 
 |     instantiate_parametrized_tests, | 
 |     parametrize, | 
 |     run_tests, | 
 |     TestCase, | 
 | ) | 
 |  | 
 |  | 
 | @contextlib.contextmanager | 
 | def autograd_fallback_mode(mode): | 
 |     prev = torch._C._get_autograd_fallback_mode() | 
 |     try: | 
 |         torch._C._set_autograd_fallback_mode(mode) | 
 |         yield | 
 |     finally: | 
 |         torch._C._set_autograd_fallback_mode(prev) | 
 |  | 
 |  | 
 | class TestAutogradFallback(TestCase): | 
 |     test_ns = "_test_autograd_fallback" | 
 |  | 
 |     def tearDown(self): | 
 |         if hasattr(torch.ops, self.test_ns): | 
 |             delattr(torch.ops, self.test_ns) | 
 |         if hasattr(self, "lib"): | 
 |             del self.lib.m | 
 |             del self.lib | 
 |  | 
 |     def get_op(self, name): | 
 |         return getattr(getattr(torch.ops, self.test_ns), name).default | 
 |  | 
 |     def get_lib(self): | 
 |         lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901 | 
 |         self.lib = lib | 
 |         return lib | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_no_grad(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") | 
 |             lib.impl("foo", lambda a, b, c: a + b + c, "CPU") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             with warnings.catch_warnings(): | 
 |                 warnings.simplefilter("error") | 
 |                 with torch.no_grad(): | 
 |                     a = torch.randn([], requires_grad=True) | 
 |                     b = torch.randn([], requires_grad=True) | 
 |                     out = op(a, b, 1) | 
 |                 self.assertFalse(out.requires_grad) | 
 |  | 
 |             with warnings.catch_warnings(): | 
 |                 warnings.simplefilter("error") | 
 |                 a = torch.randn([]) | 
 |                 b = torch.randn([]) | 
 |                 out = op(a, b, 1) | 
 |                 self.assertFalse(out.requires_grad) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_no_autograd_kernel(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor a, Tensor b, int c) -> Tensor") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             def foo_impl(a, b, c): | 
 |                 result = a.detach().numpy() + b.detach().numpy() + c | 
 |                 return torch.tensor(result) | 
 |  | 
 |             lib.impl("foo", foo_impl, "CPU") | 
 |  | 
 |             # Some inputs requiring grad | 
 |             a = torch.randn([], requires_grad=False) | 
 |             b = torch.randn([], requires_grad=True) | 
 |             out = op(a, b, 1).sum() | 
 |             with self._check_ctx(mode, mode_nothing_raises=True): | 
 |                 out.backward() | 
 |             self.assertIsNone(b.grad) | 
 |  | 
 |     def _check_ctx(self, mode, *, mode_nothing_raises=False): | 
 |         if mode == "warn": | 
 |             return self.assertWarnsRegex( | 
 |                 UserWarning, "an autograd kernel was not registered" | 
 |             ) | 
 |         assert mode == "nothing" | 
 |         if mode_nothing_raises: | 
 |             return self.assertRaisesRegex(RuntimeError, "does not require grad") | 
 |         return contextlib.nullcontext() | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_no_autograd_kernel_inplace(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             # input modified in-place gets returned as output | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             def foo_impl(x, y): | 
 |                 with torch.no_grad(): | 
 |                     x.sin_() | 
 |                     y.cos_() | 
 |                 return x, y | 
 |  | 
 |             lib.impl("foo", foo_impl, "CPU") | 
 |  | 
 |             x = torch.randn(3, requires_grad=True) | 
 |             w = x.clone() | 
 |             v = x.clone() | 
 |             y0 = w[0] | 
 |             y1 = v[1] | 
 |             z0, z1 = op(y0, y1) | 
 |             for tensor in [w, v, z0, z1, y0, y1]: | 
 |                 with self._check_ctx(mode): | 
 |                     tensor.sum().backward(retain_graph=True) | 
 |  | 
 |             # no outputs: we don't do anything. Maybe we should in the future. | 
 |             # This is not a common failure mode. | 
 |             lib.define("bar(Tensor(a!) self) -> ()") | 
 |             op = self.get_op("bar") | 
 |  | 
 |             def bar_impl(x): | 
 |                 with torch.no_grad(): | 
 |                     x.sin_() | 
 |  | 
 |             lib.impl("bar", bar_impl, "CPU") | 
 |             with warnings.catch_warnings(): | 
 |                 warnings.simplefilter("error") | 
 |                 x = torch.randn([], requires_grad=True) | 
 |                 y = x.clone() | 
 |                 z = op(y) | 
 |                 y.backward() | 
 |                 self.assertEqual(x.grad, torch.ones_like(x)) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_cpu_return_self(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             # To be clear, none of these situations are OK and will lead | 
 |             # to other problems down the line. We're testing them because | 
 |             # it is fairly common to actually do these things. | 
 |             with _scoped_library(self.test_ns, "FRAGMENT") as lib: | 
 |                 lib.define("foo(Tensor self) -> Tensor") | 
 |                 lib.impl("foo", lambda x: x, "CPU") | 
 |                 op = self.get_op("foo") | 
 |  | 
 |                 x = torch.randn(3, requires_grad=True) | 
 |                 y = op(x).sum() | 
 |                 with self._check_ctx(mode): | 
 |                     y.backward() | 
 |                     self.assertEqual(x.grad, torch.ones_like(x)) | 
 |  | 
 |                 lib.define("bar(Tensor(a!) self) -> Tensor(a!)") | 
 |                 lib.impl("bar", lambda x: x, "CPU") | 
 |                 op = self.get_op("bar") | 
 |  | 
 |                 x = torch.randn(3, requires_grad=True) | 
 |                 y = op(x).sum() | 
 |                 with self._check_ctx(mode): | 
 |                     y.backward() | 
 |                     self.assertEqual(x.grad, torch.ones_like(x)) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_composite_registered_to_cpu(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             with _scoped_library(self.test_ns, "FRAGMENT") as lib: | 
 |                 lib.define("foo(Tensor self) -> Tensor") | 
 |                 lib.impl("foo", lambda x: x.sin().sum(), "CPU") | 
 |                 op = self.get_op("foo") | 
 |  | 
 |                 x = torch.randn(3, requires_grad=True) | 
 |                 y = op(x) | 
 |                 with self._check_ctx(mode): | 
 |                     y.backward() | 
 |                     self.assertEqual(x.grad, x.cos()) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_autograd_function_registered_to_cpu(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             with _scoped_library(self.test_ns, "FRAGMENT") as lib: | 
 |                 lib.define("foo(Tensor self) -> Tensor") | 
 |  | 
 |                 class NumpySin(torch.autograd.Function): | 
 |                     @staticmethod | 
 |                     def forward(ctx, x): | 
 |                         ctx.save_for_backward(x) | 
 |                         return torch.tensor(np.sin(x.cpu().numpy())) | 
 |  | 
 |                     @staticmethod | 
 |                     def backward(ctx, gx): | 
 |                         (x,) = ctx.saved_tensors | 
 |                         return gx * x.cos() | 
 |  | 
 |                 lib.impl("foo", NumpySin.apply, "CPU") | 
 |                 op = self.get_op("foo") | 
 |  | 
 |                 x = torch.randn(3, requires_grad=True) | 
 |                 y = op(x).sum() | 
 |                 with self._check_ctx(mode): | 
 |                     y.backward() | 
 |                     self.assertEqual(x.grad, x.cos()) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_inplace_autograd_function_registered_to_cpu(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             with _scoped_library(self.test_ns, "FRAGMENT") as lib: | 
 |                 lib.define("foo(Tensor(a!) self) -> Tensor(a!)") | 
 |  | 
 |                 class NumpySin_(torch.autograd.Function): | 
 |                     @staticmethod | 
 |                     def forward(ctx, x): | 
 |                         ctx.save_for_backward(x.clone()) | 
 |                         x_np = x.detach().numpy() | 
 |                         np.sin(x_np, out=x_np) | 
 |                         ctx.mark_dirty(x) | 
 |                         return x | 
 |  | 
 |                     @staticmethod | 
 |                     def backward(ctx, gx): | 
 |                         (x,) = ctx.saved_tensors | 
 |                         return gx * x.cos() | 
 |  | 
 |                 lib.impl("foo", NumpySin_.apply, "CPU") | 
 |                 op = self.get_op("foo") | 
 |  | 
 |                 x = torch.randn(3, requires_grad=True) | 
 |                 z = x.clone() | 
 |                 w = z[0] | 
 |                 y = op(w) | 
 |  | 
 |                 expected = torch.zeros_like(x) | 
 |                 expected[0] = x[0].cos() | 
 |                 with self._check_ctx(mode): | 
 |                     (gx,) = torch.autograd.grad( | 
 |                         y, x, torch.ones_like(y), retain_graph=True | 
 |                     ) | 
 |                     self.assertEqual(gx, expected) | 
 |  | 
 |                 expected = torch.ones_like(x) | 
 |                 expected[0] = x[0].cos() | 
 |                 with self._check_ctx(mode): | 
 |                     (gx,) = torch.autograd.grad(z, x, torch.ones_like(z)) | 
 |                     self.assertEqual(gx, expected) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_inplace_on_tensor_that_does_not_require_grad(self, mode): | 
 |         # We don't do anything special (that is, we don't rebase history). | 
 |         # See NOTE [autograd fallback and in-place operations] for why | 
 |         with autograd_fallback_mode(mode): | 
 |             with _scoped_library(self.test_ns, "FRAGMENT") as lib: | 
 |                 # Correct usage of (a!) | 
 |                 lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)") | 
 |  | 
 |                 def foo_impl(x, y): | 
 |                     x_d = x.detach() | 
 |                     y = y.detach() | 
 |                     x_d.add_(y) | 
 |                     return x | 
 |  | 
 |                 lib.impl("foo", foo_impl, "CPU") | 
 |                 foo = self.get_op("foo") | 
 |  | 
 |                 # Incorrect usage of (a!): user doesn't return tensor as-is | 
 |                 lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)") | 
 |  | 
 |                 def bar_impl(x, y): | 
 |                     x_d = x.detach() | 
 |                     y = y.detach() | 
 |                     x_d.add_(y) | 
 |                     return x_d.clone() | 
 |  | 
 |                 lib.impl("bar", bar_impl, "CPU") | 
 |                 bar = self.get_op("bar") | 
 |  | 
 |                 # User mutated input tensor but didn't return it. | 
 |                 lib.define("baz(Tensor(a!) self, Tensor other) -> ()") | 
 |  | 
 |                 def baz_impl(x, y): | 
 |                     x_d = x.detach() | 
 |                     y = y.detach() | 
 |                     x_d.add_(y) | 
 |  | 
 |                 lib.impl("baz", baz_impl, "CPU") | 
 |                 baz = self.get_op("baz") | 
 |  | 
 |                 # Test in-place on non-view | 
 |                 for op in (foo, bar, baz): | 
 |                     x = torch.randn(3) | 
 |                     y = torch.randn(3, requires_grad=True) | 
 |                     with self.assertRaisesRegex(RuntimeError, "does not require grad"): | 
 |                         z = x.clone() | 
 |                         op(z, y) | 
 |                         torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True) | 
 |  | 
 |                 # Test in-place on view | 
 |                 for op in (foo, bar, baz): | 
 |                     x = torch.randn(3) | 
 |                     y = torch.randn(3, requires_grad=True) | 
 |                     with self.assertRaisesRegex(RuntimeError, "does not require grad"): | 
 |                         z = x[:] | 
 |                         op(z, y) | 
 |                         torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_post_autograd_returns_leaf(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor a) -> (Tensor, Tensor)") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             lib.impl( | 
 |                 "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU" | 
 |             ) | 
 |             x = torch.randn(3, requires_grad=True) | 
 |             y, z = op(x) | 
 |             with self._check_ctx(mode): | 
 |                 z.sum().backward() | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_undefined_inputs_outputs(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             def foo_impl(a, b): | 
 |                 return None, b.clone() | 
 |  | 
 |             lib.impl("foo", foo_impl, "CPU") | 
 |  | 
 |             x = torch.randn(3, requires_grad=True) | 
 |             # NB: PyTorch dispatcher treats "None" as undefined Tensor. | 
 |             y, z = op(None, x) | 
 |             with self._check_ctx(mode): | 
 |                 z.sum().backward() | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_undefined_grads(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             def foo_impl(a, b): | 
 |                 return a.sin(), b.cos() | 
 |  | 
 |             lib.impl("foo", foo_impl, "CPU") | 
 |  | 
 |             x = torch.randn(3, requires_grad=True) | 
 |             y = torch.randn(3) | 
 |             w, z = op(x, y) | 
 |             w = torch._C._functions.UndefinedGrad()(w) | 
 |             z = torch._C._functions.UndefinedGrad()(z) | 
 |             with self._check_ctx(mode): | 
 |                 (z + w).sum().backward() | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_base_does_not_require_grad(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor(a!) x) -> Tensor(a!)") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             def foo_impl(a): | 
 |                 with torch.no_grad(): | 
 |                     return a.zero_() | 
 |  | 
 |             lib.impl("foo", foo_impl, "CPU") | 
 |             x = torch.randn(3) | 
 |             y = x[:] | 
 |             y.requires_grad_() | 
 |             w = y[:] | 
 |             self.assertTrue(w._base is x) | 
 |  | 
 |             # Hook should be registered on w, but not w._base | 
 |             op(w) | 
 |             with self._check_ctx(mode): | 
 |                 w.sum().backward() | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             def foo_impl(a, b): | 
 |                 with torch.no_grad(): | 
 |                     x = a.clone() | 
 |                     z = b.clone() | 
 |                 y = a * b | 
 |                 return x, y, z | 
 |  | 
 |             lib.impl("foo", foo_impl, "CPU") | 
 |             a = torch.randn(3, requires_grad=True) | 
 |             b = torch.randn(3, requires_grad=True) | 
 |             x, y, z = op(a, b) | 
 |  | 
 |             with self._check_ctx(mode, mode_nothing_raises=True): | 
 |                 torch.autograd.grad( | 
 |                     x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True | 
 |                 ) | 
 |  | 
 |             with self._check_ctx(mode, mode_nothing_raises=False): | 
 |                 torch.autograd.grad( | 
 |                     y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True | 
 |                 ) | 
 |  | 
 |             with self._check_ctx(mode, mode_nothing_raises=True): | 
 |                 torch.autograd.grad( | 
 |                     z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True | 
 |                 ) | 
 |  | 
 |     @parametrize("mode", ("nothing", "warn")) | 
 |     def test_supports_tensor_lists(self, mode): | 
 |         with autograd_fallback_mode(mode): | 
 |             lib = self.get_lib() | 
 |             lib.define("foo(Tensor[] a) -> Tensor[]") | 
 |             op = self.get_op("foo") | 
 |  | 
 |             def foo_impl(a): | 
 |                 x, y, z = a | 
 |                 with torch.no_grad(): | 
 |                     return x + y + z, x * y * z | 
 |  | 
 |             lib.impl("foo", foo_impl, "CPU") | 
 |             x = torch.randn(3, requires_grad=True) | 
 |             y = torch.randn(1, requires_grad=True) | 
 |             z = torch.randn(2, 1, requires_grad=True) | 
 |             a, b = op([x, y, z]) | 
 |             with self._check_ctx(mode, mode_nothing_raises=True): | 
 |                 torch.autograd.grad( | 
 |                     a, | 
 |                     (x, y, z), | 
 |                     torch.ones_like(a), | 
 |                     allow_unused=True, | 
 |                     retain_graph=True, | 
 |                 ) | 
 |             with self._check_ctx(mode, mode_nothing_raises=True): | 
 |                 torch.autograd.grad( | 
 |                     b, | 
 |                     (x, y, z), | 
 |                     torch.ones_like(b), | 
 |                     allow_unused=True, | 
 |                     retain_graph=True, | 
 |                 ) | 
 |  | 
 |  | 
 | instantiate_parametrized_tests(TestAutogradFallback) | 
 |  | 
 | if __name__ == "__main__": | 
 |     run_tests() |