| import asyncio |
| import asyncio.sslproto |
| import contextlib |
| import gc |
| import logging |
| import select |
| import socket |
| import tempfile |
| import threading |
| import time |
| import weakref |
| import unittest |
| |
| try: |
| import ssl |
| except ImportError: |
| ssl = None |
| |
| from test import support |
| from test.test_asyncio import utils as test_utils |
| |
| |
| def tearDownModule(): |
| asyncio.set_event_loop_policy(None) |
| |
| |
| class MyBaseProto(asyncio.Protocol): |
| connected = None |
| done = None |
| |
| def __init__(self, loop=None): |
| self.transport = None |
| self.state = 'INITIAL' |
| self.nbytes = 0 |
| if loop is not None: |
| self.connected = asyncio.Future(loop=loop) |
| self.done = asyncio.Future(loop=loop) |
| |
| def connection_made(self, transport): |
| self.transport = transport |
| assert self.state == 'INITIAL', self.state |
| self.state = 'CONNECTED' |
| if self.connected: |
| self.connected.set_result(None) |
| |
| def data_received(self, data): |
| assert self.state == 'CONNECTED', self.state |
| self.nbytes += len(data) |
| |
| def eof_received(self): |
| assert self.state == 'CONNECTED', self.state |
| self.state = 'EOF' |
| |
| def connection_lost(self, exc): |
| assert self.state in ('CONNECTED', 'EOF'), self.state |
| self.state = 'CLOSED' |
| if self.done: |
| self.done.set_result(None) |
| |
| |
| @unittest.skipIf(ssl is None, 'No ssl module') |
| class TestSSL(test_utils.TestCase): |
| |
| PAYLOAD_SIZE = 1024 * 100 |
| TIMEOUT = support.LONG_TIMEOUT |
| |
| def setUp(self): |
| super().setUp() |
| self.loop = asyncio.new_event_loop() |
| self.set_event_loop(self.loop) |
| self.addCleanup(self.loop.close) |
| |
| def tearDown(self): |
| # just in case if we have transport close callbacks |
| if not self.loop.is_closed(): |
| test_utils.run_briefly(self.loop) |
| |
| self.doCleanups() |
| support.gc_collect() |
| super().tearDown() |
| |
| def tcp_server(self, server_prog, *, |
| family=socket.AF_INET, |
| addr=None, |
| timeout=support.SHORT_TIMEOUT, |
| backlog=1, |
| max_clients=10): |
| |
| if addr is None: |
| if family == getattr(socket, "AF_UNIX", None): |
| with tempfile.NamedTemporaryFile() as tmp: |
| addr = tmp.name |
| else: |
| addr = ('127.0.0.1', 0) |
| |
| sock = socket.socket(family, socket.SOCK_STREAM) |
| |
| if timeout is None: |
| raise RuntimeError('timeout is required') |
| if timeout <= 0: |
| raise RuntimeError('only blocking sockets are supported') |
| sock.settimeout(timeout) |
| |
| try: |
| sock.bind(addr) |
| sock.listen(backlog) |
| except OSError as ex: |
| sock.close() |
| raise ex |
| |
| return TestThreadedServer( |
| self, sock, server_prog, timeout, max_clients) |
| |
| def tcp_client(self, client_prog, |
| family=socket.AF_INET, |
| timeout=support.SHORT_TIMEOUT): |
| |
| sock = socket.socket(family, socket.SOCK_STREAM) |
| |
| if timeout is None: |
| raise RuntimeError('timeout is required') |
| if timeout <= 0: |
| raise RuntimeError('only blocking sockets are supported') |
| sock.settimeout(timeout) |
| |
| return TestThreadedClient( |
| self, sock, client_prog, timeout) |
| |
| def unix_server(self, *args, **kwargs): |
| return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) |
| |
| def unix_client(self, *args, **kwargs): |
| return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) |
| |
| def _create_server_ssl_context(self, certfile, keyfile=None): |
| sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
| sslcontext.options |= ssl.OP_NO_SSLv2 |
| sslcontext.load_cert_chain(certfile, keyfile) |
| return sslcontext |
| |
| def _create_client_ssl_context(self, *, disable_verify=True): |
| sslcontext = ssl.create_default_context() |
| sslcontext.check_hostname = False |
| if disable_verify: |
| sslcontext.verify_mode = ssl.CERT_NONE |
| return sslcontext |
| |
| @contextlib.contextmanager |
| def _silence_eof_received_warning(self): |
| # TODO This warning has to be fixed in asyncio. |
| logger = logging.getLogger('asyncio') |
| filter = logging.Filter('has no effect when using ssl') |
| logger.addFilter(filter) |
| try: |
| yield |
| finally: |
| logger.removeFilter(filter) |
| |
| def _abort_socket_test(self, ex): |
| try: |
| self.loop.stop() |
| finally: |
| self.fail(ex) |
| |
| def new_loop(self): |
| return asyncio.new_event_loop() |
| |
| def new_policy(self): |
| return asyncio.DefaultEventLoopPolicy() |
| |
| async def wait_closed(self, obj): |
| if not isinstance(obj, asyncio.StreamWriter): |
| return |
| try: |
| await obj.wait_closed() |
| except (BrokenPipeError, ConnectionError): |
| pass |
| |
| def test_create_server_ssl_1(self): |
| CNT = 0 # number of clients that were successful |
| TOTAL_CNT = 25 # total number of clients that test will create |
| TIMEOUT = support.LONG_TIMEOUT # timeout for this test |
| |
| A_DATA = b'A' * 1024 * 1024 |
| B_DATA = b'B' * 1024 * 1024 |
| |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, test_utils.ONLYKEY |
| ) |
| client_sslctx = self._create_client_ssl_context() |
| |
| clients = [] |
| |
| async def handle_client(reader, writer): |
| nonlocal CNT |
| |
| data = await reader.readexactly(len(A_DATA)) |
| self.assertEqual(data, A_DATA) |
| writer.write(b'OK') |
| |
| data = await reader.readexactly(len(B_DATA)) |
| self.assertEqual(data, B_DATA) |
| writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) |
| |
| await writer.drain() |
| writer.close() |
| |
| CNT += 1 |
| |
| async def test_client(addr): |
| fut = asyncio.Future() |
| |
| def prog(sock): |
| try: |
| sock.starttls(client_sslctx) |
| sock.connect(addr) |
| sock.send(A_DATA) |
| |
| data = sock.recv_all(2) |
| self.assertEqual(data, b'OK') |
| |
| sock.send(B_DATA) |
| data = sock.recv_all(4) |
| self.assertEqual(data, b'SPAM') |
| |
| sock.close() |
| |
| except Exception as ex: |
| self.loop.call_soon_threadsafe(fut.set_exception, ex) |
| else: |
| self.loop.call_soon_threadsafe(fut.set_result, None) |
| |
| client = self.tcp_client(prog) |
| client.start() |
| clients.append(client) |
| |
| await fut |
| |
| async def start_server(): |
| extras = {} |
| extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) |
| |
| srv = await asyncio.start_server( |
| handle_client, |
| '127.0.0.1', 0, |
| family=socket.AF_INET, |
| ssl=sslctx, |
| **extras) |
| |
| try: |
| srv_socks = srv.sockets |
| self.assertTrue(srv_socks) |
| |
| addr = srv_socks[0].getsockname() |
| |
| tasks = [] |
| for _ in range(TOTAL_CNT): |
| tasks.append(test_client(addr)) |
| |
| await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) |
| |
| finally: |
| self.loop.call_soon(srv.close) |
| await srv.wait_closed() |
| |
| with self._silence_eof_received_warning(): |
| self.loop.run_until_complete(start_server()) |
| |
| self.assertEqual(CNT, TOTAL_CNT) |
| |
| for client in clients: |
| client.stop() |
| |
| def test_create_connection_ssl_1(self): |
| self.loop.set_exception_handler(None) |
| |
| CNT = 0 |
| TOTAL_CNT = 25 |
| |
| A_DATA = b'A' * 1024 * 1024 |
| B_DATA = b'B' * 1024 * 1024 |
| |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, |
| test_utils.ONLYKEY |
| ) |
| client_sslctx = self._create_client_ssl_context() |
| |
| def server(sock): |
| sock.starttls( |
| sslctx, |
| server_side=True) |
| |
| data = sock.recv_all(len(A_DATA)) |
| self.assertEqual(data, A_DATA) |
| sock.send(b'OK') |
| |
| data = sock.recv_all(len(B_DATA)) |
| self.assertEqual(data, B_DATA) |
| sock.send(b'SPAM') |
| |
| sock.close() |
| |
| async def client(addr): |
| extras = {} |
| extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) |
| |
| reader, writer = await asyncio.open_connection( |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='', |
| **extras) |
| |
| writer.write(A_DATA) |
| self.assertEqual(await reader.readexactly(2), b'OK') |
| |
| writer.write(B_DATA) |
| self.assertEqual(await reader.readexactly(4), b'SPAM') |
| |
| nonlocal CNT |
| CNT += 1 |
| |
| writer.close() |
| await self.wait_closed(writer) |
| |
| async def client_sock(addr): |
| sock = socket.socket() |
| sock.connect(addr) |
| reader, writer = await asyncio.open_connection( |
| sock=sock, |
| ssl=client_sslctx, |
| server_hostname='') |
| |
| writer.write(A_DATA) |
| self.assertEqual(await reader.readexactly(2), b'OK') |
| |
| writer.write(B_DATA) |
| self.assertEqual(await reader.readexactly(4), b'SPAM') |
| |
| nonlocal CNT |
| CNT += 1 |
| |
| writer.close() |
| await self.wait_closed(writer) |
| sock.close() |
| |
| def run(coro): |
| nonlocal CNT |
| CNT = 0 |
| |
| async def _gather(*tasks): |
| # trampoline |
| return await asyncio.gather(*tasks) |
| |
| with self.tcp_server(server, |
| max_clients=TOTAL_CNT, |
| backlog=TOTAL_CNT) as srv: |
| tasks = [] |
| for _ in range(TOTAL_CNT): |
| tasks.append(coro(srv.addr)) |
| |
| self.loop.run_until_complete(_gather(*tasks)) |
| |
| self.assertEqual(CNT, TOTAL_CNT) |
| |
| with self._silence_eof_received_warning(): |
| run(client) |
| |
| with self._silence_eof_received_warning(): |
| run(client_sock) |
| |
| def test_create_connection_ssl_slow_handshake(self): |
| client_sslctx = self._create_client_ssl_context() |
| |
| # silence error logger |
| self.loop.set_exception_handler(lambda *args: None) |
| |
| def server(sock): |
| try: |
| sock.recv_all(1024 * 1024) |
| except ConnectionAbortedError: |
| pass |
| finally: |
| sock.close() |
| |
| async def client(addr): |
| reader, writer = await asyncio.open_connection( |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='', |
| ssl_handshake_timeout=1.0) |
| writer.close() |
| await self.wait_closed(writer) |
| |
| with self.tcp_server(server, |
| max_clients=1, |
| backlog=1) as srv: |
| |
| with self.assertRaisesRegex( |
| ConnectionAbortedError, |
| r'SSL handshake.*is taking longer'): |
| |
| self.loop.run_until_complete(client(srv.addr)) |
| |
| def test_create_connection_ssl_failed_certificate(self): |
| # silence error logger |
| self.loop.set_exception_handler(lambda *args: None) |
| |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, |
| test_utils.ONLYKEY |
| ) |
| client_sslctx = self._create_client_ssl_context(disable_verify=False) |
| |
| def server(sock): |
| try: |
| sock.starttls( |
| sslctx, |
| server_side=True) |
| sock.connect() |
| except (ssl.SSLError, OSError): |
| pass |
| finally: |
| sock.close() |
| |
| async def client(addr): |
| reader, writer = await asyncio.open_connection( |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='', |
| ssl_handshake_timeout=support.SHORT_TIMEOUT) |
| writer.close() |
| await self.wait_closed(writer) |
| |
| with self.tcp_server(server, |
| max_clients=1, |
| backlog=1) as srv: |
| |
| with self.assertRaises(ssl.SSLCertVerificationError): |
| self.loop.run_until_complete(client(srv.addr)) |
| |
| def test_ssl_handshake_timeout(self): |
| # bpo-29970: Check that a connection is aborted if handshake is not |
| # completed in timeout period, instead of remaining open indefinitely |
| client_sslctx = test_utils.simple_client_sslcontext() |
| |
| # silence error logger |
| messages = [] |
| self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) |
| |
| server_side_aborted = False |
| |
| def server(sock): |
| nonlocal server_side_aborted |
| try: |
| sock.recv_all(1024 * 1024) |
| except ConnectionAbortedError: |
| server_side_aborted = True |
| finally: |
| sock.close() |
| |
| async def client(addr): |
| await asyncio.wait_for( |
| self.loop.create_connection( |
| asyncio.Protocol, |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='', |
| ssl_handshake_timeout=10.0), |
| 0.5) |
| |
| with self.tcp_server(server, |
| max_clients=1, |
| backlog=1) as srv: |
| |
| with self.assertRaises(asyncio.TimeoutError): |
| self.loop.run_until_complete(client(srv.addr)) |
| |
| self.assertTrue(server_side_aborted) |
| |
| # Python issue #23197: cancelling a handshake must not raise an |
| # exception or log an error, even if the handshake failed |
| self.assertEqual(messages, []) |
| |
| def test_ssl_handshake_connection_lost(self): |
| # #246: make sure that no connection_lost() is called before |
| # connection_made() is called first |
| |
| client_sslctx = test_utils.simple_client_sslcontext() |
| |
| # silence error logger |
| self.loop.set_exception_handler(lambda loop, ctx: None) |
| |
| connection_made_called = False |
| connection_lost_called = False |
| |
| def server(sock): |
| sock.recv(1024) |
| # break the connection during handshake |
| sock.close() |
| |
| class ClientProto(asyncio.Protocol): |
| def connection_made(self, transport): |
| nonlocal connection_made_called |
| connection_made_called = True |
| |
| def connection_lost(self, exc): |
| nonlocal connection_lost_called |
| connection_lost_called = True |
| |
| async def client(addr): |
| await self.loop.create_connection( |
| ClientProto, |
| *addr, |
| ssl=client_sslctx, |
| server_hostname=''), |
| |
| with self.tcp_server(server, |
| max_clients=1, |
| backlog=1) as srv: |
| |
| with self.assertRaises(ConnectionResetError): |
| self.loop.run_until_complete(client(srv.addr)) |
| |
| if connection_lost_called: |
| if connection_made_called: |
| self.fail("unexpected call to connection_lost()") |
| else: |
| self.fail("unexpected call to connection_lost() without" |
| "calling connection_made()") |
| elif connection_made_called: |
| self.fail("unexpected call to connection_made()") |
| |
| def test_ssl_connect_accepted_socket(self): |
| proto = ssl.PROTOCOL_TLS_SERVER |
| server_context = ssl.SSLContext(proto) |
| server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) |
| if hasattr(server_context, 'check_hostname'): |
| server_context.check_hostname = False |
| server_context.verify_mode = ssl.CERT_NONE |
| |
| client_context = ssl.SSLContext(proto) |
| if hasattr(server_context, 'check_hostname'): |
| client_context.check_hostname = False |
| client_context.verify_mode = ssl.CERT_NONE |
| |
| def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): |
| loop = self.loop |
| |
| class MyProto(MyBaseProto): |
| |
| def connection_lost(self, exc): |
| super().connection_lost(exc) |
| loop.call_soon(loop.stop) |
| |
| def data_received(self, data): |
| super().data_received(data) |
| self.transport.write(expected_response) |
| |
| lsock = socket.socket(socket.AF_INET) |
| lsock.bind(('127.0.0.1', 0)) |
| lsock.listen(1) |
| addr = lsock.getsockname() |
| |
| message = b'test data' |
| response = None |
| expected_response = b'roger' |
| |
| def client(): |
| nonlocal response |
| try: |
| csock = socket.socket(socket.AF_INET) |
| if client_ssl is not None: |
| csock = client_ssl.wrap_socket(csock) |
| csock.connect(addr) |
| csock.sendall(message) |
| response = csock.recv(99) |
| csock.close() |
| except Exception as exc: |
| print( |
| "Failure in client thread in test_connect_accepted_socket", |
| exc) |
| |
| thread = threading.Thread(target=client, daemon=True) |
| thread.start() |
| |
| conn, _ = lsock.accept() |
| proto = MyProto(loop=loop) |
| proto.loop = loop |
| |
| extras = {} |
| if server_ssl: |
| extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) |
| |
| f = loop.create_task( |
| loop.connect_accepted_socket( |
| (lambda: proto), conn, ssl=server_ssl, |
| **extras)) |
| loop.run_forever() |
| conn.close() |
| lsock.close() |
| |
| thread.join(1) |
| self.assertFalse(thread.is_alive()) |
| self.assertEqual(proto.state, 'CLOSED') |
| self.assertEqual(proto.nbytes, len(message)) |
| self.assertEqual(response, expected_response) |
| tr, _ = f.result() |
| |
| if server_ssl: |
| self.assertIn('SSL', tr.__class__.__name__) |
| |
| tr.close() |
| # let it close |
| self.loop.run_until_complete(asyncio.sleep(0.1)) |
| |
| def test_start_tls_client_corrupted_ssl(self): |
| self.loop.set_exception_handler(lambda loop, ctx: None) |
| |
| sslctx = test_utils.simple_server_sslcontext() |
| client_sslctx = test_utils.simple_client_sslcontext() |
| |
| def server(sock): |
| orig_sock = sock.dup() |
| try: |
| sock.starttls( |
| sslctx, |
| server_side=True) |
| sock.sendall(b'A\n') |
| sock.recv_all(1) |
| orig_sock.send(b'please corrupt the SSL connection') |
| except ssl.SSLError: |
| pass |
| finally: |
| sock.close() |
| orig_sock.close() |
| |
| async def client(addr): |
| reader, writer = await asyncio.open_connection( |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='') |
| |
| self.assertEqual(await reader.readline(), b'A\n') |
| writer.write(b'B') |
| with self.assertRaises(ssl.SSLError): |
| await reader.readline() |
| writer.close() |
| try: |
| await self.wait_closed(writer) |
| except ssl.SSLError: |
| pass |
| return 'OK' |
| |
| with self.tcp_server(server, |
| max_clients=1, |
| backlog=1) as srv: |
| |
| res = self.loop.run_until_complete(client(srv.addr)) |
| |
| self.assertEqual(res, 'OK') |
| |
| def test_start_tls_client_reg_proto_1(self): |
| HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
| |
| server_context = test_utils.simple_server_sslcontext() |
| client_context = test_utils.simple_client_sslcontext() |
| |
| def serve(sock): |
| sock.settimeout(self.TIMEOUT) |
| |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| sock.starttls(server_context, server_side=True) |
| |
| sock.sendall(b'O') |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| sock.unwrap() |
| sock.close() |
| |
| class ClientProto(asyncio.Protocol): |
| def __init__(self, on_data, on_eof): |
| self.on_data = on_data |
| self.on_eof = on_eof |
| self.con_made_cnt = 0 |
| |
| def connection_made(proto, tr): |
| proto.con_made_cnt += 1 |
| # Ensure connection_made gets called only once. |
| self.assertEqual(proto.con_made_cnt, 1) |
| |
| def data_received(self, data): |
| self.on_data.set_result(data) |
| |
| def eof_received(self): |
| self.on_eof.set_result(True) |
| |
| async def client(addr): |
| await asyncio.sleep(0.5) |
| |
| on_data = self.loop.create_future() |
| on_eof = self.loop.create_future() |
| |
| tr, proto = await self.loop.create_connection( |
| lambda: ClientProto(on_data, on_eof), *addr) |
| |
| tr.write(HELLO_MSG) |
| new_tr = await self.loop.start_tls(tr, proto, client_context) |
| |
| self.assertEqual(await on_data, b'O') |
| new_tr.write(HELLO_MSG) |
| await on_eof |
| |
| new_tr.close() |
| |
| with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: |
| self.loop.run_until_complete( |
| asyncio.wait_for(client(srv.addr), |
| timeout=support.SHORT_TIMEOUT)) |
| |
| def test_create_connection_memory_leak(self): |
| HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
| |
| server_context = self._create_server_ssl_context( |
| test_utils.ONLYCERT, test_utils.ONLYKEY) |
| client_context = self._create_client_ssl_context() |
| |
| def serve(sock): |
| sock.settimeout(self.TIMEOUT) |
| |
| sock.starttls(server_context, server_side=True) |
| |
| sock.sendall(b'O') |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| sock.unwrap() |
| sock.close() |
| |
| class ClientProto(asyncio.Protocol): |
| def __init__(self, on_data, on_eof): |
| self.on_data = on_data |
| self.on_eof = on_eof |
| self.con_made_cnt = 0 |
| |
| def connection_made(proto, tr): |
| # XXX: We assume user stores the transport in protocol |
| proto.tr = tr |
| proto.con_made_cnt += 1 |
| # Ensure connection_made gets called only once. |
| self.assertEqual(proto.con_made_cnt, 1) |
| |
| def data_received(self, data): |
| self.on_data.set_result(data) |
| |
| def eof_received(self): |
| self.on_eof.set_result(True) |
| |
| async def client(addr): |
| await asyncio.sleep(0.5) |
| |
| on_data = self.loop.create_future() |
| on_eof = self.loop.create_future() |
| |
| tr, proto = await self.loop.create_connection( |
| lambda: ClientProto(on_data, on_eof), *addr, |
| ssl=client_context) |
| |
| self.assertEqual(await on_data, b'O') |
| tr.write(HELLO_MSG) |
| await on_eof |
| |
| tr.close() |
| |
| with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: |
| self.loop.run_until_complete( |
| asyncio.wait_for(client(srv.addr), |
| timeout=support.SHORT_TIMEOUT)) |
| |
| # No garbage is left for SSL client from loop.create_connection, even |
| # if user stores the SSLTransport in corresponding protocol instance |
| client_context = weakref.ref(client_context) |
| self.assertIsNone(client_context()) |
| |
| def test_start_tls_client_buf_proto_1(self): |
| HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
| |
| server_context = test_utils.simple_server_sslcontext() |
| client_context = test_utils.simple_client_sslcontext() |
| |
| client_con_made_calls = 0 |
| |
| def serve(sock): |
| sock.settimeout(self.TIMEOUT) |
| |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| sock.starttls(server_context, server_side=True) |
| |
| sock.sendall(b'O') |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| sock.sendall(b'2') |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| sock.unwrap() |
| sock.close() |
| |
| class ClientProtoFirst(asyncio.BufferedProtocol): |
| def __init__(self, on_data): |
| self.on_data = on_data |
| self.buf = bytearray(1) |
| |
| def connection_made(self, tr): |
| nonlocal client_con_made_calls |
| client_con_made_calls += 1 |
| |
| def get_buffer(self, sizehint): |
| return self.buf |
| |
| def buffer_updated(self, nsize): |
| assert nsize == 1 |
| self.on_data.set_result(bytes(self.buf[:nsize])) |
| |
| def eof_received(self): |
| pass |
| |
| class ClientProtoSecond(asyncio.Protocol): |
| def __init__(self, on_data, on_eof): |
| self.on_data = on_data |
| self.on_eof = on_eof |
| self.con_made_cnt = 0 |
| |
| def connection_made(self, tr): |
| nonlocal client_con_made_calls |
| client_con_made_calls += 1 |
| |
| def data_received(self, data): |
| self.on_data.set_result(data) |
| |
| def eof_received(self): |
| self.on_eof.set_result(True) |
| |
| async def client(addr): |
| await asyncio.sleep(0.5) |
| |
| on_data1 = self.loop.create_future() |
| on_data2 = self.loop.create_future() |
| on_eof = self.loop.create_future() |
| |
| tr, proto = await self.loop.create_connection( |
| lambda: ClientProtoFirst(on_data1), *addr) |
| |
| tr.write(HELLO_MSG) |
| new_tr = await self.loop.start_tls(tr, proto, client_context) |
| |
| self.assertEqual(await on_data1, b'O') |
| new_tr.write(HELLO_MSG) |
| |
| new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) |
| self.assertEqual(await on_data2, b'2') |
| new_tr.write(HELLO_MSG) |
| await on_eof |
| |
| new_tr.close() |
| |
| # connection_made() should be called only once -- when |
| # we establish connection for the first time. Start TLS |
| # doesn't call connection_made() on application protocols. |
| self.assertEqual(client_con_made_calls, 1) |
| |
| with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: |
| self.loop.run_until_complete( |
| asyncio.wait_for(client(srv.addr), |
| timeout=self.TIMEOUT)) |
| |
| def test_start_tls_slow_client_cancel(self): |
| HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
| |
| client_context = test_utils.simple_client_sslcontext() |
| server_waits_on_handshake = self.loop.create_future() |
| |
| def serve(sock): |
| sock.settimeout(self.TIMEOUT) |
| |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| try: |
| self.loop.call_soon_threadsafe( |
| server_waits_on_handshake.set_result, None) |
| data = sock.recv_all(1024 * 1024) |
| except ConnectionAbortedError: |
| pass |
| finally: |
| sock.close() |
| |
| class ClientProto(asyncio.Protocol): |
| def __init__(self, on_data, on_eof): |
| self.on_data = on_data |
| self.on_eof = on_eof |
| self.con_made_cnt = 0 |
| |
| def connection_made(proto, tr): |
| proto.con_made_cnt += 1 |
| # Ensure connection_made gets called only once. |
| self.assertEqual(proto.con_made_cnt, 1) |
| |
| def data_received(self, data): |
| self.on_data.set_result(data) |
| |
| def eof_received(self): |
| self.on_eof.set_result(True) |
| |
| async def client(addr): |
| await asyncio.sleep(0.5) |
| |
| on_data = self.loop.create_future() |
| on_eof = self.loop.create_future() |
| |
| tr, proto = await self.loop.create_connection( |
| lambda: ClientProto(on_data, on_eof), *addr) |
| |
| tr.write(HELLO_MSG) |
| |
| await server_waits_on_handshake |
| |
| with self.assertRaises(asyncio.TimeoutError): |
| await asyncio.wait_for( |
| self.loop.start_tls(tr, proto, client_context), |
| 0.5) |
| |
| with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: |
| self.loop.run_until_complete( |
| asyncio.wait_for(client(srv.addr), |
| timeout=support.SHORT_TIMEOUT)) |
| |
| def test_start_tls_server_1(self): |
| HELLO_MSG = b'1' * self.PAYLOAD_SIZE |
| |
| server_context = test_utils.simple_server_sslcontext() |
| client_context = test_utils.simple_client_sslcontext() |
| |
| def client(sock, addr): |
| sock.settimeout(self.TIMEOUT) |
| |
| sock.connect(addr) |
| data = sock.recv_all(len(HELLO_MSG)) |
| self.assertEqual(len(data), len(HELLO_MSG)) |
| |
| sock.starttls(client_context) |
| sock.sendall(HELLO_MSG) |
| |
| sock.unwrap() |
| sock.close() |
| |
| class ServerProto(asyncio.Protocol): |
| def __init__(self, on_con, on_eof, on_con_lost): |
| self.on_con = on_con |
| self.on_eof = on_eof |
| self.on_con_lost = on_con_lost |
| self.data = b'' |
| |
| def connection_made(self, tr): |
| self.on_con.set_result(tr) |
| |
| def data_received(self, data): |
| self.data += data |
| |
| def eof_received(self): |
| self.on_eof.set_result(1) |
| |
| def connection_lost(self, exc): |
| if exc is None: |
| self.on_con_lost.set_result(None) |
| else: |
| self.on_con_lost.set_exception(exc) |
| |
| async def main(proto, on_con, on_eof, on_con_lost): |
| tr = await on_con |
| tr.write(HELLO_MSG) |
| |
| self.assertEqual(proto.data, b'') |
| |
| new_tr = await self.loop.start_tls( |
| tr, proto, server_context, |
| server_side=True, |
| ssl_handshake_timeout=self.TIMEOUT) |
| |
| await on_eof |
| await on_con_lost |
| self.assertEqual(proto.data, HELLO_MSG) |
| new_tr.close() |
| |
| async def run_main(): |
| on_con = self.loop.create_future() |
| on_eof = self.loop.create_future() |
| on_con_lost = self.loop.create_future() |
| proto = ServerProto(on_con, on_eof, on_con_lost) |
| |
| server = await self.loop.create_server( |
| lambda: proto, '127.0.0.1', 0) |
| addr = server.sockets[0].getsockname() |
| |
| with self.tcp_client(lambda sock: client(sock, addr), |
| timeout=self.TIMEOUT): |
| await asyncio.wait_for( |
| main(proto, on_con, on_eof, on_con_lost), |
| timeout=self.TIMEOUT) |
| |
| server.close() |
| await server.wait_closed() |
| |
| self.loop.run_until_complete(run_main()) |
| |
| def test_create_server_ssl_over_ssl(self): |
| CNT = 0 # number of clients that were successful |
| TOTAL_CNT = 25 # total number of clients that test will create |
| TIMEOUT = support.LONG_TIMEOUT # timeout for this test |
| |
| A_DATA = b'A' * 1024 * 1024 |
| B_DATA = b'B' * 1024 * 1024 |
| |
| sslctx_1 = self._create_server_ssl_context( |
| test_utils.ONLYCERT, test_utils.ONLYKEY) |
| client_sslctx_1 = self._create_client_ssl_context() |
| sslctx_2 = self._create_server_ssl_context( |
| test_utils.ONLYCERT, test_utils.ONLYKEY) |
| client_sslctx_2 = self._create_client_ssl_context() |
| |
| clients = [] |
| |
| async def handle_client(reader, writer): |
| nonlocal CNT |
| |
| data = await reader.readexactly(len(A_DATA)) |
| self.assertEqual(data, A_DATA) |
| writer.write(b'OK') |
| |
| data = await reader.readexactly(len(B_DATA)) |
| self.assertEqual(data, B_DATA) |
| writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) |
| |
| await writer.drain() |
| writer.close() |
| |
| CNT += 1 |
| |
| class ServerProtocol(asyncio.StreamReaderProtocol): |
| def connection_made(self, transport): |
| super_ = super() |
| transport.pause_reading() |
| fut = self._loop.create_task(self._loop.start_tls( |
| transport, self, sslctx_2, server_side=True)) |
| |
| def cb(_): |
| try: |
| tr = fut.result() |
| except Exception as ex: |
| super_.connection_lost(ex) |
| else: |
| super_.connection_made(tr) |
| fut.add_done_callback(cb) |
| |
| def server_protocol_factory(): |
| reader = asyncio.StreamReader() |
| protocol = ServerProtocol(reader, handle_client) |
| return protocol |
| |
| async def test_client(addr): |
| fut = asyncio.Future() |
| |
| def prog(sock): |
| try: |
| sock.connect(addr) |
| sock.starttls(client_sslctx_1) |
| |
| # because wrap_socket() doesn't work correctly on |
| # SSLSocket, we have to do the 2nd level SSL manually |
| incoming = ssl.MemoryBIO() |
| outgoing = ssl.MemoryBIO() |
| sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) |
| |
| def do(func, *args): |
| while True: |
| try: |
| rv = func(*args) |
| break |
| except ssl.SSLWantReadError: |
| if outgoing.pending: |
| sock.send(outgoing.read()) |
| incoming.write(sock.recv(65536)) |
| if outgoing.pending: |
| sock.send(outgoing.read()) |
| return rv |
| |
| do(sslobj.do_handshake) |
| |
| do(sslobj.write, A_DATA) |
| data = do(sslobj.read, 2) |
| self.assertEqual(data, b'OK') |
| |
| do(sslobj.write, B_DATA) |
| data = b'' |
| while True: |
| chunk = do(sslobj.read, 4) |
| if not chunk: |
| break |
| data += chunk |
| self.assertEqual(data, b'SPAM') |
| |
| do(sslobj.unwrap) |
| sock.close() |
| |
| except Exception as ex: |
| self.loop.call_soon_threadsafe(fut.set_exception, ex) |
| sock.close() |
| else: |
| self.loop.call_soon_threadsafe(fut.set_result, None) |
| |
| client = self.tcp_client(prog) |
| client.start() |
| clients.append(client) |
| |
| await fut |
| |
| async def start_server(): |
| extras = {} |
| |
| srv = await self.loop.create_server( |
| server_protocol_factory, |
| '127.0.0.1', 0, |
| family=socket.AF_INET, |
| ssl=sslctx_1, |
| **extras) |
| |
| try: |
| srv_socks = srv.sockets |
| self.assertTrue(srv_socks) |
| |
| addr = srv_socks[0].getsockname() |
| |
| tasks = [] |
| for _ in range(TOTAL_CNT): |
| tasks.append(test_client(addr)) |
| |
| await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) |
| |
| finally: |
| self.loop.call_soon(srv.close) |
| await srv.wait_closed() |
| |
| with self._silence_eof_received_warning(): |
| self.loop.run_until_complete(start_server()) |
| |
| self.assertEqual(CNT, TOTAL_CNT) |
| |
| for client in clients: |
| client.stop() |
| |
| def test_shutdown_cleanly(self): |
| CNT = 0 |
| TOTAL_CNT = 25 |
| |
| A_DATA = b'A' * 1024 * 1024 |
| |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, test_utils.ONLYKEY) |
| client_sslctx = self._create_client_ssl_context() |
| |
| def server(sock): |
| sock.starttls( |
| sslctx, |
| server_side=True) |
| |
| data = sock.recv_all(len(A_DATA)) |
| self.assertEqual(data, A_DATA) |
| sock.send(b'OK') |
| |
| sock.unwrap() |
| |
| sock.close() |
| |
| async def client(addr): |
| extras = {} |
| extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) |
| |
| reader, writer = await asyncio.open_connection( |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='', |
| **extras) |
| |
| writer.write(A_DATA) |
| self.assertEqual(await reader.readexactly(2), b'OK') |
| |
| self.assertEqual(await reader.read(), b'') |
| |
| nonlocal CNT |
| CNT += 1 |
| |
| writer.close() |
| await self.wait_closed(writer) |
| |
| def run(coro): |
| nonlocal CNT |
| CNT = 0 |
| |
| async def _gather(*tasks): |
| return await asyncio.gather(*tasks) |
| |
| with self.tcp_server(server, |
| max_clients=TOTAL_CNT, |
| backlog=TOTAL_CNT) as srv: |
| tasks = [] |
| for _ in range(TOTAL_CNT): |
| tasks.append(coro(srv.addr)) |
| |
| self.loop.run_until_complete( |
| _gather(*tasks)) |
| |
| self.assertEqual(CNT, TOTAL_CNT) |
| |
| with self._silence_eof_received_warning(): |
| run(client) |
| |
| def test_flush_before_shutdown(self): |
| CHUNK = 1024 * 128 |
| SIZE = 32 |
| |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, test_utils.ONLYKEY) |
| client_sslctx = self._create_client_ssl_context() |
| |
| future = None |
| |
| def server(sock): |
| sock.starttls(sslctx, server_side=True) |
| self.assertEqual(sock.recv_all(4), b'ping') |
| sock.send(b'pong') |
| time.sleep(0.5) # hopefully stuck the TCP buffer |
| data = sock.recv_all(CHUNK * SIZE) |
| self.assertEqual(len(data), CHUNK * SIZE) |
| sock.close() |
| |
| def run(meth): |
| def wrapper(sock): |
| try: |
| meth(sock) |
| except Exception as ex: |
| self.loop.call_soon_threadsafe(future.set_exception, ex) |
| else: |
| self.loop.call_soon_threadsafe(future.set_result, None) |
| return wrapper |
| |
| async def client(addr): |
| nonlocal future |
| future = self.loop.create_future() |
| reader, writer = await asyncio.open_connection( |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='') |
| sslprotocol = writer.transport._ssl_protocol |
| writer.write(b'ping') |
| data = await reader.readexactly(4) |
| self.assertEqual(data, b'pong') |
| |
| sslprotocol.pause_writing() |
| for _ in range(SIZE): |
| writer.write(b'x' * CHUNK) |
| |
| writer.close() |
| sslprotocol.resume_writing() |
| |
| await self.wait_closed(writer) |
| try: |
| data = await reader.read() |
| self.assertEqual(data, b'') |
| except ConnectionResetError: |
| pass |
| await future |
| |
| with self.tcp_server(run(server)) as srv: |
| self.loop.run_until_complete(client(srv.addr)) |
| |
| def test_remote_shutdown_receives_trailing_data(self): |
| CHUNK = 1024 * 128 |
| SIZE = 32 |
| |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, |
| test_utils.ONLYKEY |
| ) |
| client_sslctx = self._create_client_ssl_context() |
| future = None |
| |
| def server(sock): |
| incoming = ssl.MemoryBIO() |
| outgoing = ssl.MemoryBIO() |
| sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) |
| |
| while True: |
| try: |
| sslobj.do_handshake() |
| except ssl.SSLWantReadError: |
| if outgoing.pending: |
| sock.send(outgoing.read()) |
| incoming.write(sock.recv(16384)) |
| else: |
| if outgoing.pending: |
| sock.send(outgoing.read()) |
| break |
| |
| while True: |
| try: |
| data = sslobj.read(4) |
| except ssl.SSLWantReadError: |
| incoming.write(sock.recv(16384)) |
| else: |
| break |
| |
| self.assertEqual(data, b'ping') |
| sslobj.write(b'pong') |
| sock.send(outgoing.read()) |
| |
| time.sleep(0.2) # wait for the peer to fill its backlog |
| |
| # send close_notify but don't wait for response |
| with self.assertRaises(ssl.SSLWantReadError): |
| sslobj.unwrap() |
| sock.send(outgoing.read()) |
| |
| # should receive all data |
| data_len = 0 |
| while True: |
| try: |
| chunk = len(sslobj.read(16384)) |
| data_len += chunk |
| except ssl.SSLWantReadError: |
| incoming.write(sock.recv(16384)) |
| except ssl.SSLZeroReturnError: |
| break |
| |
| self.assertEqual(data_len, CHUNK * SIZE) |
| |
| # verify that close_notify is received |
| sslobj.unwrap() |
| |
| sock.close() |
| |
| def eof_server(sock): |
| sock.starttls(sslctx, server_side=True) |
| self.assertEqual(sock.recv_all(4), b'ping') |
| sock.send(b'pong') |
| |
| time.sleep(0.2) # wait for the peer to fill its backlog |
| |
| # send EOF |
| sock.shutdown(socket.SHUT_WR) |
| |
| # should receive all data |
| data = sock.recv_all(CHUNK * SIZE) |
| self.assertEqual(len(data), CHUNK * SIZE) |
| |
| sock.close() |
| |
| async def client(addr): |
| nonlocal future |
| future = self.loop.create_future() |
| |
| reader, writer = await asyncio.open_connection( |
| *addr, |
| ssl=client_sslctx, |
| server_hostname='') |
| writer.write(b'ping') |
| data = await reader.readexactly(4) |
| self.assertEqual(data, b'pong') |
| |
| # fill write backlog in a hacky way - renegotiation won't help |
| for _ in range(SIZE): |
| writer.transport._test__append_write_backlog(b'x' * CHUNK) |
| |
| try: |
| data = await reader.read() |
| self.assertEqual(data, b'') |
| except (BrokenPipeError, ConnectionResetError): |
| pass |
| |
| await future |
| |
| writer.close() |
| await self.wait_closed(writer) |
| |
| def run(meth): |
| def wrapper(sock): |
| try: |
| meth(sock) |
| except Exception as ex: |
| self.loop.call_soon_threadsafe(future.set_exception, ex) |
| else: |
| self.loop.call_soon_threadsafe(future.set_result, None) |
| return wrapper |
| |
| with self.tcp_server(run(server)) as srv: |
| self.loop.run_until_complete(client(srv.addr)) |
| |
| with self.tcp_server(run(eof_server)) as srv: |
| self.loop.run_until_complete(client(srv.addr)) |
| |
| def test_connect_timeout_warning(self): |
| s = socket.socket(socket.AF_INET) |
| s.bind(('127.0.0.1', 0)) |
| addr = s.getsockname() |
| |
| async def test(): |
| try: |
| await asyncio.wait_for( |
| self.loop.create_connection(asyncio.Protocol, |
| *addr, ssl=True), |
| 0.1) |
| except (ConnectionRefusedError, asyncio.TimeoutError): |
| pass |
| else: |
| self.fail('TimeoutError is not raised') |
| |
| with s: |
| try: |
| with self.assertWarns(ResourceWarning) as cm: |
| self.loop.run_until_complete(test()) |
| gc.collect() |
| gc.collect() |
| gc.collect() |
| except AssertionError as e: |
| self.assertEqual(str(e), 'ResourceWarning not triggered') |
| else: |
| self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) |
| |
| def test_handshake_timeout_handler_leak(self): |
| s = socket.socket(socket.AF_INET) |
| s.bind(('127.0.0.1', 0)) |
| s.listen(1) |
| addr = s.getsockname() |
| |
| async def test(ctx): |
| try: |
| await asyncio.wait_for( |
| self.loop.create_connection(asyncio.Protocol, *addr, |
| ssl=ctx), |
| 0.1) |
| except (ConnectionRefusedError, asyncio.TimeoutError): |
| pass |
| else: |
| self.fail('TimeoutError is not raised') |
| |
| with s: |
| ctx = ssl.create_default_context() |
| self.loop.run_until_complete(test(ctx)) |
| ctx = weakref.ref(ctx) |
| |
| # SSLProtocol should be DECREF to 0 |
| self.assertIsNone(ctx()) |
| |
| def test_shutdown_timeout_handler_leak(self): |
| loop = self.loop |
| |
| def server(sock): |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, |
| test_utils.ONLYKEY |
| ) |
| sock = sslctx.wrap_socket(sock, server_side=True) |
| sock.recv(32) |
| sock.close() |
| |
| class Protocol(asyncio.Protocol): |
| def __init__(self): |
| self.fut = asyncio.Future(loop=loop) |
| |
| def connection_lost(self, exc): |
| self.fut.set_result(None) |
| |
| async def client(addr, ctx): |
| tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) |
| tr.close() |
| await pr.fut |
| |
| with self.tcp_server(server) as srv: |
| ctx = self._create_client_ssl_context() |
| loop.run_until_complete(client(srv.addr, ctx)) |
| ctx = weakref.ref(ctx) |
| |
| # asyncio has no shutdown timeout, but it ends up with a circular |
| # reference loop - not ideal (introduces gc glitches), but at least |
| # not leaking |
| gc.collect() |
| gc.collect() |
| gc.collect() |
| |
| # SSLProtocol should be DECREF to 0 |
| self.assertIsNone(ctx()) |
| |
| def test_shutdown_timeout_handler_not_set(self): |
| loop = self.loop |
| eof = asyncio.Event() |
| extra = None |
| |
| def server(sock): |
| sslctx = self._create_server_ssl_context( |
| test_utils.ONLYCERT, |
| test_utils.ONLYKEY |
| ) |
| sock = sslctx.wrap_socket(sock, server_side=True) |
| sock.send(b'hello') |
| assert sock.recv(1024) == b'world' |
| sock.send(b'extra bytes') |
| # sending EOF here |
| sock.shutdown(socket.SHUT_WR) |
| loop.call_soon_threadsafe(eof.set) |
| # make sure we have enough time to reproduce the issue |
| assert sock.recv(1024) == b'' |
| sock.close() |
| |
| class Protocol(asyncio.Protocol): |
| def __init__(self): |
| self.fut = asyncio.Future(loop=loop) |
| self.transport = None |
| |
| def connection_made(self, transport): |
| self.transport = transport |
| |
| def data_received(self, data): |
| if data == b'hello': |
| self.transport.write(b'world') |
| # pause reading would make incoming data stay in the sslobj |
| self.transport.pause_reading() |
| else: |
| nonlocal extra |
| extra = data |
| |
| def connection_lost(self, exc): |
| if exc is None: |
| self.fut.set_result(None) |
| else: |
| self.fut.set_exception(exc) |
| |
| async def client(addr): |
| ctx = self._create_client_ssl_context() |
| tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) |
| await eof.wait() |
| tr.resume_reading() |
| await pr.fut |
| tr.close() |
| assert extra == b'extra bytes' |
| |
| with self.tcp_server(server) as srv: |
| loop.run_until_complete(client(srv.addr)) |
| |
| |
| ############################################################################### |
| # Socket Testing Utilities |
| ############################################################################### |
| |
| |
| class TestSocketWrapper: |
| |
| def __init__(self, sock): |
| self.__sock = sock |
| |
| def recv_all(self, n): |
| buf = b'' |
| while len(buf) < n: |
| data = self.recv(n - len(buf)) |
| if data == b'': |
| raise ConnectionAbortedError |
| buf += data |
| return buf |
| |
| def starttls(self, ssl_context, *, |
| server_side=False, |
| server_hostname=None, |
| do_handshake_on_connect=True): |
| |
| assert isinstance(ssl_context, ssl.SSLContext) |
| |
| ssl_sock = ssl_context.wrap_socket( |
| self.__sock, server_side=server_side, |
| server_hostname=server_hostname, |
| do_handshake_on_connect=do_handshake_on_connect) |
| |
| if server_side: |
| ssl_sock.do_handshake() |
| |
| self.__sock.close() |
| self.__sock = ssl_sock |
| |
| def __getattr__(self, name): |
| return getattr(self.__sock, name) |
| |
| def __repr__(self): |
| return '<{} {!r}>'.format(type(self).__name__, self.__sock) |
| |
| |
| class SocketThread(threading.Thread): |
| |
| def stop(self): |
| self._active = False |
| self.join() |
| |
| def __enter__(self): |
| self.start() |
| return self |
| |
| def __exit__(self, *exc): |
| self.stop() |
| |
| |
| class TestThreadedClient(SocketThread): |
| |
| def __init__(self, test, sock, prog, timeout): |
| threading.Thread.__init__(self, None, None, 'test-client') |
| self.daemon = True |
| |
| self._timeout = timeout |
| self._sock = sock |
| self._active = True |
| self._prog = prog |
| self._test = test |
| |
| def run(self): |
| try: |
| self._prog(TestSocketWrapper(self._sock)) |
| except (KeyboardInterrupt, SystemExit): |
| raise |
| except BaseException as ex: |
| self._test._abort_socket_test(ex) |
| |
| |
| class TestThreadedServer(SocketThread): |
| |
| def __init__(self, test, sock, prog, timeout, max_clients): |
| threading.Thread.__init__(self, None, None, 'test-server') |
| self.daemon = True |
| |
| self._clients = 0 |
| self._finished_clients = 0 |
| self._max_clients = max_clients |
| self._timeout = timeout |
| self._sock = sock |
| self._active = True |
| |
| self._prog = prog |
| |
| self._s1, self._s2 = socket.socketpair() |
| self._s1.setblocking(False) |
| |
| self._test = test |
| |
| def stop(self): |
| try: |
| if self._s2 and self._s2.fileno() != -1: |
| try: |
| self._s2.send(b'stop') |
| except OSError: |
| pass |
| finally: |
| super().stop() |
| |
| def run(self): |
| try: |
| with self._sock: |
| self._sock.setblocking(0) |
| self._run() |
| finally: |
| self._s1.close() |
| self._s2.close() |
| |
| def _run(self): |
| while self._active: |
| if self._clients >= self._max_clients: |
| return |
| |
| r, w, x = select.select( |
| [self._sock, self._s1], [], [], self._timeout) |
| |
| if self._s1 in r: |
| return |
| |
| if self._sock in r: |
| try: |
| conn, addr = self._sock.accept() |
| except BlockingIOError: |
| continue |
| except socket.timeout: |
| if not self._active: |
| return |
| else: |
| raise |
| else: |
| self._clients += 1 |
| conn.settimeout(self._timeout) |
| try: |
| with conn: |
| self._handle_client(conn) |
| except (KeyboardInterrupt, SystemExit): |
| raise |
| except BaseException as ex: |
| self._active = False |
| try: |
| raise |
| finally: |
| self._test._abort_socket_test(ex) |
| |
| def _handle_client(self, sock): |
| self._prog(TestSocketWrapper(sock)) |
| |
| @property |
| def addr(self): |
| return self._sock.getsockname() |