|  | """ | 
|  | Test suite for socketserver. | 
|  | """ | 
|  |  | 
|  | import contextlib | 
|  | import io | 
|  | import os | 
|  | import select | 
|  | import signal | 
|  | import socket | 
|  | import tempfile | 
|  | import threading | 
|  | import unittest | 
|  | import socketserver | 
|  |  | 
|  | import test.support | 
|  | from test.support import reap_children, verbose | 
|  | from test.support import os_helper | 
|  | from test.support import socket_helper | 
|  | from test.support import threading_helper | 
|  |  | 
|  |  | 
|  | test.support.requires("network") | 
|  |  | 
|  | TEST_STR = b"hello world\n" | 
|  | HOST = socket_helper.HOST | 
|  |  | 
|  | HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") | 
|  | requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, | 
|  | 'requires Unix sockets') | 
|  | HAVE_FORKING = hasattr(os, "fork") | 
|  | requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') | 
|  |  | 
|  | def signal_alarm(n): | 
|  | """Call signal.alarm when it exists (i.e. not on Windows).""" | 
|  | if hasattr(signal, 'alarm'): | 
|  | signal.alarm(n) | 
|  |  | 
|  | # Remember real select() to avoid interferences with mocking | 
|  | _real_select = select.select | 
|  |  | 
|  | def receive(sock, n, timeout=test.support.SHORT_TIMEOUT): | 
|  | r, w, x = _real_select([sock], [], [], timeout) | 
|  | if sock in r: | 
|  | return sock.recv(n) | 
|  | else: | 
|  | raise RuntimeError("timed out on %r" % (sock,)) | 
|  |  | 
|  | if HAVE_UNIX_SOCKETS and HAVE_FORKING: | 
|  | class ForkingUnixStreamServer(socketserver.ForkingMixIn, | 
|  | socketserver.UnixStreamServer): | 
|  | pass | 
|  |  | 
|  | class ForkingUnixDatagramServer(socketserver.ForkingMixIn, | 
|  | socketserver.UnixDatagramServer): | 
|  | pass | 
|  |  | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def simple_subprocess(testcase): | 
|  | """Tests that a custom child process is not waited on (Issue 1540386)""" | 
|  | pid = os.fork() | 
|  | if pid == 0: | 
|  | # Don't raise an exception; it would be caught by the test harness. | 
|  | os._exit(72) | 
|  | try: | 
|  | yield None | 
|  | except: | 
|  | raise | 
|  | finally: | 
|  | test.support.wait_process(pid, exitcode=72) | 
|  |  | 
|  |  | 
|  | class SocketServerTest(unittest.TestCase): | 
|  | """Test all socket servers.""" | 
|  |  | 
|  | def setUp(self): | 
|  | signal_alarm(60)  # Kill deadlocks after 60 seconds. | 
|  | self.port_seed = 0 | 
|  | self.test_files = [] | 
|  |  | 
|  | def tearDown(self): | 
|  | signal_alarm(0)  # Didn't deadlock. | 
|  | reap_children() | 
|  |  | 
|  | for fn in self.test_files: | 
|  | try: | 
|  | os.remove(fn) | 
|  | except OSError: | 
|  | pass | 
|  | self.test_files[:] = [] | 
|  |  | 
|  | def pickaddr(self, proto): | 
|  | if proto == socket.AF_INET: | 
|  | return (HOST, 0) | 
|  | else: | 
|  | # XXX: We need a way to tell AF_UNIX to pick its own name | 
|  | # like AF_INET provides port==0. | 
|  | dir = None | 
|  | fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) | 
|  | self.test_files.append(fn) | 
|  | return fn | 
|  |  | 
|  | def make_server(self, addr, svrcls, hdlrbase): | 
|  | class MyServer(svrcls): | 
|  | def handle_error(self, request, client_address): | 
|  | self.close_request(request) | 
|  | raise | 
|  |  | 
|  | class MyHandler(hdlrbase): | 
|  | def handle(self): | 
|  | line = self.rfile.readline() | 
|  | self.wfile.write(line) | 
|  |  | 
|  | if verbose: print("creating server") | 
|  | try: | 
|  | server = MyServer(addr, MyHandler) | 
|  | except PermissionError as e: | 
|  | # Issue 29184: cannot bind() a Unix socket on Android. | 
|  | self.skipTest('Cannot create server (%s, %s): %s' % | 
|  | (svrcls, addr, e)) | 
|  | self.assertEqual(server.server_address, server.socket.getsockname()) | 
|  | return server | 
|  |  | 
|  | @threading_helper.reap_threads | 
|  | def run_server(self, svrcls, hdlrbase, testfunc): | 
|  | server = self.make_server(self.pickaddr(svrcls.address_family), | 
|  | svrcls, hdlrbase) | 
|  | # We had the OS pick a port, so pull the real address out of | 
|  | # the server. | 
|  | addr = server.server_address | 
|  | if verbose: | 
|  | print("ADDR =", addr) | 
|  | print("CLASS =", svrcls) | 
|  |  | 
|  | t = threading.Thread( | 
|  | name='%s serving' % svrcls, | 
|  | target=server.serve_forever, | 
|  | # Short poll interval to make the test finish quickly. | 
|  | # Time between requests is short enough that we won't wake | 
|  | # up spuriously too many times. | 
|  | kwargs={'poll_interval':0.01}) | 
|  | t.daemon = True  # In case this function raises. | 
|  | t.start() | 
|  | if verbose: print("server running") | 
|  | for i in range(3): | 
|  | if verbose: print("test client", i) | 
|  | testfunc(svrcls.address_family, addr) | 
|  | if verbose: print("waiting for server") | 
|  | server.shutdown() | 
|  | t.join() | 
|  | server.server_close() | 
|  | self.assertEqual(-1, server.socket.fileno()) | 
|  | if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): | 
|  | # bpo-31151: Check that ForkingMixIn.server_close() waits until | 
|  | # all children completed | 
|  | self.assertFalse(server.active_children) | 
|  | if verbose: print("done") | 
|  |  | 
|  | def stream_examine(self, proto, addr): | 
|  | with socket.socket(proto, socket.SOCK_STREAM) as s: | 
|  | s.connect(addr) | 
|  | s.sendall(TEST_STR) | 
|  | buf = data = receive(s, 100) | 
|  | while data and b'\n' not in buf: | 
|  | data = receive(s, 100) | 
|  | buf += data | 
|  | self.assertEqual(buf, TEST_STR) | 
|  |  | 
|  | def dgram_examine(self, proto, addr): | 
|  | with socket.socket(proto, socket.SOCK_DGRAM) as s: | 
|  | if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: | 
|  | s.bind(self.pickaddr(proto)) | 
|  | s.sendto(TEST_STR, addr) | 
|  | buf = data = receive(s, 100) | 
|  | while data and b'\n' not in buf: | 
|  | data = receive(s, 100) | 
|  | buf += data | 
|  | self.assertEqual(buf, TEST_STR) | 
|  |  | 
|  | def test_TCPServer(self): | 
|  | self.run_server(socketserver.TCPServer, | 
|  | socketserver.StreamRequestHandler, | 
|  | self.stream_examine) | 
|  |  | 
|  | def test_ThreadingTCPServer(self): | 
|  | self.run_server(socketserver.ThreadingTCPServer, | 
|  | socketserver.StreamRequestHandler, | 
|  | self.stream_examine) | 
|  |  | 
|  | @requires_forking | 
|  | def test_ForkingTCPServer(self): | 
|  | with simple_subprocess(self): | 
|  | self.run_server(socketserver.ForkingTCPServer, | 
|  | socketserver.StreamRequestHandler, | 
|  | self.stream_examine) | 
|  |  | 
|  | @requires_unix_sockets | 
|  | def test_UnixStreamServer(self): | 
|  | self.run_server(socketserver.UnixStreamServer, | 
|  | socketserver.StreamRequestHandler, | 
|  | self.stream_examine) | 
|  |  | 
|  | @requires_unix_sockets | 
|  | def test_ThreadingUnixStreamServer(self): | 
|  | self.run_server(socketserver.ThreadingUnixStreamServer, | 
|  | socketserver.StreamRequestHandler, | 
|  | self.stream_examine) | 
|  |  | 
|  | @requires_unix_sockets | 
|  | @requires_forking | 
|  | def test_ForkingUnixStreamServer(self): | 
|  | with simple_subprocess(self): | 
|  | self.run_server(ForkingUnixStreamServer, | 
|  | socketserver.StreamRequestHandler, | 
|  | self.stream_examine) | 
|  |  | 
|  | def test_UDPServer(self): | 
|  | self.run_server(socketserver.UDPServer, | 
|  | socketserver.DatagramRequestHandler, | 
|  | self.dgram_examine) | 
|  |  | 
|  | def test_ThreadingUDPServer(self): | 
|  | self.run_server(socketserver.ThreadingUDPServer, | 
|  | socketserver.DatagramRequestHandler, | 
|  | self.dgram_examine) | 
|  |  | 
|  | @requires_forking | 
|  | def test_ForkingUDPServer(self): | 
|  | with simple_subprocess(self): | 
|  | self.run_server(socketserver.ForkingUDPServer, | 
|  | socketserver.DatagramRequestHandler, | 
|  | self.dgram_examine) | 
|  |  | 
|  | @requires_unix_sockets | 
|  | def test_UnixDatagramServer(self): | 
|  | self.run_server(socketserver.UnixDatagramServer, | 
|  | socketserver.DatagramRequestHandler, | 
|  | self.dgram_examine) | 
|  |  | 
|  | @requires_unix_sockets | 
|  | def test_ThreadingUnixDatagramServer(self): | 
|  | self.run_server(socketserver.ThreadingUnixDatagramServer, | 
|  | socketserver.DatagramRequestHandler, | 
|  | self.dgram_examine) | 
|  |  | 
|  | @requires_unix_sockets | 
|  | @requires_forking | 
|  | def test_ForkingUnixDatagramServer(self): | 
|  | self.run_server(ForkingUnixDatagramServer, | 
|  | socketserver.DatagramRequestHandler, | 
|  | self.dgram_examine) | 
|  |  | 
|  | @threading_helper.reap_threads | 
|  | def test_shutdown(self): | 
|  | # Issue #2302: shutdown() should always succeed in making an | 
|  | # other thread leave serve_forever(). | 
|  | class MyServer(socketserver.TCPServer): | 
|  | pass | 
|  |  | 
|  | class MyHandler(socketserver.StreamRequestHandler): | 
|  | pass | 
|  |  | 
|  | threads = [] | 
|  | for i in range(20): | 
|  | s = MyServer((HOST, 0), MyHandler) | 
|  | t = threading.Thread( | 
|  | name='MyServer serving', | 
|  | target=s.serve_forever, | 
|  | kwargs={'poll_interval':0.01}) | 
|  | t.daemon = True  # In case this function raises. | 
|  | threads.append((t, s)) | 
|  | for t, s in threads: | 
|  | t.start() | 
|  | s.shutdown() | 
|  | for t, s in threads: | 
|  | t.join() | 
|  | s.server_close() | 
|  |  | 
|  | def test_tcpserver_bind_leak(self): | 
|  | # Issue #22435: the server socket wouldn't be closed if bind()/listen() | 
|  | # failed. | 
|  | # Create many servers for which bind() will fail, to see if this result | 
|  | # in FD exhaustion. | 
|  | for i in range(1024): | 
|  | with self.assertRaises(OverflowError): | 
|  | socketserver.TCPServer((HOST, -1), | 
|  | socketserver.StreamRequestHandler) | 
|  |  | 
|  | def test_context_manager(self): | 
|  | with socketserver.TCPServer((HOST, 0), | 
|  | socketserver.StreamRequestHandler) as server: | 
|  | pass | 
|  | self.assertEqual(-1, server.socket.fileno()) | 
|  |  | 
|  |  | 
|  | class ErrorHandlerTest(unittest.TestCase): | 
|  | """Test that the servers pass normal exceptions from the handler to | 
|  | handle_error(), and that exiting exceptions like SystemExit and | 
|  | KeyboardInterrupt are not passed.""" | 
|  |  | 
|  | def tearDown(self): | 
|  | os_helper.unlink(os_helper.TESTFN) | 
|  |  | 
|  | def test_sync_handled(self): | 
|  | BaseErrorTestServer(ValueError) | 
|  | self.check_result(handled=True) | 
|  |  | 
|  | def test_sync_not_handled(self): | 
|  | with self.assertRaises(SystemExit): | 
|  | BaseErrorTestServer(SystemExit) | 
|  | self.check_result(handled=False) | 
|  |  | 
|  | def test_threading_handled(self): | 
|  | ThreadingErrorTestServer(ValueError) | 
|  | self.check_result(handled=True) | 
|  |  | 
|  | def test_threading_not_handled(self): | 
|  | ThreadingErrorTestServer(SystemExit) | 
|  | self.check_result(handled=False) | 
|  |  | 
|  | @requires_forking | 
|  | def test_forking_handled(self): | 
|  | ForkingErrorTestServer(ValueError) | 
|  | self.check_result(handled=True) | 
|  |  | 
|  | @requires_forking | 
|  | def test_forking_not_handled(self): | 
|  | ForkingErrorTestServer(SystemExit) | 
|  | self.check_result(handled=False) | 
|  |  | 
|  | def check_result(self, handled): | 
|  | with open(os_helper.TESTFN) as log: | 
|  | expected = 'Handler called\n' + 'Error handled\n' * handled | 
|  | self.assertEqual(log.read(), expected) | 
|  |  | 
|  |  | 
|  | class BaseErrorTestServer(socketserver.TCPServer): | 
|  | def __init__(self, exception): | 
|  | self.exception = exception | 
|  | super().__init__((HOST, 0), BadHandler) | 
|  | with socket.create_connection(self.server_address): | 
|  | pass | 
|  | try: | 
|  | self.handle_request() | 
|  | finally: | 
|  | self.server_close() | 
|  | self.wait_done() | 
|  |  | 
|  | def handle_error(self, request, client_address): | 
|  | with open(os_helper.TESTFN, 'a') as log: | 
|  | log.write('Error handled\n') | 
|  |  | 
|  | def wait_done(self): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class BadHandler(socketserver.BaseRequestHandler): | 
|  | def handle(self): | 
|  | with open(os_helper.TESTFN, 'a') as log: | 
|  | log.write('Handler called\n') | 
|  | raise self.server.exception('Test error') | 
|  |  | 
|  |  | 
|  | class ThreadingErrorTestServer(socketserver.ThreadingMixIn, | 
|  | BaseErrorTestServer): | 
|  | def __init__(self, *pos, **kw): | 
|  | self.done = threading.Event() | 
|  | super().__init__(*pos, **kw) | 
|  |  | 
|  | def shutdown_request(self, *pos, **kw): | 
|  | super().shutdown_request(*pos, **kw) | 
|  | self.done.set() | 
|  |  | 
|  | def wait_done(self): | 
|  | self.done.wait() | 
|  |  | 
|  |  | 
|  | if HAVE_FORKING: | 
|  | class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class SocketWriterTest(unittest.TestCase): | 
|  | def test_basics(self): | 
|  | class Handler(socketserver.StreamRequestHandler): | 
|  | def handle(self): | 
|  | self.server.wfile = self.wfile | 
|  | self.server.wfile_fileno = self.wfile.fileno() | 
|  | self.server.request_fileno = self.request.fileno() | 
|  |  | 
|  | server = socketserver.TCPServer((HOST, 0), Handler) | 
|  | self.addCleanup(server.server_close) | 
|  | s = socket.socket( | 
|  | server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) | 
|  | with s: | 
|  | s.connect(server.server_address) | 
|  | server.handle_request() | 
|  | self.assertIsInstance(server.wfile, io.BufferedIOBase) | 
|  | self.assertEqual(server.wfile_fileno, server.request_fileno) | 
|  |  | 
|  | def test_write(self): | 
|  | # Test that wfile.write() sends data immediately, and that it does | 
|  | # not truncate sends when interrupted by a Unix signal | 
|  | pthread_kill = test.support.get_attribute(signal, 'pthread_kill') | 
|  |  | 
|  | class Handler(socketserver.StreamRequestHandler): | 
|  | def handle(self): | 
|  | self.server.sent1 = self.wfile.write(b'write data\n') | 
|  | # Should be sent immediately, without requiring flush() | 
|  | self.server.received = self.rfile.readline() | 
|  | big_chunk = b'\0' * test.support.SOCK_MAX_SIZE | 
|  | self.server.sent2 = self.wfile.write(big_chunk) | 
|  |  | 
|  | server = socketserver.TCPServer((HOST, 0), Handler) | 
|  | self.addCleanup(server.server_close) | 
|  | interrupted = threading.Event() | 
|  |  | 
|  | def signal_handler(signum, frame): | 
|  | interrupted.set() | 
|  |  | 
|  | original = signal.signal(signal.SIGUSR1, signal_handler) | 
|  | self.addCleanup(signal.signal, signal.SIGUSR1, original) | 
|  | response1 = None | 
|  | received2 = None | 
|  | main_thread = threading.get_ident() | 
|  |  | 
|  | def run_client(): | 
|  | s = socket.socket(server.address_family, socket.SOCK_STREAM, | 
|  | socket.IPPROTO_TCP) | 
|  | with s, s.makefile('rb') as reader: | 
|  | s.connect(server.server_address) | 
|  | nonlocal response1 | 
|  | response1 = reader.readline() | 
|  | s.sendall(b'client response\n') | 
|  |  | 
|  | reader.read(100) | 
|  | # The main thread should now be blocking in a send() syscall. | 
|  | # But in theory, it could get interrupted by other signals, | 
|  | # and then retried. So keep sending the signal in a loop, in | 
|  | # case an earlier signal happens to be delivered at an | 
|  | # inconvenient moment. | 
|  | while True: | 
|  | pthread_kill(main_thread, signal.SIGUSR1) | 
|  | if interrupted.wait(timeout=float(1)): | 
|  | break | 
|  | nonlocal received2 | 
|  | received2 = len(reader.read()) | 
|  |  | 
|  | background = threading.Thread(target=run_client) | 
|  | background.start() | 
|  | server.handle_request() | 
|  | background.join() | 
|  | self.assertEqual(server.sent1, len(response1)) | 
|  | self.assertEqual(response1, b'write data\n') | 
|  | self.assertEqual(server.received, b'client response\n') | 
|  | self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) | 
|  | self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) | 
|  |  | 
|  |  | 
|  | class MiscTestCase(unittest.TestCase): | 
|  |  | 
|  | def test_all(self): | 
|  | # objects defined in the module should be in __all__ | 
|  | expected = [] | 
|  | for name in dir(socketserver): | 
|  | if not name.startswith('_'): | 
|  | mod_object = getattr(socketserver, name) | 
|  | if getattr(mod_object, '__module__', None) == 'socketserver': | 
|  | expected.append(name) | 
|  | self.assertCountEqual(socketserver.__all__, expected) | 
|  |  | 
|  | def test_shutdown_request_called_if_verify_request_false(self): | 
|  | # Issue #26309: BaseServer should call shutdown_request even if | 
|  | # verify_request is False | 
|  |  | 
|  | class MyServer(socketserver.TCPServer): | 
|  | def verify_request(self, request, client_address): | 
|  | return False | 
|  |  | 
|  | shutdown_called = 0 | 
|  | def shutdown_request(self, request): | 
|  | self.shutdown_called += 1 | 
|  | socketserver.TCPServer.shutdown_request(self, request) | 
|  |  | 
|  | server = MyServer((HOST, 0), socketserver.StreamRequestHandler) | 
|  | s = socket.socket(server.address_family, socket.SOCK_STREAM) | 
|  | s.connect(server.server_address) | 
|  | s.close() | 
|  | server.handle_request() | 
|  | self.assertEqual(server.shutdown_called, 1) | 
|  | server.server_close() | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | unittest.main() |