Reset grad state across unittests (#126345)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126345
Approved by: https://github.com/ezyang
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index cfbd96e..6f990f1 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -102,13 +102,7 @@
 
 
 class AOTTestCase(TestCase):
-    def setUp(self):
-        self.prev_grad_state = torch.is_grad_enabled()
-        super().setUp()
-
-    def tearDown(self):
-        torch.set_grad_enabled(self.prev_grad_state)
-        super().tearDown()
+    pass
 
 
 class TestPythonKey(AOTTestCase):
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index af5dcf3..b0ea4dc 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -2903,6 +2903,9 @@
         if self._default_dtype_check_enabled:
             assert torch.get_default_dtype() == torch.float
 
+        # attempt to reset some global state at the end of the test
+        self._prev_grad_state = torch.is_grad_enabled()
+
     def tearDown(self):
         # There exists test cases that override TestCase.setUp
         # definition, so we cannot assume that _check_invariants
@@ -2917,6 +2920,10 @@
         if self._default_dtype_check_enabled:
             assert torch.get_default_dtype() == torch.float
 
+        # attribute may not be defined, per above
+        if hasattr(self, '_prev_grad_state'):
+            torch.set_grad_enabled(self._prev_grad_state)
+
     @staticmethod
     def _make_crow_indices(n_rows, n_cols, nnz,
                            *, device, dtype, random=True):