no_grad, enable_grad: support for decorating generator functions (#31792)
Summary:
Closes https://github.com/pytorch/pytorch/issues/31497
This allows `torch.no_grad` and `torch.enable_grad` to be used as decorators for generator functions. In which case it disables/enables grad only inside the body of the generator and restores the context outside of the generator.
https://github.com/pytorch/pytorch/issues/31497 doesn't include a complete reproducer but the included test with `torch.is_grad_enabled` show this is working where it failed before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31792
Differential Revision: D19274971
Pulled By: albanD
fbshipit-source-id: fde6d3fd95d76c8d324ad02db577213a4b68ccbe
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 98226e0..e4ab9b8 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -880,6 +880,27 @@
w = adder(x, y)
self.assertFalse(torch.is_grad_enabled())
+ def test_set_grad_generator_functions(self):
+ @torch.no_grad()
+ def gen_no_grad():
+ for i in range(10):
+ self.assertEqual(torch.is_grad_enabled(), False)
+ yield i
+
+ with torch.enable_grad():
+ for _ in gen_no_grad():
+ self.assertEqual(torch.is_grad_enabled(), True)
+
+ @torch.enable_grad()
+ def gen_enable_grad():
+ for i in range(10):
+ self.assertEqual(torch.is_grad_enabled(), True)
+ yield i
+
+ with torch.no_grad():
+ for _ in gen_enable_grad():
+ self.assertEqual(torch.is_grad_enabled(), False)
+
def test_no_grad_python_function(self):
"""Python Functions should respect grad mode."""
x = torch.ones(5, 5, requires_grad=True)
diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py
index 24cc4bb..c59a4bb 100644
--- a/torch/autograd/grad_mode.py
+++ b/torch/autograd/grad_mode.py
@@ -1,8 +1,36 @@
import torch
import functools
+import inspect
+
+class _DecoratorContextManager:
+ """Allow a context manager to be used as a decorator"""
+
+ def __call__(self, func):
+ if inspect.isgeneratorfunction(func):
+ return self._wrap_generator(func)
+
+ @functools.wraps(func)
+ def decorate_context(*args, **kwargs):
+ with self:
+ return func(*args, **kwargs)
+ return decorate_context
+
+ def _wrap_generator(self, func):
+ """Wrap each generator invocation with the context manager"""
+ @functools.wraps(func)
+ def generator_context(*args, **kwargs):
+ gen = func(*args, **kwargs)
+ while True:
+ try:
+ with self:
+ x = next(gen)
+ yield x
+ except StopIteration:
+ break
+ return generator_context
-class no_grad(object):
+class no_grad(_DecoratorContextManager):
r"""Context-manager that disabled gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
@@ -42,15 +70,8 @@
torch.set_grad_enabled(self.prev)
return False
- def __call__(self, func):
- @functools.wraps(func)
- def decorate_no_grad(*args, **kwargs):
- with self:
- return func(*args, **kwargs)
- return decorate_no_grad
-
-class enable_grad(object):
+class enable_grad(_DecoratorContextManager):
r"""Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
@@ -89,13 +110,6 @@
torch.set_grad_enabled(self.prev)
return False
- def __call__(self, func):
- @functools.wraps(func)
- def decorate_enable_grad(*args, **kwargs):
- with self:
- return func(*args, **kwargs)
- return decorate_enable_grad
-
class set_grad_enabled(object):
r"""Context-manager that sets gradient calculation to on or off.