| #!/usr/bin/python3 |
| # |
| # Copyright 2015 Google Inc. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| """Tests for the example portserver.""" |
| |
| import asyncio |
| import os |
| import signal |
| import socket |
| import subprocess |
| import sys |
| import time |
| import unittest |
| from unittest import mock |
| from multiprocessing import Process |
| |
| import portpicker |
| |
| # On Windows, portserver.py is located in the "Scripts" folder, which isn't |
| # added to the import path by default |
| if sys.platform == 'win32': |
| sys.path.append(os.path.join(os.path.split(sys.executable)[0])) |
| |
| import portserver |
| |
| |
| def setUpModule(): |
| portserver._configure_logging(verbose=True) |
| |
| def exit_immediately(): |
| os._exit(0) |
| |
| class PortserverFunctionsTest(unittest.TestCase): |
| |
| @classmethod |
| def setUp(cls): |
| cls.port = portpicker.PickUnusedPort() |
| |
| def test_get_process_command_line(self): |
| portserver._get_process_command_line(os.getpid()) |
| |
| def test_get_process_start_time(self): |
| self.assertGreater(portserver._get_process_start_time(os.getpid()), 0) |
| |
| def test_is_port_free(self): |
| """This might be flaky unless this test is run with a portserver.""" |
| # The port should be free initially. |
| self.assertTrue(portserver._is_port_free(self.port)) |
| |
| cases = [ |
| (socket.AF_INET, socket.SOCK_STREAM, None), |
| (socket.AF_INET6, socket.SOCK_STREAM, 1), |
| (socket.AF_INET, socket.SOCK_DGRAM, None), |
| (socket.AF_INET6, socket.SOCK_DGRAM, 1), |
| ] |
| |
| # Using v6only=0 on Windows doesn't result in collisions |
| if sys.platform != 'win32': |
| cases.extend([ |
| (socket.AF_INET6, socket.SOCK_STREAM, 0), |
| (socket.AF_INET6, socket.SOCK_DGRAM, 0), |
| ]) |
| |
| for (sock_family, sock_type, v6only) in cases: |
| # Occupy the port on a subset of possible protocols. |
| try: |
| sock = socket.socket(sock_family, sock_type, 0) |
| except socket.error: |
| print('Kernel does not support sock_family=%d' % sock_family, |
| file=sys.stderr) |
| # Skip this case, since we cannot occupy a port. |
| continue |
| |
| if not hasattr(socket, 'IPPROTO_IPV6'): |
| v6only = None |
| |
| if v6only is not None: |
| try: |
| sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, |
| v6only) |
| except socket.error: |
| print('Kernel does not support IPV6_V6ONLY=%d' % v6only, |
| file=sys.stderr) |
| # Don't care; just proceed with the default. |
| sock.bind(('', self.port)) |
| |
| # The port should be busy. |
| self.assertFalse(portserver._is_port_free(self.port)) |
| sock.close() |
| |
| # Now it's free again. |
| self.assertTrue(portserver._is_port_free(self.port)) |
| |
| def test_is_port_free_exception(self): |
| with mock.patch.object(socket, 'socket') as mock_sock: |
| mock_sock.side_effect = socket.error('fake socket error', 0) |
| self.assertFalse(portserver._is_port_free(self.port)) |
| |
| def test_should_allocate_port(self): |
| self.assertFalse(portserver._should_allocate_port(0)) |
| self.assertFalse(portserver._should_allocate_port(1)) |
| self.assertTrue(portserver._should_allocate_port, os.getpid()) |
| |
| p = Process(target=exit_immediately) |
| p.start() |
| child_pid = p.pid |
| p.join() |
| |
| # This test assumes that after waitpid returns the kernel has finished |
| # cleaning the process. We also assume that the kernel will not reuse |
| # the former child's pid before our next call checks for its existence. |
| # Likely assumptions, but not guaranteed. |
| self.assertFalse(portserver._should_allocate_port(child_pid)) |
| |
| def test_parse_command_line(self): |
| with mock.patch.object( |
| sys, 'argv', ['program_name', '--verbose', |
| '--portserver_static_pool=1-1,3-8', |
| '--portserver_unix_socket_address=@hello-test']): |
| portserver._parse_command_line() |
| |
| def test_parse_port_ranges(self): |
| self.assertFalse(portserver._parse_port_ranges('')) |
| self.assertCountEqual(portserver._parse_port_ranges('1-1'), {1}) |
| self.assertCountEqual(portserver._parse_port_ranges('1-1,3-8,375-378'), |
| {1, 3, 4, 5, 6, 7, 8, 375, 376, 377, 378}) |
| # Unparsable parts are logged but ignored. |
| self.assertEqual({1, 2}, |
| portserver._parse_port_ranges('1-2,not,numbers')) |
| self.assertEqual(set(), portserver._parse_port_ranges('8080-8081x')) |
| # Port ranges that go out of bounds are logged but ignored. |
| self.assertEqual(set(), portserver._parse_port_ranges('0-1138')) |
| self.assertEqual(set(range(19, 84 + 1)), |
| portserver._parse_port_ranges('1138-65536,19-84')) |
| |
| def test_configure_logging(self): |
| """Just code coverage really.""" |
| portserver._configure_logging(False) |
| portserver._configure_logging(True) |
| |
| |
| _test_socket_addr = f'@TST-{os.getpid()}' |
| |
| @mock.patch.object( |
| sys, 'argv', ['PortserverFunctionsTest.test_main', |
| f'--portserver_unix_socket_address={_test_socket_addr}'] |
| ) |
| @mock.patch.object(portserver, '_parse_port_ranges') |
| def test_main_no_ports(self, *unused_mocks): |
| portserver._parse_port_ranges.return_value = set() |
| with self.assertRaises(SystemExit): |
| portserver.main() |
| |
| @unittest.skipUnless(sys.executable, 'Requires a stand alone interpreter') |
| @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'AF_UNIX required') |
| def test_portserver_binary(self): |
| """Launch python portserver.py and test it.""" |
| # Blindly assuming tree layout is src/tests/portserver_test.py |
| # with src/portserver.py. |
| portserver_py = os.path.join( |
| os.path.dirname(os.path.dirname(__file__)), |
| 'portserver.py') |
| anon_addr = self._test_socket_addr.replace('@', '\0') |
| |
| conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) |
| with self.assertRaises( |
| ConnectionRefusedError, |
| msg=f'{self._test_socket_addr} should not listen yet.'): |
| conn.connect(anon_addr) |
| conn.close() |
| |
| server = subprocess.Popen( |
| [sys.executable, portserver_py, |
| f'--portserver_unix_socket_address={self._test_socket_addr}'], |
| stderr=subprocess.PIPE, |
| ) |
| try: |
| # Wait a few seconds for the server to start listening. |
| start_time = time.monotonic() |
| while True: |
| time.sleep(0.05) |
| try: |
| conn.connect(anon_addr) |
| conn.close() |
| except ConnectionRefusedError: |
| delta = time.monotonic() - start_time |
| if delta < 4: |
| continue |
| else: |
| server.kill() |
| self.fail('Failed to connect to portserver ' |
| f'{self._test_socket_addr} within ' |
| f'{delta} seconds. STDERR:\n' + |
| server.stderr.read().decode('utf-8')) |
| else: |
| break |
| |
| ports = set() |
| port = portpicker.get_port_from_port_server( |
| portserver_address=self._test_socket_addr) |
| ports.add(port) |
| port = portpicker.get_port_from_port_server( |
| portserver_address=self._test_socket_addr) |
| ports.add(port) |
| |
| with subprocess.Popen('exit 0', shell=True) as quick_process: |
| quick_process.wait() |
| # This process doesn't exist so it should be a denied alloc. |
| # We use the pid from the above quick_process under the assumption |
| # that most OSes try to avoid rapid pid recycling. |
| denied_port = portpicker.get_port_from_port_server( |
| portserver_address=self._test_socket_addr, |
| pid=quick_process.pid) # A now unused pid. |
| self.assertIsNone(denied_port) |
| |
| self.assertEqual(len(ports), 2, msg=ports) |
| |
| # Check statistics from portserver |
| server.send_signal(signal.SIGUSR1) |
| # TODO implement an I/O timeout |
| for line in server.stderr: |
| if b'denied-allocations ' in line: |
| denied_allocations = int( |
| line.split(b'denied-allocations ', 2)[1]) |
| self.assertEqual(1, denied_allocations, msg=line) |
| elif b'total-allocations ' in line: |
| total_allocations = int( |
| line.split(b'total-allocations ', 2)[1]) |
| self.assertEqual(2, total_allocations, msg=line) |
| break |
| |
| rejected_port = portpicker.get_port_from_port_server( |
| portserver_address=self._test_socket_addr, |
| pid=99999999999999999999999999999999999) # Out of range. |
| self.assertIsNone(rejected_port) |
| |
| # Done. shutdown gracefully. |
| server.send_signal(signal.SIGINT) |
| server.communicate(timeout=2) |
| finally: |
| server.kill() |
| server.wait() |
| |
| |
| class PortPoolTest(unittest.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.port = portpicker.PickUnusedPort() |
| |
| def setUp(self): |
| self.pool = portserver._PortPool() |
| |
| def test_initialization(self): |
| self.assertEqual(0, self.pool.num_ports()) |
| self.pool.add_port_to_free_pool(self.port) |
| self.assertEqual(1, self.pool.num_ports()) |
| self.pool.add_port_to_free_pool(1138) |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0) |
| self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| def test_get_port_for_process_ok(self, mock_is_port_free): |
| self.pool.add_port_to_free_pool(self.port) |
| mock_is_port_free.return_value = True |
| self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(1, self.pool.ports_checked_for_last_request) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| def test_get_port_for_process_none_left(self, mock_is_port_free): |
| self.pool.add_port_to_free_pool(self.port) |
| self.pool.add_port_to_free_pool(22) |
| mock_is_port_free.return_value = False |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(0, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(2, self.pool.ports_checked_for_last_request) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| @mock.patch.object(os, 'getpid') |
| def test_get_port_for_process_pid_eq_port(self, mock_getpid, mock_is_port_free): |
| self.pool.add_port_to_free_pool(12345) |
| self.pool.add_port_to_free_pool(12344) |
| mock_is_port_free.side_effect = lambda port: port == os.getpid() |
| mock_getpid.return_value = 12345 |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(12345, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(2, self.pool.ports_checked_for_last_request) |
| |
| @mock.patch.object(portserver, '_is_port_free') |
| @mock.patch.object(os, 'getpid') |
| def test_get_port_for_process_pid_ne_port(self, mock_getpid, mock_is_port_free): |
| self.pool.add_port_to_free_pool(12344) |
| self.pool.add_port_to_free_pool(12345) |
| mock_is_port_free.side_effect = lambda port: port != os.getpid() |
| mock_getpid.return_value = 12345 |
| self.assertEqual(2, self.pool.num_ports()) |
| self.assertEqual(12344, self.pool.get_port_for_process(os.getpid())) |
| self.assertEqual(2, self.pool.ports_checked_for_last_request) |
| |
| |
| @mock.patch.object(portserver, '_get_process_command_line') |
| @mock.patch.object(portserver, '_should_allocate_port') |
| @mock.patch.object(portserver._PortPool, 'get_port_for_process') |
| class PortServerRequestHandlerTest(unittest.TestCase): |
| def setUp(self): |
| portserver._configure_logging(verbose=True) |
| self.rh = portserver._PortServerRequestHandler([23, 42, 54]) |
| |
| def test_stats_reporting(self, *unused_mocks): |
| with mock.patch.object(portserver, 'log') as mock_logger: |
| self.rh.dump_stats() |
| mock_logger.info.assert_called_with('total-allocations 0') |
| |
| def test_handle_port_request_bad_data(self, *unused_mocks): |
| self._test_bad_data_from_client(b'') |
| self._test_bad_data_from_client(b'\n') |
| self._test_bad_data_from_client(b'99Z\n') |
| self._test_bad_data_from_client(b'99 8\n') |
| self.assertEqual([], portserver._get_process_command_line.mock_calls) |
| |
| def _test_bad_data_from_client(self, data): |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.rh._handle_port_request(data, mock_writer) |
| self.assertFalse(portserver._should_allocate_port.mock_calls) |
| |
| def test_handle_port_request_denied_allocation(self, *unused_mocks): |
| portserver._should_allocate_port.return_value = False |
| self.assertEqual(0, self.rh._denied_allocations) |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.rh._handle_port_request(b'5\n', mock_writer) |
| self.assertEqual(1, self.rh._denied_allocations) |
| |
| def test_handle_port_request_bad_port_returned(self, *unused_mocks): |
| portserver._should_allocate_port.return_value = True |
| self.rh._port_pool.get_port_for_process.return_value = 0 |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.rh._handle_port_request(b'6\n', mock_writer) |
| self.rh._port_pool.get_port_for_process.assert_called_once_with(6) |
| self.assertEqual(1, self.rh._denied_allocations) |
| |
| def test_handle_port_request_success(self, *unused_mocks): |
| portserver._should_allocate_port.return_value = True |
| self.rh._port_pool.get_port_for_process.return_value = 999 |
| mock_writer = mock.Mock(asyncio.StreamWriter) |
| self.assertEqual(0, self.rh._total_allocations) |
| self.rh._handle_port_request(b'8', mock_writer) |
| portserver._should_allocate_port.assert_called_once_with(8) |
| self.rh._port_pool.get_port_for_process.assert_called_once_with(8) |
| self.assertEqual(1, self.rh._total_allocations) |
| self.assertEqual(0, self.rh._denied_allocations) |
| mock_writer.write.assert_called_once_with(b'999\n') |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |