Add a classmethod variant of enter_context.

PiperOrigin-RevId: 371692599
Change-Id: I6c86e34e798749459b1f38fc86a94c5d3ce8c814
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index 3bf77d9..0708875 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -560,6 +560,39 @@
       yield fp
 
 
+class _method(object):
+  """A decorator that supports both instance and classmethod invocations.
+
+  Using similar semantics to the @property builtin, this decorator can augment
+  an instance method to support conditional logic when invoked on a class
+  object. This breaks support for invoking an instance method via the class
+  (e.g. Cls.method(self, ...)) but is still situationally useful.
+  """
+
+  def __init__(self, finstancemethod):
+    # type: (Callable[..., Any]) -> None
+    self._finstancemethod = finstancemethod
+    self._fclassmethod = None
+
+  def classmethod(self, fclassmethod):
+    # type: (Callable[..., Any]) -> _method
+    self._fclassmethod = classmethod(fclassmethod)
+    return self
+
+  def __doc__(self):
+    # type: () -> str
+    if getattr(self._finstancemethod, '__doc__'):
+      return self._finstancemethod.__doc__
+    elif getattr(self._fclassmethod, '__doc__'):
+      return self._fclassmethod.__doc__
+    return ''
+
+  def __get__(self, obj, type_):
+    # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any]
+    func = self._fclassmethod if obj is None else self._finstancemethod
+    return func.__get__(obj, type_)  # pytype: disable=attribute-error
+
+
 class TestCase(unittest3_backport.TestCase):
   """Extension of unittest.TestCase providing more power."""
 
@@ -576,20 +609,31 @@
   maxDiff = 80 * 20
   longMessage = True
 
+  # Exit stacks for per-test and per-class scopes.
+  _exit_stack = None
+  _cls_exit_stack = None
+
   def __init__(self, *args, **kwargs):
     super(TestCase, self).__init__(*args, **kwargs)
     # This is to work around missing type stubs in unittest.pyi
     self._outcome = getattr(self, '_outcome')  # type: Optional[_OutcomeType]
-    # This is re-initialized by setUp().
-    self._exit_stack = None
 
   def setUp(self):
     super(TestCase, self).setUp()
-    # NOTE: Only Py3 contextlib has ExitStack
+    # NOTE: Only Python 3 contextlib has ExitStack
     if hasattr(contextlib, 'ExitStack'):
       self._exit_stack = contextlib.ExitStack()
       self.addCleanup(self._exit_stack.close)
 
+  @classmethod
+  def setUpClass(cls):
+    super(TestCase, cls).setUpClass()
+    # NOTE: Only Python 3 contextlib has ExitStack and only Python 3.8+ has
+    # addClassCleanup.
+    if hasattr(contextlib, 'ExitStack') and hasattr(cls, 'addClassCleanup'):
+      cls._cls_exit_stack = contextlib.ExitStack()
+      cls.addClassCleanup(cls._cls_exit_stack.close)
+
   def create_tempdir(self, name=None, cleanup=None):
     # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir
     """Create a temporary directory specific to the test.
@@ -700,14 +744,19 @@
     self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
     return tf
 
+  @_method
   def enter_context(self, manager):
     # type: (ContextManager[_T]) -> _T
     """Returns the CM's value after registering it with the exit stack.
 
-    Entering a context pushes it onto a stack of contexts. The context is exited
-    when the test completes. Contexts are are exited in the reverse order of
-    entering. They will always be exited, regardless of test failure/success.
-    The context stack is specific to the test being run.
+    Entering a context pushes it onto a stack of contexts. When `enter_context`
+    is called on the test instance (e.g. `self.enter_context`), the context is
+    exited after the test case's tearDown call. When called on the test class
+    (e.g. `TestCase.enter_context`), the context is exited after the test
+    class's tearDownClass call.
+
+    Contexts are are exited in the reverse order of entering. They will always
+    be exited, regardless of test failure/success.
 
     This is useful to eliminate per-test boilerplate when context managers
     are used. For example, instead of decorating every test with `@mock.patch`,
@@ -726,6 +775,15 @@
           'sure that AbslTest.setUp() is called.')
     return self._exit_stack.enter_context(manager)
 
+  @enter_context.classmethod
+  def enter_context(cls, manager):  # pylint: disable=no-self-argument
+    # type: (ContextManager[_T]) -> _T
+    if not cls._cls_exit_stack:
+      raise AssertionError(
+          'cls._cls_exit_stack is not set: cls.enter_context requires '
+          'Python 3.8+; also make sure that AbslTest.setUpClass() is called.')
+    return cls._cls_exit_stack.enter_context(manager)
+
   @classmethod
   def _get_tempdir_path_cls(cls):
     # type: () -> Text
diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py
index 0ac3009..a54ff55 100644
--- a/absl/testing/tests/absltest_test.py
+++ b/absl/testing/tests/absltest_test.py
@@ -1487,6 +1487,15 @@
     self.assertRegex(stderr, 'No such file or directory')
 
 
+@contextlib.contextmanager
+def cm_for_test(obj):
+  try:
+    obj.cm_state = 'yielded'
+    yield 'value'
+  finally:
+    obj.cm_state = 'exited'
+
+
 @absltest.skipIf(six.PY2, 'Python 2 does not have ExitStack')
 class EnterContextTest(absltest.TestCase):
 
@@ -1503,15 +1512,33 @@
     self.addCleanup(assert_cm_exited)
 
     super(EnterContextTest, self).setUp()
-    self.cm_value = self.enter_context(self.cm_for_test())
+    self.cm_value = self.enter_context(cm_for_test(self))
 
-  @contextlib.contextmanager
-  def cm_for_test(self):
-    try:
-      self.cm_state = 'yielded'
-      yield 'value'
-    finally:
-      self.cm_state = 'exited'
+  def test_enter_context(self):
+    self.assertEqual(self.cm_value, 'value')
+    self.assertEqual(self.cm_state, 'yielded')
+
+
+@absltest.skipIf(not hasattr(absltest.TestCase, 'addClassCleanup'),
+                 'Python 3.8 required for class-level enter_context')
+class EnterContextClassmethodTest(absltest.TestCase):
+
+  cm_state = 'unset'
+  cm_value = 'unset'
+
+  @classmethod
+  def setUpClass(cls):
+
+    def assert_cm_exited():
+      assert cls.cm_state == 'exited'
+
+    # Because cleanup functions are run in reverse order, we have to add
+    # our assert-cleanup before the exit stack registers its own cleanup.
+    # This ensures we see state after the stack cleanup runs.
+    cls.addClassCleanup(assert_cm_exited)
+
+    super(EnterContextClassmethodTest, cls).setUpClass()
+    cls.cm_value = cls.enter_context(cm_for_test(cls))
 
   def test_enter_context(self):
     self.assertEqual(self.cm_value, 'value')