no_grad, enable_grad: support for decorating generator functions (#31792)

Summary:
Closes https://github.com/pytorch/pytorch/issues/31497

This allows `torch.no_grad` and `torch.enable_grad` to be used as decorators for generator functions. In which case it disables/enables grad only inside the body of the generator and restores the context outside of the generator.

https://github.com/pytorch/pytorch/issues/31497 doesn't include a complete reproducer but the included test with `torch.is_grad_enabled` show this is working where it failed before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31792

Differential Revision: D19274971

Pulled By: albanD

fbshipit-source-id: fde6d3fd95d76c8d324ad02db577213a4b68ccbe
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 98226e0..e4ab9b8 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -880,6 +880,27 @@
             w = adder(x, y)
             self.assertFalse(torch.is_grad_enabled())
 
+    def test_set_grad_generator_functions(self):
+        @torch.no_grad()
+        def gen_no_grad():
+            for i in range(10):
+                self.assertEqual(torch.is_grad_enabled(), False)
+                yield i
+
+        with torch.enable_grad():
+            for _ in gen_no_grad():
+                self.assertEqual(torch.is_grad_enabled(), True)
+
+        @torch.enable_grad()
+        def gen_enable_grad():
+            for i in range(10):
+                self.assertEqual(torch.is_grad_enabled(), True)
+                yield i
+
+        with torch.no_grad():
+            for _ in gen_enable_grad():
+                self.assertEqual(torch.is_grad_enabled(), False)
+
     def test_no_grad_python_function(self):
         """Python Functions should respect grad mode."""
         x = torch.ones(5, 5, requires_grad=True)
diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py
index 24cc4bb..c59a4bb 100644
--- a/torch/autograd/grad_mode.py
+++ b/torch/autograd/grad_mode.py
@@ -1,8 +1,36 @@
 import torch
 import functools
+import inspect
+
+class _DecoratorContextManager:
+    """Allow a context manager to be used as a decorator"""
+
+    def __call__(self, func):
+        if inspect.isgeneratorfunction(func):
+            return self._wrap_generator(func)
+
+        @functools.wraps(func)
+        def decorate_context(*args, **kwargs):
+            with self:
+                return func(*args, **kwargs)
+        return decorate_context
+
+    def _wrap_generator(self, func):
+        """Wrap each generator invocation with the context manager"""
+        @functools.wraps(func)
+        def generator_context(*args, **kwargs):
+            gen = func(*args, **kwargs)
+            while True:
+                try:
+                    with self:
+                        x = next(gen)
+                    yield x
+                except StopIteration:
+                    break
+        return generator_context
 
 
-class no_grad(object):
+class no_grad(_DecoratorContextManager):
     r"""Context-manager that disabled gradient calculation.
 
     Disabling gradient calculation is useful for inference, when you are sure
@@ -42,15 +70,8 @@
         torch.set_grad_enabled(self.prev)
         return False
 
-    def __call__(self, func):
-        @functools.wraps(func)
-        def decorate_no_grad(*args, **kwargs):
-            with self:
-                return func(*args, **kwargs)
-        return decorate_no_grad
 
-
-class enable_grad(object):
+class enable_grad(_DecoratorContextManager):
     r"""Context-manager that enables gradient calculation.
 
     Enables gradient calculation, if it has been disabled via :class:`~no_grad`
@@ -89,13 +110,6 @@
         torch.set_grad_enabled(self.prev)
         return False
 
-    def __call__(self, func):
-        @functools.wraps(func)
-        def decorate_enable_grad(*args, **kwargs):
-            with self:
-                return func(*args, **kwargs)
-        return decorate_enable_grad
-
 
 class set_grad_enabled(object):
     r"""Context-manager that sets gradient calculation to on or off.