| import asyncio |
| from contextlib import ( |
| asynccontextmanager, AbstractAsyncContextManager, |
| AsyncExitStack, nullcontext, aclosing, contextmanager) |
| import functools |
| from test import support |
| import unittest |
| |
| from test.test_contextlib import TestBaseExitStack |
| |
| |
| def _async_test(func): |
| """Decorator to turn an async function into a test case.""" |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| coro = func(*args, **kwargs) |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| try: |
| return loop.run_until_complete(coro) |
| finally: |
| loop.close() |
| asyncio.set_event_loop_policy(None) |
| return wrapper |
| |
| |
| class TestAbstractAsyncContextManager(unittest.TestCase): |
| |
| @_async_test |
| async def test_enter(self): |
| class DefaultEnter(AbstractAsyncContextManager): |
| async def __aexit__(self, *args): |
| await super().__aexit__(*args) |
| |
| manager = DefaultEnter() |
| self.assertIs(await manager.__aenter__(), manager) |
| |
| async with manager as context: |
| self.assertIs(manager, context) |
| |
| @_async_test |
| async def test_async_gen_propagates_generator_exit(self): |
| # A regression test for https://bugs.python.org/issue33786. |
| |
| @asynccontextmanager |
| async def ctx(): |
| yield |
| |
| async def gen(): |
| async with ctx(): |
| yield 11 |
| |
| ret = [] |
| exc = ValueError(22) |
| with self.assertRaises(ValueError): |
| async with ctx(): |
| async for val in gen(): |
| ret.append(val) |
| raise exc |
| |
| self.assertEqual(ret, [11]) |
| |
| def test_exit_is_abstract(self): |
| class MissingAexit(AbstractAsyncContextManager): |
| pass |
| |
| with self.assertRaises(TypeError): |
| MissingAexit() |
| |
| def test_structural_subclassing(self): |
| class ManagerFromScratch: |
| async def __aenter__(self): |
| return self |
| async def __aexit__(self, exc_type, exc_value, traceback): |
| return None |
| |
| self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager)) |
| |
| class DefaultEnter(AbstractAsyncContextManager): |
| async def __aexit__(self, *args): |
| await super().__aexit__(*args) |
| |
| self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager)) |
| |
| class NoneAenter(ManagerFromScratch): |
| __aenter__ = None |
| |
| self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager)) |
| |
| class NoneAexit(ManagerFromScratch): |
| __aexit__ = None |
| |
| self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager)) |
| |
| |
| class AsyncContextManagerTestCase(unittest.TestCase): |
| |
| @_async_test |
| async def test_contextmanager_plain(self): |
| state = [] |
| @asynccontextmanager |
| async def woohoo(): |
| state.append(1) |
| yield 42 |
| state.append(999) |
| async with woohoo() as x: |
| self.assertEqual(state, [1]) |
| self.assertEqual(x, 42) |
| state.append(x) |
| self.assertEqual(state, [1, 42, 999]) |
| |
| @_async_test |
| async def test_contextmanager_finally(self): |
| state = [] |
| @asynccontextmanager |
| async def woohoo(): |
| state.append(1) |
| try: |
| yield 42 |
| finally: |
| state.append(999) |
| with self.assertRaises(ZeroDivisionError): |
| async with woohoo() as x: |
| self.assertEqual(state, [1]) |
| self.assertEqual(x, 42) |
| state.append(x) |
| raise ZeroDivisionError() |
| self.assertEqual(state, [1, 42, 999]) |
| |
| @_async_test |
| async def test_contextmanager_no_reraise(self): |
| @asynccontextmanager |
| async def whee(): |
| yield |
| ctx = whee() |
| await ctx.__aenter__() |
| # Calling __aexit__ should not result in an exception |
| self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) |
| |
| @_async_test |
| async def test_contextmanager_trap_yield_after_throw(self): |
| @asynccontextmanager |
| async def whoo(): |
| try: |
| yield |
| except: |
| yield |
| ctx = whoo() |
| await ctx.__aenter__() |
| with self.assertRaises(RuntimeError): |
| await ctx.__aexit__(TypeError, TypeError('foo'), None) |
| |
| @_async_test |
| async def test_contextmanager_trap_no_yield(self): |
| @asynccontextmanager |
| async def whoo(): |
| if False: |
| yield |
| ctx = whoo() |
| with self.assertRaises(RuntimeError): |
| await ctx.__aenter__() |
| |
| @_async_test |
| async def test_contextmanager_trap_second_yield(self): |
| @asynccontextmanager |
| async def whoo(): |
| yield |
| yield |
| ctx = whoo() |
| await ctx.__aenter__() |
| with self.assertRaises(RuntimeError): |
| await ctx.__aexit__(None, None, None) |
| |
| @_async_test |
| async def test_contextmanager_non_normalised(self): |
| @asynccontextmanager |
| async def whoo(): |
| try: |
| yield |
| except RuntimeError: |
| raise SyntaxError |
| |
| ctx = whoo() |
| await ctx.__aenter__() |
| with self.assertRaises(SyntaxError): |
| await ctx.__aexit__(RuntimeError, None, None) |
| |
| @_async_test |
| async def test_contextmanager_except(self): |
| state = [] |
| @asynccontextmanager |
| async def woohoo(): |
| state.append(1) |
| try: |
| yield 42 |
| except ZeroDivisionError as e: |
| state.append(e.args[0]) |
| self.assertEqual(state, [1, 42, 999]) |
| async with woohoo() as x: |
| self.assertEqual(state, [1]) |
| self.assertEqual(x, 42) |
| state.append(x) |
| raise ZeroDivisionError(999) |
| self.assertEqual(state, [1, 42, 999]) |
| |
| @_async_test |
| async def test_contextmanager_except_stopiter(self): |
| @asynccontextmanager |
| async def woohoo(): |
| yield |
| |
| class StopIterationSubclass(StopIteration): |
| pass |
| |
| class StopAsyncIterationSubclass(StopAsyncIteration): |
| pass |
| |
| for stop_exc in ( |
| StopIteration('spam'), |
| StopAsyncIteration('ham'), |
| StopIterationSubclass('spam'), |
| StopAsyncIterationSubclass('spam') |
| ): |
| with self.subTest(type=type(stop_exc)): |
| try: |
| async with woohoo(): |
| raise stop_exc |
| except Exception as ex: |
| self.assertIs(ex, stop_exc) |
| else: |
| self.fail(f'{stop_exc} was suppressed') |
| |
| @_async_test |
| async def test_contextmanager_wrap_runtimeerror(self): |
| @asynccontextmanager |
| async def woohoo(): |
| try: |
| yield |
| except Exception as exc: |
| raise RuntimeError(f'caught {exc}') from exc |
| |
| with self.assertRaises(RuntimeError): |
| async with woohoo(): |
| 1 / 0 |
| |
| # If the context manager wrapped StopAsyncIteration in a RuntimeError, |
| # we also unwrap it, because we can't tell whether the wrapping was |
| # done by the generator machinery or by the generator itself. |
| with self.assertRaises(StopAsyncIteration): |
| async with woohoo(): |
| raise StopAsyncIteration |
| |
| def _create_contextmanager_attribs(self): |
| def attribs(**kw): |
| def decorate(func): |
| for k,v in kw.items(): |
| setattr(func,k,v) |
| return func |
| return decorate |
| @asynccontextmanager |
| @attribs(foo='bar') |
| async def baz(spam): |
| """Whee!""" |
| yield |
| return baz |
| |
| def test_contextmanager_attribs(self): |
| baz = self._create_contextmanager_attribs() |
| self.assertEqual(baz.__name__,'baz') |
| self.assertEqual(baz.foo, 'bar') |
| |
| @support.requires_docstrings |
| def test_contextmanager_doc_attrib(self): |
| baz = self._create_contextmanager_attribs() |
| self.assertEqual(baz.__doc__, "Whee!") |
| |
| @support.requires_docstrings |
| @_async_test |
| async def test_instance_docstring_given_cm_docstring(self): |
| baz = self._create_contextmanager_attribs()(None) |
| self.assertEqual(baz.__doc__, "Whee!") |
| async with baz: |
| pass # suppress warning |
| |
| @_async_test |
| async def test_keywords(self): |
| # Ensure no keyword arguments are inhibited |
| @asynccontextmanager |
| async def woohoo(self, func, args, kwds): |
| yield (self, func, args, kwds) |
| async with woohoo(self=11, func=22, args=33, kwds=44) as target: |
| self.assertEqual(target, (11, 22, 33, 44)) |
| |
| @_async_test |
| async def test_recursive(self): |
| depth = 0 |
| ncols = 0 |
| |
| @asynccontextmanager |
| async def woohoo(): |
| nonlocal ncols |
| ncols += 1 |
| |
| nonlocal depth |
| before = depth |
| depth += 1 |
| yield |
| depth -= 1 |
| self.assertEqual(depth, before) |
| |
| @woohoo() |
| async def recursive(): |
| if depth < 10: |
| await recursive() |
| |
| await recursive() |
| |
| self.assertEqual(ncols, 10) |
| self.assertEqual(depth, 0) |
| |
| |
| class AclosingTestCase(unittest.TestCase): |
| |
| @support.requires_docstrings |
| def test_instance_docs(self): |
| cm_docstring = aclosing.__doc__ |
| obj = aclosing(None) |
| self.assertEqual(obj.__doc__, cm_docstring) |
| |
| @_async_test |
| async def test_aclosing(self): |
| state = [] |
| class C: |
| async def aclose(self): |
| state.append(1) |
| x = C() |
| self.assertEqual(state, []) |
| async with aclosing(x) as y: |
| self.assertEqual(x, y) |
| self.assertEqual(state, [1]) |
| |
| @_async_test |
| async def test_aclosing_error(self): |
| state = [] |
| class C: |
| async def aclose(self): |
| state.append(1) |
| x = C() |
| self.assertEqual(state, []) |
| with self.assertRaises(ZeroDivisionError): |
| async with aclosing(x) as y: |
| self.assertEqual(x, y) |
| 1 / 0 |
| self.assertEqual(state, [1]) |
| |
| @_async_test |
| async def test_aclosing_bpo41229(self): |
| state = [] |
| |
| @contextmanager |
| def sync_resource(): |
| try: |
| yield |
| finally: |
| state.append(1) |
| |
| async def agenfunc(): |
| with sync_resource(): |
| yield -1 |
| yield -2 |
| |
| x = agenfunc() |
| self.assertEqual(state, []) |
| with self.assertRaises(ZeroDivisionError): |
| async with aclosing(x) as y: |
| self.assertEqual(x, y) |
| self.assertEqual(-1, await x.__anext__()) |
| 1 / 0 |
| self.assertEqual(state, [1]) |
| |
| |
| class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase): |
| class SyncAsyncExitStack(AsyncExitStack): |
| @staticmethod |
| def run_coroutine(coro): |
| loop = asyncio.get_event_loop() |
| |
| f = asyncio.ensure_future(coro) |
| f.add_done_callback(lambda f: loop.stop()) |
| loop.run_forever() |
| |
| exc = f.exception() |
| |
| if not exc: |
| return f.result() |
| else: |
| context = exc.__context__ |
| |
| try: |
| raise exc |
| except: |
| exc.__context__ = context |
| raise exc |
| |
| def close(self): |
| return self.run_coroutine(self.aclose()) |
| |
| def __enter__(self): |
| return self.run_coroutine(self.__aenter__()) |
| |
| def __exit__(self, *exc_details): |
| return self.run_coroutine(self.__aexit__(*exc_details)) |
| |
| exit_stack = SyncAsyncExitStack |
| |
| def setUp(self): |
| self.loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(self.loop) |
| self.addCleanup(self.loop.close) |
| self.addCleanup(asyncio.set_event_loop_policy, None) |
| |
| @_async_test |
| async def test_async_callback(self): |
| expected = [ |
| ((), {}), |
| ((1,), {}), |
| ((1,2), {}), |
| ((), dict(example=1)), |
| ((1,), dict(example=1)), |
| ((1,2), dict(example=1)), |
| ] |
| result = [] |
| async def _exit(*args, **kwds): |
| """Test metadata propagation""" |
| result.append((args, kwds)) |
| |
| async with AsyncExitStack() as stack: |
| for args, kwds in reversed(expected): |
| if args and kwds: |
| f = stack.push_async_callback(_exit, *args, **kwds) |
| elif args: |
| f = stack.push_async_callback(_exit, *args) |
| elif kwds: |
| f = stack.push_async_callback(_exit, **kwds) |
| else: |
| f = stack.push_async_callback(_exit) |
| self.assertIs(f, _exit) |
| for wrapper in stack._exit_callbacks: |
| self.assertIs(wrapper[1].__wrapped__, _exit) |
| self.assertNotEqual(wrapper[1].__name__, _exit.__name__) |
| self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) |
| |
| self.assertEqual(result, expected) |
| |
| result = [] |
| async with AsyncExitStack() as stack: |
| with self.assertRaises(TypeError): |
| stack.push_async_callback(arg=1) |
| with self.assertRaises(TypeError): |
| self.exit_stack.push_async_callback(arg=2) |
| with self.assertRaises(TypeError): |
| stack.push_async_callback(callback=_exit, arg=3) |
| self.assertEqual(result, []) |
| |
| @_async_test |
| async def test_async_push(self): |
| exc_raised = ZeroDivisionError |
| async def _expect_exc(exc_type, exc, exc_tb): |
| self.assertIs(exc_type, exc_raised) |
| async def _suppress_exc(*exc_details): |
| return True |
| async def _expect_ok(exc_type, exc, exc_tb): |
| self.assertIsNone(exc_type) |
| self.assertIsNone(exc) |
| self.assertIsNone(exc_tb) |
| class ExitCM(object): |
| def __init__(self, check_exc): |
| self.check_exc = check_exc |
| async def __aenter__(self): |
| self.fail("Should not be called!") |
| async def __aexit__(self, *exc_details): |
| await self.check_exc(*exc_details) |
| |
| async with self.exit_stack() as stack: |
| stack.push_async_exit(_expect_ok) |
| self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) |
| cm = ExitCM(_expect_ok) |
| stack.push_async_exit(cm) |
| self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) |
| stack.push_async_exit(_suppress_exc) |
| self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) |
| cm = ExitCM(_expect_exc) |
| stack.push_async_exit(cm) |
| self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) |
| stack.push_async_exit(_expect_exc) |
| self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) |
| stack.push_async_exit(_expect_exc) |
| self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) |
| 1/0 |
| |
| @_async_test |
| async def test_async_enter_context(self): |
| class TestCM(object): |
| async def __aenter__(self): |
| result.append(1) |
| async def __aexit__(self, *exc_details): |
| result.append(3) |
| |
| result = [] |
| cm = TestCM() |
| |
| async with AsyncExitStack() as stack: |
| @stack.push_async_callback # Registered first => cleaned up last |
| async def _exit(): |
| result.append(4) |
| self.assertIsNotNone(_exit) |
| await stack.enter_async_context(cm) |
| self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) |
| result.append(2) |
| |
| self.assertEqual(result, [1, 2, 3, 4]) |
| |
| @_async_test |
| async def test_async_exit_exception_chaining(self): |
| # Ensure exception chaining matches the reference behaviour |
| async def raise_exc(exc): |
| raise exc |
| |
| saved_details = None |
| async def suppress_exc(*exc_details): |
| nonlocal saved_details |
| saved_details = exc_details |
| return True |
| |
| try: |
| async with self.exit_stack() as stack: |
| stack.push_async_callback(raise_exc, IndexError) |
| stack.push_async_callback(raise_exc, KeyError) |
| stack.push_async_callback(raise_exc, AttributeError) |
| stack.push_async_exit(suppress_exc) |
| stack.push_async_callback(raise_exc, ValueError) |
| 1 / 0 |
| except IndexError as exc: |
| self.assertIsInstance(exc.__context__, KeyError) |
| self.assertIsInstance(exc.__context__.__context__, AttributeError) |
| # Inner exceptions were suppressed |
| self.assertIsNone(exc.__context__.__context__.__context__) |
| else: |
| self.fail("Expected IndexError, but no exception was raised") |
| # Check the inner exceptions |
| inner_exc = saved_details[1] |
| self.assertIsInstance(inner_exc, ValueError) |
| self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) |
| |
| @_async_test |
| async def test_async_exit_exception_explicit_none_context(self): |
| # Ensure AsyncExitStack chaining matches actual nested `with` statements |
| # regarding explicit __context__ = None. |
| |
| class MyException(Exception): |
| pass |
| |
| @asynccontextmanager |
| async def my_cm(): |
| try: |
| yield |
| except BaseException: |
| exc = MyException() |
| try: |
| raise exc |
| finally: |
| exc.__context__ = None |
| |
| @asynccontextmanager |
| async def my_cm_with_exit_stack(): |
| async with self.exit_stack() as stack: |
| await stack.enter_async_context(my_cm()) |
| yield stack |
| |
| for cm in (my_cm, my_cm_with_exit_stack): |
| with self.subTest(): |
| try: |
| async with cm(): |
| raise IndexError() |
| except MyException as exc: |
| self.assertIsNone(exc.__context__) |
| else: |
| self.fail("Expected IndexError, but no exception was raised") |
| |
| |
| class TestAsyncNullcontext(unittest.TestCase): |
| @_async_test |
| async def test_async_nullcontext(self): |
| class C: |
| pass |
| c = C() |
| async with nullcontext(c) as c_in: |
| self.assertIs(c_in, c) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |