| """Tests for asyncio/threads.py""" |
| |
| import asyncio |
| import unittest |
| |
| from contextvars import ContextVar |
| from unittest import mock |
| from test.test_asyncio import utils as test_utils |
| |
| |
| def tearDownModule(): |
| asyncio.set_event_loop_policy(None) |
| |
| |
| class ToThreadTests(test_utils.TestCase): |
| def setUp(self): |
| super().setUp() |
| self.loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(self.loop) |
| |
| def tearDown(self): |
| self.loop.run_until_complete( |
| self.loop.shutdown_default_executor()) |
| self.loop.close() |
| asyncio.set_event_loop(None) |
| self.loop = None |
| super().tearDown() |
| |
| def test_to_thread(self): |
| async def main(): |
| return await asyncio.to_thread(sum, [40, 2]) |
| |
| result = self.loop.run_until_complete(main()) |
| self.assertEqual(result, 42) |
| |
| def test_to_thread_exception(self): |
| def raise_runtime(): |
| raise RuntimeError("test") |
| |
| async def main(): |
| await asyncio.to_thread(raise_runtime) |
| |
| with self.assertRaisesRegex(RuntimeError, "test"): |
| self.loop.run_until_complete(main()) |
| |
| def test_to_thread_once(self): |
| func = mock.Mock() |
| |
| async def main(): |
| await asyncio.to_thread(func) |
| |
| self.loop.run_until_complete(main()) |
| func.assert_called_once() |
| |
| def test_to_thread_concurrent(self): |
| func = mock.Mock() |
| |
| async def main(): |
| futs = [] |
| for _ in range(10): |
| fut = asyncio.to_thread(func) |
| futs.append(fut) |
| await asyncio.gather(*futs) |
| |
| self.loop.run_until_complete(main()) |
| self.assertEqual(func.call_count, 10) |
| |
| def test_to_thread_args_kwargs(self): |
| # Unlike run_in_executor(), to_thread() should directly accept kwargs. |
| func = mock.Mock() |
| |
| async def main(): |
| await asyncio.to_thread(func, 'test', something=True) |
| |
| self.loop.run_until_complete(main()) |
| func.assert_called_once_with('test', something=True) |
| |
| def test_to_thread_contextvars(self): |
| test_ctx = ContextVar('test_ctx') |
| |
| def get_ctx(): |
| return test_ctx.get() |
| |
| async def main(): |
| test_ctx.set('parrot') |
| return await asyncio.to_thread(get_ctx) |
| |
| result = self.loop.run_until_complete(main()) |
| self.assertEqual(result, 'parrot') |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |