|  | """Utilities shared by tests.""" | 
|  |  | 
|  | import asyncio | 
|  | import collections | 
|  | import contextlib | 
|  | import io | 
|  | import logging | 
|  | import os | 
|  | import re | 
|  | import selectors | 
|  | import socket | 
|  | import socketserver | 
|  | import sys | 
|  | import tempfile | 
|  | import threading | 
|  | import time | 
|  | import unittest | 
|  | import weakref | 
|  |  | 
|  | from unittest import mock | 
|  |  | 
|  | from http.server import HTTPServer | 
|  | from wsgiref.simple_server import WSGIRequestHandler, WSGIServer | 
|  |  | 
|  | try: | 
|  | import ssl | 
|  | except ImportError:  # pragma: no cover | 
|  | ssl = None | 
|  |  | 
|  | from asyncio import base_events | 
|  | from asyncio import events | 
|  | from asyncio import format_helpers | 
|  | from asyncio import futures | 
|  | from asyncio import tasks | 
|  | from asyncio.log import logger | 
|  | from test import support | 
|  | from test.support import threading_helper | 
|  |  | 
|  |  | 
|  | def data_file(filename): | 
|  | if hasattr(support, 'TEST_HOME_DIR'): | 
|  | fullname = os.path.join(support.TEST_HOME_DIR, filename) | 
|  | if os.path.isfile(fullname): | 
|  | return fullname | 
|  | fullname = os.path.join(os.path.dirname(__file__), '..', filename) | 
|  | if os.path.isfile(fullname): | 
|  | return fullname | 
|  | raise FileNotFoundError(filename) | 
|  |  | 
|  |  | 
|  | ONLYCERT = data_file('ssl_cert.pem') | 
|  | ONLYKEY = data_file('ssl_key.pem') | 
|  | SIGNED_CERTFILE = data_file('keycert3.pem') | 
|  | SIGNING_CA = data_file('pycacert.pem') | 
|  | PEERCERT = { | 
|  | 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',), | 
|  | 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',), | 
|  | 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',), | 
|  | 'issuer': ((('countryName', 'XY'),), | 
|  | (('organizationName', 'Python Software Foundation CA'),), | 
|  | (('commonName', 'our-ca-server'),)), | 
|  | 'notAfter': 'Oct 28 14:23:16 2037 GMT', | 
|  | 'notBefore': 'Aug 29 14:23:16 2018 GMT', | 
|  | 'serialNumber': 'CB2D80995A69525C', | 
|  | 'subject': ((('countryName', 'XY'),), | 
|  | (('localityName', 'Castle Anthrax'),), | 
|  | (('organizationName', 'Python Software Foundation'),), | 
|  | (('commonName', 'localhost'),)), | 
|  | 'subjectAltName': (('DNS', 'localhost'),), | 
|  | 'version': 3 | 
|  | } | 
|  |  | 
|  |  | 
|  | def simple_server_sslcontext(): | 
|  | server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | 
|  | server_context.load_cert_chain(ONLYCERT, ONLYKEY) | 
|  | server_context.check_hostname = False | 
|  | server_context.verify_mode = ssl.CERT_NONE | 
|  | return server_context | 
|  |  | 
|  |  | 
|  | def simple_client_sslcontext(*, disable_verify=True): | 
|  | client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) | 
|  | client_context.check_hostname = False | 
|  | if disable_verify: | 
|  | client_context.verify_mode = ssl.CERT_NONE | 
|  | return client_context | 
|  |  | 
|  |  | 
|  | def dummy_ssl_context(): | 
|  | if ssl is None: | 
|  | return None | 
|  | else: | 
|  | return simple_client_sslcontext(disable_verify=True) | 
|  |  | 
|  |  | 
|  | def run_briefly(loop): | 
|  | async def once(): | 
|  | pass | 
|  | gen = once() | 
|  | t = loop.create_task(gen) | 
|  | # Don't log a warning if the task is not done after run_until_complete(). | 
|  | # It occurs if the loop is stopped or if a task raises a BaseException. | 
|  | t._log_destroy_pending = False | 
|  | try: | 
|  | loop.run_until_complete(t) | 
|  | finally: | 
|  | gen.close() | 
|  |  | 
|  |  | 
|  | def run_until(loop, pred, timeout=support.SHORT_TIMEOUT): | 
|  | deadline = time.monotonic() + timeout | 
|  | while not pred(): | 
|  | if timeout is not None: | 
|  | timeout = deadline - time.monotonic() | 
|  | if timeout <= 0: | 
|  | raise futures.TimeoutError() | 
|  | loop.run_until_complete(tasks.sleep(0.001)) | 
|  |  | 
|  |  | 
|  | def run_once(loop): | 
|  | """Legacy API to run once through the event loop. | 
|  |  | 
|  | This is the recommended pattern for test code.  It will poll the | 
|  | selector once and run all callbacks scheduled in response to I/O | 
|  | events. | 
|  | """ | 
|  | loop.call_soon(loop.stop) | 
|  | loop.run_forever() | 
|  |  | 
|  |  | 
|  | class SilentWSGIRequestHandler(WSGIRequestHandler): | 
|  |  | 
|  | def get_stderr(self): | 
|  | return io.StringIO() | 
|  |  | 
|  | def log_message(self, format, *args): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class SilentWSGIServer(WSGIServer): | 
|  |  | 
|  | request_timeout = support.LOOPBACK_TIMEOUT | 
|  |  | 
|  | def get_request(self): | 
|  | request, client_addr = super().get_request() | 
|  | request.settimeout(self.request_timeout) | 
|  | return request, client_addr | 
|  |  | 
|  | def handle_error(self, request, client_address): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class SSLWSGIServerMixin: | 
|  |  | 
|  | def finish_request(self, request, client_address): | 
|  | # The relative location of our test directory (which | 
|  | # contains the ssl key and certificate files) differs | 
|  | # between the stdlib and stand-alone asyncio. | 
|  | # Prefer our own if we can find it. | 
|  | context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | 
|  | context.load_cert_chain(ONLYCERT, ONLYKEY) | 
|  |  | 
|  | ssock = context.wrap_socket(request, server_side=True) | 
|  | try: | 
|  | self.RequestHandlerClass(ssock, client_address, self) | 
|  | ssock.close() | 
|  | except OSError: | 
|  | # maybe socket has been closed by peer | 
|  | pass | 
|  |  | 
|  |  | 
|  | class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): | 
|  | pass | 
|  |  | 
|  |  | 
|  | def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): | 
|  |  | 
|  | def loop(environ): | 
|  | size = int(environ['CONTENT_LENGTH']) | 
|  | while size: | 
|  | data = environ['wsgi.input'].read(min(size, 0x10000)) | 
|  | yield data | 
|  | size -= len(data) | 
|  |  | 
|  | def app(environ, start_response): | 
|  | status = '200 OK' | 
|  | headers = [('Content-type', 'text/plain')] | 
|  | start_response(status, headers) | 
|  | if environ['PATH_INFO'] == '/loop': | 
|  | return loop(environ) | 
|  | else: | 
|  | return [b'Test message'] | 
|  |  | 
|  | # Run the test WSGI server in a separate thread in order not to | 
|  | # interfere with event handling in the main thread | 
|  | server_class = server_ssl_cls if use_ssl else server_cls | 
|  | httpd = server_class(address, SilentWSGIRequestHandler) | 
|  | httpd.set_app(app) | 
|  | httpd.address = httpd.server_address | 
|  | server_thread = threading.Thread( | 
|  | target=lambda: httpd.serve_forever(poll_interval=0.05)) | 
|  | server_thread.start() | 
|  | try: | 
|  | yield httpd | 
|  | finally: | 
|  | httpd.shutdown() | 
|  | httpd.server_close() | 
|  | server_thread.join() | 
|  |  | 
|  |  | 
|  | if hasattr(socket, 'AF_UNIX'): | 
|  |  | 
|  | class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): | 
|  |  | 
|  | def server_bind(self): | 
|  | socketserver.UnixStreamServer.server_bind(self) | 
|  | self.server_name = '127.0.0.1' | 
|  | self.server_port = 80 | 
|  |  | 
|  |  | 
|  | class UnixWSGIServer(UnixHTTPServer, WSGIServer): | 
|  |  | 
|  | request_timeout = support.LOOPBACK_TIMEOUT | 
|  |  | 
|  | def server_bind(self): | 
|  | UnixHTTPServer.server_bind(self) | 
|  | self.setup_environ() | 
|  |  | 
|  | def get_request(self): | 
|  | request, client_addr = super().get_request() | 
|  | request.settimeout(self.request_timeout) | 
|  | # Code in the stdlib expects that get_request | 
|  | # will return a socket and a tuple (host, port). | 
|  | # However, this isn't true for UNIX sockets, | 
|  | # as the second return value will be a path; | 
|  | # hence we return some fake data sufficient | 
|  | # to get the tests going | 
|  | return request, ('127.0.0.1', '') | 
|  |  | 
|  |  | 
|  | class SilentUnixWSGIServer(UnixWSGIServer): | 
|  |  | 
|  | def handle_error(self, request, client_address): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): | 
|  | pass | 
|  |  | 
|  |  | 
|  | def gen_unix_socket_path(): | 
|  | with tempfile.NamedTemporaryFile() as file: | 
|  | return file.name | 
|  |  | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def unix_socket_path(): | 
|  | path = gen_unix_socket_path() | 
|  | try: | 
|  | yield path | 
|  | finally: | 
|  | try: | 
|  | os.unlink(path) | 
|  | except OSError: | 
|  | pass | 
|  |  | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def run_test_unix_server(*, use_ssl=False): | 
|  | with unix_socket_path() as path: | 
|  | yield from _run_test_server(address=path, use_ssl=use_ssl, | 
|  | server_cls=SilentUnixWSGIServer, | 
|  | server_ssl_cls=UnixSSLWSGIServer) | 
|  |  | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): | 
|  | yield from _run_test_server(address=(host, port), use_ssl=use_ssl, | 
|  | server_cls=SilentWSGIServer, | 
|  | server_ssl_cls=SSLWSGIServer) | 
|  |  | 
|  |  | 
|  | def make_test_protocol(base): | 
|  | dct = {} | 
|  | for name in dir(base): | 
|  | if name.startswith('__') and name.endswith('__'): | 
|  | # skip magic names | 
|  | continue | 
|  | dct[name] = MockCallback(return_value=None) | 
|  | return type('TestProtocol', (base,) + base.__bases__, dct)() | 
|  |  | 
|  |  | 
|  | class TestSelector(selectors.BaseSelector): | 
|  |  | 
|  | def __init__(self): | 
|  | self.keys = {} | 
|  |  | 
|  | def register(self, fileobj, events, data=None): | 
|  | key = selectors.SelectorKey(fileobj, 0, events, data) | 
|  | self.keys[fileobj] = key | 
|  | return key | 
|  |  | 
|  | def unregister(self, fileobj): | 
|  | return self.keys.pop(fileobj) | 
|  |  | 
|  | def select(self, timeout): | 
|  | return [] | 
|  |  | 
|  | def get_map(self): | 
|  | return self.keys | 
|  |  | 
|  |  | 
|  | class TestLoop(base_events.BaseEventLoop): | 
|  | """Loop for unittests. | 
|  |  | 
|  | It manages self time directly. | 
|  | If something scheduled to be executed later then | 
|  | on next loop iteration after all ready handlers done | 
|  | generator passed to __init__ is calling. | 
|  |  | 
|  | Generator should be like this: | 
|  |  | 
|  | def gen(): | 
|  | ... | 
|  | when = yield ... | 
|  | ... = yield time_advance | 
|  |  | 
|  | Value returned by yield is absolute time of next scheduled handler. | 
|  | Value passed to yield is time advance to move loop's time forward. | 
|  | """ | 
|  |  | 
|  | def __init__(self, gen=None): | 
|  | super().__init__() | 
|  |  | 
|  | if gen is None: | 
|  | def gen(): | 
|  | yield | 
|  | self._check_on_close = False | 
|  | else: | 
|  | self._check_on_close = True | 
|  |  | 
|  | self._gen = gen() | 
|  | next(self._gen) | 
|  | self._time = 0 | 
|  | self._clock_resolution = 1e-9 | 
|  | self._timers = [] | 
|  | self._selector = TestSelector() | 
|  |  | 
|  | self.readers = {} | 
|  | self.writers = {} | 
|  | self.reset_counters() | 
|  |  | 
|  | self._transports = weakref.WeakValueDictionary() | 
|  |  | 
|  | def time(self): | 
|  | return self._time | 
|  |  | 
|  | def advance_time(self, advance): | 
|  | """Move test time forward.""" | 
|  | if advance: | 
|  | self._time += advance | 
|  |  | 
|  | def close(self): | 
|  | super().close() | 
|  | if self._check_on_close: | 
|  | try: | 
|  | self._gen.send(0) | 
|  | except StopIteration: | 
|  | pass | 
|  | else:  # pragma: no cover | 
|  | raise AssertionError("Time generator is not finished") | 
|  |  | 
|  | def _add_reader(self, fd, callback, *args): | 
|  | self.readers[fd] = events.Handle(callback, args, self, None) | 
|  |  | 
|  | def _remove_reader(self, fd): | 
|  | self.remove_reader_count[fd] += 1 | 
|  | if fd in self.readers: | 
|  | del self.readers[fd] | 
|  | return True | 
|  | else: | 
|  | return False | 
|  |  | 
|  | def assert_reader(self, fd, callback, *args): | 
|  | if fd not in self.readers: | 
|  | raise AssertionError(f'fd {fd} is not registered') | 
|  | handle = self.readers[fd] | 
|  | if handle._callback != callback: | 
|  | raise AssertionError( | 
|  | f'unexpected callback: {handle._callback} != {callback}') | 
|  | if handle._args != args: | 
|  | raise AssertionError( | 
|  | f'unexpected callback args: {handle._args} != {args}') | 
|  |  | 
|  | def assert_no_reader(self, fd): | 
|  | if fd in self.readers: | 
|  | raise AssertionError(f'fd {fd} is registered') | 
|  |  | 
|  | def _add_writer(self, fd, callback, *args): | 
|  | self.writers[fd] = events.Handle(callback, args, self, None) | 
|  |  | 
|  | def _remove_writer(self, fd): | 
|  | self.remove_writer_count[fd] += 1 | 
|  | if fd in self.writers: | 
|  | del self.writers[fd] | 
|  | return True | 
|  | else: | 
|  | return False | 
|  |  | 
|  | def assert_writer(self, fd, callback, *args): | 
|  | if fd not in self.writers: | 
|  | raise AssertionError(f'fd {fd} is not registered') | 
|  | handle = self.writers[fd] | 
|  | if handle._callback != callback: | 
|  | raise AssertionError(f'{handle._callback!r} != {callback!r}') | 
|  | if handle._args != args: | 
|  | raise AssertionError(f'{handle._args!r} != {args!r}') | 
|  |  | 
|  | def _ensure_fd_no_transport(self, fd): | 
|  | if not isinstance(fd, int): | 
|  | try: | 
|  | fd = int(fd.fileno()) | 
|  | except (AttributeError, TypeError, ValueError): | 
|  | # This code matches selectors._fileobj_to_fd function. | 
|  | raise ValueError("Invalid file object: " | 
|  | "{!r}".format(fd)) from None | 
|  | try: | 
|  | transport = self._transports[fd] | 
|  | except KeyError: | 
|  | pass | 
|  | else: | 
|  | raise RuntimeError( | 
|  | 'File descriptor {!r} is used by transport {!r}'.format( | 
|  | fd, transport)) | 
|  |  | 
|  | def add_reader(self, fd, callback, *args): | 
|  | """Add a reader callback.""" | 
|  | self._ensure_fd_no_transport(fd) | 
|  | return self._add_reader(fd, callback, *args) | 
|  |  | 
|  | def remove_reader(self, fd): | 
|  | """Remove a reader callback.""" | 
|  | self._ensure_fd_no_transport(fd) | 
|  | return self._remove_reader(fd) | 
|  |  | 
|  | def add_writer(self, fd, callback, *args): | 
|  | """Add a writer callback..""" | 
|  | self._ensure_fd_no_transport(fd) | 
|  | return self._add_writer(fd, callback, *args) | 
|  |  | 
|  | def remove_writer(self, fd): | 
|  | """Remove a writer callback.""" | 
|  | self._ensure_fd_no_transport(fd) | 
|  | return self._remove_writer(fd) | 
|  |  | 
|  | def reset_counters(self): | 
|  | self.remove_reader_count = collections.defaultdict(int) | 
|  | self.remove_writer_count = collections.defaultdict(int) | 
|  |  | 
|  | def _run_once(self): | 
|  | super()._run_once() | 
|  | for when in self._timers: | 
|  | advance = self._gen.send(when) | 
|  | self.advance_time(advance) | 
|  | self._timers = [] | 
|  |  | 
|  | def call_at(self, when, callback, *args, context=None): | 
|  | self._timers.append(when) | 
|  | return super().call_at(when, callback, *args, context=context) | 
|  |  | 
|  | def _process_events(self, event_list): | 
|  | return | 
|  |  | 
|  | def _write_to_self(self): | 
|  | pass | 
|  |  | 
|  |  | 
|  | def MockCallback(**kwargs): | 
|  | return mock.Mock(spec=['__call__'], **kwargs) | 
|  |  | 
|  |  | 
|  | class MockPattern(str): | 
|  | """A regex based str with a fuzzy __eq__. | 
|  |  | 
|  | Use this helper with 'mock.assert_called_with', or anywhere | 
|  | where a regex comparison between strings is needed. | 
|  |  | 
|  | For instance: | 
|  | mock_call.assert_called_with(MockPattern('spam.*ham')) | 
|  | """ | 
|  | def __eq__(self, other): | 
|  | return bool(re.search(str(self), other, re.S)) | 
|  |  | 
|  |  | 
|  | class MockInstanceOf: | 
|  | def __init__(self, type): | 
|  | self._type = type | 
|  |  | 
|  | def __eq__(self, other): | 
|  | return isinstance(other, self._type) | 
|  |  | 
|  |  | 
|  | def get_function_source(func): | 
|  | source = format_helpers._get_function_source(func) | 
|  | if source is None: | 
|  | raise ValueError("unable to get the source of %r" % (func,)) | 
|  | return source | 
|  |  | 
|  |  | 
|  | class TestCase(unittest.TestCase): | 
|  | @staticmethod | 
|  | def close_loop(loop): | 
|  | if loop._default_executor is not None: | 
|  | if not loop.is_closed(): | 
|  | loop.run_until_complete(loop.shutdown_default_executor()) | 
|  | else: | 
|  | loop._default_executor.shutdown(wait=True) | 
|  | loop.close() | 
|  | policy = support.maybe_get_event_loop_policy() | 
|  | if policy is not None: | 
|  | try: | 
|  | watcher = policy.get_child_watcher() | 
|  | except NotImplementedError: | 
|  | # watcher is not implemented by EventLoopPolicy, e.g. Windows | 
|  | pass | 
|  | else: | 
|  | if isinstance(watcher, asyncio.ThreadedChildWatcher): | 
|  | threads = list(watcher._threads.values()) | 
|  | for thread in threads: | 
|  | thread.join() | 
|  |  | 
|  | def set_event_loop(self, loop, *, cleanup=True): | 
|  | if loop is None: | 
|  | raise AssertionError('loop is None') | 
|  | # ensure that the event loop is passed explicitly in asyncio | 
|  | events.set_event_loop(None) | 
|  | if cleanup: | 
|  | self.addCleanup(self.close_loop, loop) | 
|  |  | 
|  | def new_test_loop(self, gen=None): | 
|  | loop = TestLoop(gen) | 
|  | self.set_event_loop(loop) | 
|  | return loop | 
|  |  | 
|  | def setUp(self): | 
|  | self._thread_cleanup = threading_helper.threading_setup() | 
|  |  | 
|  | def tearDown(self): | 
|  | events.set_event_loop(None) | 
|  |  | 
|  | # Detect CPython bug #23353: ensure that yield/yield-from is not used | 
|  | # in an except block of a generator | 
|  | self.assertEqual(sys.exc_info(), (None, None, None)) | 
|  |  | 
|  | self.doCleanups() | 
|  | threading_helper.threading_cleanup(*self._thread_cleanup) | 
|  | support.reap_children() | 
|  |  | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def disable_logger(): | 
|  | """Context manager to disable asyncio logger. | 
|  |  | 
|  | For example, it can be used to ignore warnings in debug mode. | 
|  | """ | 
|  | old_level = logger.level | 
|  | try: | 
|  | logger.setLevel(logging.CRITICAL+1) | 
|  | yield | 
|  | finally: | 
|  | logger.setLevel(old_level) | 
|  |  | 
|  |  | 
|  | def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, | 
|  | family=socket.AF_INET): | 
|  | """Create a mock of a non-blocking socket.""" | 
|  | sock = mock.MagicMock(socket.socket) | 
|  | sock.proto = proto | 
|  | sock.type = type | 
|  | sock.family = family | 
|  | sock.gettimeout.return_value = 0.0 | 
|  | return sock |