Reset grad state across unittests (#126345)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126345
Approved by: https://github.com/ezyang
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index cfbd96e..6f990f1 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -102,13 +102,7 @@
class AOTTestCase(TestCase):
- def setUp(self):
- self.prev_grad_state = torch.is_grad_enabled()
- super().setUp()
-
- def tearDown(self):
- torch.set_grad_enabled(self.prev_grad_state)
- super().tearDown()
+ pass
class TestPythonKey(AOTTestCase):
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index af5dcf3..b0ea4dc 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -2903,6 +2903,9 @@
if self._default_dtype_check_enabled:
assert torch.get_default_dtype() == torch.float
+ # attempt to reset some global state at the end of the test
+ self._prev_grad_state = torch.is_grad_enabled()
+
def tearDown(self):
# There exists test cases that override TestCase.setUp
# definition, so we cannot assume that _check_invariants
@@ -2917,6 +2920,10 @@
if self._default_dtype_check_enabled:
assert torch.get_default_dtype() == torch.float
+ # attribute may not be defined, per above
+ if hasattr(self, '_prev_grad_state'):
+ torch.set_grad_enabled(self._prev_grad_state)
+
@staticmethod
def _make_crow_indices(n_rows, n_cols, nnz,
*, device, dtype, random=True):