Add a classmethod variant of enter_context.
PiperOrigin-RevId: 371692599
Change-Id: I6c86e34e798749459b1f38fc86a94c5d3ce8c814
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index 3bf77d9..0708875 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -560,6 +560,39 @@
yield fp
+class _method(object):
+ """A decorator that supports both instance and classmethod invocations.
+
+ Using similar semantics to the @property builtin, this decorator can augment
+ an instance method to support conditional logic when invoked on a class
+ object. This breaks support for invoking an instance method via the class
+ (e.g. Cls.method(self, ...)) but is still situationally useful.
+ """
+
+ def __init__(self, finstancemethod):
+ # type: (Callable[..., Any]) -> None
+ self._finstancemethod = finstancemethod
+ self._fclassmethod = None
+
+ def classmethod(self, fclassmethod):
+ # type: (Callable[..., Any]) -> _method
+ self._fclassmethod = classmethod(fclassmethod)
+ return self
+
+ def __doc__(self):
+ # type: () -> str
+ if getattr(self._finstancemethod, '__doc__'):
+ return self._finstancemethod.__doc__
+ elif getattr(self._fclassmethod, '__doc__'):
+ return self._fclassmethod.__doc__
+ return ''
+
+ def __get__(self, obj, type_):
+ # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any]
+ func = self._fclassmethod if obj is None else self._finstancemethod
+ return func.__get__(obj, type_) # pytype: disable=attribute-error
+
+
class TestCase(unittest3_backport.TestCase):
"""Extension of unittest.TestCase providing more power."""
@@ -576,20 +609,31 @@
maxDiff = 80 * 20
longMessage = True
+ # Exit stacks for per-test and per-class scopes.
+ _exit_stack = None
+ _cls_exit_stack = None
+
def __init__(self, *args, **kwargs):
super(TestCase, self).__init__(*args, **kwargs)
# This is to work around missing type stubs in unittest.pyi
self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType]
- # This is re-initialized by setUp().
- self._exit_stack = None
def setUp(self):
super(TestCase, self).setUp()
- # NOTE: Only Py3 contextlib has ExitStack
+ # NOTE: Only Python 3 contextlib has ExitStack
if hasattr(contextlib, 'ExitStack'):
self._exit_stack = contextlib.ExitStack()
self.addCleanup(self._exit_stack.close)
+ @classmethod
+ def setUpClass(cls):
+ super(TestCase, cls).setUpClass()
+ # NOTE: Only Python 3 contextlib has ExitStack and only Python 3.8+ has
+ # addClassCleanup.
+ if hasattr(contextlib, 'ExitStack') and hasattr(cls, 'addClassCleanup'):
+ cls._cls_exit_stack = contextlib.ExitStack()
+ cls.addClassCleanup(cls._cls_exit_stack.close)
+
def create_tempdir(self, name=None, cleanup=None):
# type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir
"""Create a temporary directory specific to the test.
@@ -700,14 +744,19 @@
self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
return tf
+ @_method
def enter_context(self, manager):
# type: (ContextManager[_T]) -> _T
"""Returns the CM's value after registering it with the exit stack.
- Entering a context pushes it onto a stack of contexts. The context is exited
- when the test completes. Contexts are are exited in the reverse order of
- entering. They will always be exited, regardless of test failure/success.
- The context stack is specific to the test being run.
+ Entering a context pushes it onto a stack of contexts. When `enter_context`
+ is called on the test instance (e.g. `self.enter_context`), the context is
+ exited after the test case's tearDown call. When called on the test class
+ (e.g. `TestCase.enter_context`), the context is exited after the test
+ class's tearDownClass call.
+
+ Contexts are are exited in the reverse order of entering. They will always
+ be exited, regardless of test failure/success.
This is useful to eliminate per-test boilerplate when context managers
are used. For example, instead of decorating every test with `@mock.patch`,
@@ -726,6 +775,15 @@
'sure that AbslTest.setUp() is called.')
return self._exit_stack.enter_context(manager)
+ @enter_context.classmethod
+ def enter_context(cls, manager): # pylint: disable=no-self-argument
+ # type: (ContextManager[_T]) -> _T
+ if not cls._cls_exit_stack:
+ raise AssertionError(
+ 'cls._cls_exit_stack is not set: cls.enter_context requires '
+ 'Python 3.8+; also make sure that AbslTest.setUpClass() is called.')
+ return cls._cls_exit_stack.enter_context(manager)
+
@classmethod
def _get_tempdir_path_cls(cls):
# type: () -> Text
diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py
index 0ac3009..a54ff55 100644
--- a/absl/testing/tests/absltest_test.py
+++ b/absl/testing/tests/absltest_test.py
@@ -1487,6 +1487,15 @@
self.assertRegex(stderr, 'No such file or directory')
+@contextlib.contextmanager
+def cm_for_test(obj):
+ try:
+ obj.cm_state = 'yielded'
+ yield 'value'
+ finally:
+ obj.cm_state = 'exited'
+
+
@absltest.skipIf(six.PY2, 'Python 2 does not have ExitStack')
class EnterContextTest(absltest.TestCase):
@@ -1503,15 +1512,33 @@
self.addCleanup(assert_cm_exited)
super(EnterContextTest, self).setUp()
- self.cm_value = self.enter_context(self.cm_for_test())
+ self.cm_value = self.enter_context(cm_for_test(self))
- @contextlib.contextmanager
- def cm_for_test(self):
- try:
- self.cm_state = 'yielded'
- yield 'value'
- finally:
- self.cm_state = 'exited'
+ def test_enter_context(self):
+ self.assertEqual(self.cm_value, 'value')
+ self.assertEqual(self.cm_state, 'yielded')
+
+
+@absltest.skipIf(not hasattr(absltest.TestCase, 'addClassCleanup'),
+ 'Python 3.8 required for class-level enter_context')
+class EnterContextClassmethodTest(absltest.TestCase):
+
+ cm_state = 'unset'
+ cm_value = 'unset'
+
+ @classmethod
+ def setUpClass(cls):
+
+ def assert_cm_exited():
+ assert cls.cm_state == 'exited'
+
+ # Because cleanup functions are run in reverse order, we have to add
+ # our assert-cleanup before the exit stack registers its own cleanup.
+ # This ensures we see state after the stack cleanup runs.
+ cls.addClassCleanup(assert_cm_exited)
+
+ super(EnterContextClassmethodTest, cls).setUpClass()
+ cls.cm_value = cls.enter_context(cm_for_test(cls))
def test_enter_context(self):
self.assertEqual(self.cm_value, 'value')