blob: 70cd140f479669c69fc4c413b896b6a97d40377b [file] [log] [blame]
import asyncio
import asyncio.events
import contextlib
import os
import pprint
import select
import socket
import tempfile
import threading
class FunctionalTestCaseMixin:
def new_loop(self):
return asyncio.new_event_loop()
def run_loop_briefly(self, *, delay=0.01):
self.loop.run_until_complete(asyncio.sleep(delay))
def loop_exception_handler(self, loop, context):
self.__unhandled_exceptions.append(context)
self.loop.default_exception_handler(context)
def setUp(self):
self.loop = self.new_loop()
asyncio.set_event_loop(None)
self.loop.set_exception_handler(self.loop_exception_handler)
self.__unhandled_exceptions = []
# Disable `_get_running_loop`.
self._old_get_running_loop = asyncio.events._get_running_loop
asyncio.events._get_running_loop = lambda: None
def tearDown(self):
try:
self.loop.close()
if self.__unhandled_exceptions:
print('Unexpected calls to loop.call_exception_handler():')
pprint.pprint(self.__unhandled_exceptions)
self.fail('unexpected calls to loop.call_exception_handler()')
finally:
asyncio.events._get_running_loop = self._old_get_running_loop
asyncio.set_event_loop(None)
self.loop = None
def tcp_server(self, server_prog, *,
family=socket.AF_INET,
addr=None,
timeout=5,
backlog=1,
max_clients=10):
if addr is None:
if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
with tempfile.NamedTemporaryFile() as tmp:
addr = tmp.name
else:
addr = ('127.0.0.1', 0)
sock = socket.create_server(addr, family=family, backlog=backlog)
if timeout is None:
raise RuntimeError('timeout is required')
if timeout <= 0:
raise RuntimeError('only blocking sockets are supported')
sock.settimeout(timeout)
return TestThreadedServer(
self, sock, server_prog, timeout, max_clients)
def tcp_client(self, client_prog,
family=socket.AF_INET,
timeout=10):
sock = socket.socket(family, socket.SOCK_STREAM)
if timeout is None:
raise RuntimeError('timeout is required')
if timeout <= 0:
raise RuntimeError('only blocking sockets are supported')
sock.settimeout(timeout)
return TestThreadedClient(
self, sock, client_prog, timeout)
def unix_server(self, *args, **kwargs):
if not hasattr(socket, 'AF_UNIX'):
raise NotImplementedError
return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
def unix_client(self, *args, **kwargs):
if not hasattr(socket, 'AF_UNIX'):
raise NotImplementedError
return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
@contextlib.contextmanager
def unix_sock_name(self):
with tempfile.TemporaryDirectory() as td:
fn = os.path.join(td, 'sock')
try:
yield fn
finally:
try:
os.unlink(fn)
except OSError:
pass
def _abort_socket_test(self, ex):
try:
self.loop.stop()
finally:
self.fail(ex)
##############################################################################
# Socket Testing Utilities
##############################################################################
class TestSocketWrapper:
def __init__(self, sock):
self.__sock = sock
def recv_all(self, n):
buf = b''
while len(buf) < n:
data = self.recv(n - len(buf))
if data == b'':
raise ConnectionAbortedError
buf += data
return buf
def start_tls(self, ssl_context, *,
server_side=False,
server_hostname=None):
ssl_sock = ssl_context.wrap_socket(
self.__sock, server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=False)
try:
ssl_sock.do_handshake()
except:
ssl_sock.close()
raise
finally:
self.__sock.close()
self.__sock = ssl_sock
def __getattr__(self, name):
return getattr(self.__sock, name)
def __repr__(self):
return '<{} {!r}>'.format(type(self).__name__, self.__sock)
class SocketThread(threading.Thread):
def stop(self):
self._active = False
self.join()
def __enter__(self):
self.start()
return self
def __exit__(self, *exc):
self.stop()
class TestThreadedClient(SocketThread):
def __init__(self, test, sock, prog, timeout):
threading.Thread.__init__(self, None, None, 'test-client')
self.daemon = True
self._timeout = timeout
self._sock = sock
self._active = True
self._prog = prog
self._test = test
def run(self):
try:
self._prog(TestSocketWrapper(self._sock))
except Exception as ex:
self._test._abort_socket_test(ex)
class TestThreadedServer(SocketThread):
def __init__(self, test, sock, prog, timeout, max_clients):
threading.Thread.__init__(self, None, None, 'test-server')
self.daemon = True
self._clients = 0
self._finished_clients = 0
self._max_clients = max_clients
self._timeout = timeout
self._sock = sock
self._active = True
self._prog = prog
self._s1, self._s2 = socket.socketpair()
self._s1.setblocking(False)
self._test = test
def stop(self):
try:
if self._s2 and self._s2.fileno() != -1:
try:
self._s2.send(b'stop')
except OSError:
pass
finally:
super().stop()
def run(self):
try:
with self._sock:
self._sock.setblocking(0)
self._run()
finally:
self._s1.close()
self._s2.close()
def _run(self):
while self._active:
if self._clients >= self._max_clients:
return
r, w, x = select.select(
[self._sock, self._s1], [], [], self._timeout)
if self._s1 in r:
return
if self._sock in r:
try:
conn, addr = self._sock.accept()
except BlockingIOError:
continue
except socket.timeout:
if not self._active:
return
else:
raise
else:
self._clients += 1
conn.settimeout(self._timeout)
try:
with conn:
self._handle_client(conn)
except Exception as ex:
self._active = False
try:
raise
finally:
self._test._abort_socket_test(ex)
def _handle_client(self, sock):
self._prog(TestSocketWrapper(sock))
@property
def addr(self):
return self._sock.getsockname()