| # Copyright 2019 The gRPC Authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import asyncio |
| import gc |
| import logging |
| import time |
| import unittest |
| |
| import grpc |
| from grpc.experimental import aio |
| |
| from tests.unit.framework.common import test_constants |
| from tests_aio.unit._test_base import AioTestBase |
| |
| _SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' |
| _BLOCK_FOREVER = '/test/BlockForever' |
| _BLOCK_BRIEFLY = '/test/BlockBriefly' |
| _UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' |
| _UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' |
| _UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed' |
| _STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen' |
| _STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter' |
| _STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed' |
| _STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen' |
| _STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' |
| _STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' |
| |
| _REQUEST = b'\x00\x00\x00' |
| _RESPONSE = b'\x01\x01\x01' |
| _NUM_STREAM_REQUESTS = 3 |
| _NUM_STREAM_RESPONSES = 5 |
| |
| |
| class _GenericHandler(grpc.GenericRpcHandler): |
| |
| def __init__(self): |
| self._called = asyncio.get_event_loop().create_future() |
| self._routing_table = { |
| _SIMPLE_UNARY_UNARY: |
| grpc.unary_unary_rpc_method_handler(self._unary_unary), |
| _BLOCK_FOREVER: |
| grpc.unary_unary_rpc_method_handler(self._block_forever), |
| _BLOCK_BRIEFLY: |
| grpc.unary_unary_rpc_method_handler(self._block_briefly), |
| _UNARY_STREAM_ASYNC_GEN: |
| grpc.unary_stream_rpc_method_handler( |
| self._unary_stream_async_gen), |
| _UNARY_STREAM_READER_WRITER: |
| grpc.unary_stream_rpc_method_handler( |
| self._unary_stream_reader_writer), |
| _UNARY_STREAM_EVILLY_MIXED: |
| grpc.unary_stream_rpc_method_handler( |
| self._unary_stream_evilly_mixed), |
| _STREAM_UNARY_ASYNC_GEN: |
| grpc.stream_unary_rpc_method_handler( |
| self._stream_unary_async_gen), |
| _STREAM_UNARY_READER_WRITER: |
| grpc.stream_unary_rpc_method_handler( |
| self._stream_unary_reader_writer), |
| _STREAM_UNARY_EVILLY_MIXED: |
| grpc.stream_unary_rpc_method_handler( |
| self._stream_unary_evilly_mixed), |
| _STREAM_STREAM_ASYNC_GEN: |
| grpc.stream_stream_rpc_method_handler( |
| self._stream_stream_async_gen), |
| _STREAM_STREAM_READER_WRITER: |
| grpc.stream_stream_rpc_method_handler( |
| self._stream_stream_reader_writer), |
| _STREAM_STREAM_EVILLY_MIXED: |
| grpc.stream_stream_rpc_method_handler( |
| self._stream_stream_evilly_mixed), |
| } |
| |
| @staticmethod |
| async def _unary_unary(unused_request, unused_context): |
| return _RESPONSE |
| |
| async def _block_forever(self, unused_request, unused_context): |
| await asyncio.get_event_loop().create_future() |
| |
| async def _block_briefly(self, unused_request, unused_context): |
| await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2) |
| return _RESPONSE |
| |
| async def _unary_stream_async_gen(self, unused_request, unused_context): |
| for _ in range(_NUM_STREAM_RESPONSES): |
| yield _RESPONSE |
| |
| async def _unary_stream_reader_writer(self, unused_request, context): |
| for _ in range(_NUM_STREAM_RESPONSES): |
| await context.write(_RESPONSE) |
| |
| async def _unary_stream_evilly_mixed(self, unused_request, context): |
| yield _RESPONSE |
| for _ in range(_NUM_STREAM_RESPONSES - 1): |
| await context.write(_RESPONSE) |
| |
| async def _stream_unary_async_gen(self, request_iterator, unused_context): |
| request_count = 0 |
| async for request in request_iterator: |
| assert _REQUEST == request |
| request_count += 1 |
| assert _NUM_STREAM_REQUESTS == request_count |
| return _RESPONSE |
| |
| async def _stream_unary_reader_writer(self, unused_request, context): |
| for _ in range(_NUM_STREAM_REQUESTS): |
| assert _REQUEST == await context.read() |
| return _RESPONSE |
| |
| async def _stream_unary_evilly_mixed(self, request_iterator, context): |
| assert _REQUEST == await context.read() |
| request_count = 0 |
| async for request in request_iterator: |
| assert _REQUEST == request |
| request_count += 1 |
| assert _NUM_STREAM_REQUESTS - 1 == request_count |
| return _RESPONSE |
| |
| async def _stream_stream_async_gen(self, request_iterator, unused_context): |
| request_count = 0 |
| async for request in request_iterator: |
| assert _REQUEST == request |
| request_count += 1 |
| assert _NUM_STREAM_REQUESTS == request_count |
| |
| for _ in range(_NUM_STREAM_RESPONSES): |
| yield _RESPONSE |
| |
| async def _stream_stream_reader_writer(self, unused_request, context): |
| for _ in range(_NUM_STREAM_REQUESTS): |
| assert _REQUEST == await context.read() |
| for _ in range(_NUM_STREAM_RESPONSES): |
| await context.write(_RESPONSE) |
| |
| async def _stream_stream_evilly_mixed(self, request_iterator, context): |
| assert _REQUEST == await context.read() |
| request_count = 0 |
| async for request in request_iterator: |
| assert _REQUEST == request |
| request_count += 1 |
| assert _NUM_STREAM_REQUESTS - 1 == request_count |
| |
| yield _RESPONSE |
| for _ in range(_NUM_STREAM_RESPONSES - 1): |
| await context.write(_RESPONSE) |
| |
| def service(self, handler_details): |
| self._called.set_result(None) |
| return self._routing_table[handler_details.method] |
| |
| async def wait_for_call(self): |
| await self._called |
| |
| |
| async def _start_test_server(): |
| server = aio.server() |
| port = server.add_insecure_port('[::]:0') |
| generic_handler = _GenericHandler() |
| server.add_generic_rpc_handlers((generic_handler,)) |
| await server.start() |
| return 'localhost:%d' % port, server, generic_handler |
| |
| |
| class TestServer(AioTestBase): |
| |
| async def setUp(self): |
| addr, self._server, self._generic_handler = await _start_test_server() |
| self._channel = aio.insecure_channel(addr) |
| |
| async def tearDown(self): |
| await self._channel.close() |
| await self._server.stop(None) |
| |
| async def test_unary_unary(self): |
| unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY) |
| response = await unary_unary_call(_REQUEST) |
| self.assertEqual(response, _RESPONSE) |
| |
| async def test_unary_stream_async_generator(self): |
| unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) |
| call = unary_stream_call(_REQUEST) |
| |
| response_cnt = 0 |
| async for response in call: |
| response_cnt += 1 |
| self.assertEqual(_RESPONSE, response) |
| |
| self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_unary_stream_reader_writer(self): |
| unary_stream_call = self._channel.unary_stream( |
| _UNARY_STREAM_READER_WRITER) |
| call = unary_stream_call(_REQUEST) |
| |
| for _ in range(_NUM_STREAM_RESPONSES): |
| response = await call.read() |
| self.assertEqual(_RESPONSE, response) |
| |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_unary_stream_evilly_mixed(self): |
| unary_stream_call = self._channel.unary_stream( |
| _UNARY_STREAM_EVILLY_MIXED) |
| call = unary_stream_call(_REQUEST) |
| |
| # Uses reader API |
| self.assertEqual(_RESPONSE, await call.read()) |
| |
| # Uses async generator API |
| response_cnt = 0 |
| async for response in call: |
| response_cnt += 1 |
| self.assertEqual(_RESPONSE, response) |
| |
| self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt) |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_stream_unary_async_generator(self): |
| stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) |
| call = stream_unary_call() |
| |
| for _ in range(_NUM_STREAM_REQUESTS): |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| response = await call |
| self.assertEqual(_RESPONSE, response) |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_stream_unary_reader_writer(self): |
| stream_unary_call = self._channel.stream_unary( |
| _STREAM_UNARY_READER_WRITER) |
| call = stream_unary_call() |
| |
| for _ in range(_NUM_STREAM_REQUESTS): |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| response = await call |
| self.assertEqual(_RESPONSE, response) |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_stream_unary_evilly_mixed(self): |
| stream_unary_call = self._channel.stream_unary( |
| _STREAM_UNARY_EVILLY_MIXED) |
| call = stream_unary_call() |
| |
| for _ in range(_NUM_STREAM_REQUESTS): |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| response = await call |
| self.assertEqual(_RESPONSE, response) |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_stream_stream_async_generator(self): |
| stream_stream_call = self._channel.stream_stream( |
| _STREAM_STREAM_ASYNC_GEN) |
| call = stream_stream_call() |
| |
| for _ in range(_NUM_STREAM_REQUESTS): |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| for _ in range(_NUM_STREAM_RESPONSES): |
| response = await call.read() |
| self.assertEqual(_RESPONSE, response) |
| |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_stream_stream_reader_writer(self): |
| stream_stream_call = self._channel.stream_stream( |
| _STREAM_STREAM_READER_WRITER) |
| call = stream_stream_call() |
| |
| for _ in range(_NUM_STREAM_REQUESTS): |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| for _ in range(_NUM_STREAM_RESPONSES): |
| response = await call.read() |
| self.assertEqual(_RESPONSE, response) |
| |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_stream_stream_evilly_mixed(self): |
| stream_stream_call = self._channel.stream_stream( |
| _STREAM_STREAM_EVILLY_MIXED) |
| call = stream_stream_call() |
| |
| for _ in range(_NUM_STREAM_REQUESTS): |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| for _ in range(_NUM_STREAM_RESPONSES): |
| response = await call.read() |
| self.assertEqual(_RESPONSE, response) |
| |
| self.assertEqual(await call.code(), grpc.StatusCode.OK) |
| |
| async def test_shutdown(self): |
| await self._server.stop(None) |
| # Ensures no SIGSEGV triggered, and ends within timeout. |
| |
| async def test_shutdown_after_call(self): |
| await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) |
| |
| await self._server.stop(None) |
| |
| async def test_graceful_shutdown_success(self): |
| call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) |
| await self._generic_handler.wait_for_call() |
| |
| shutdown_start_time = time.time() |
| await self._server.stop(test_constants.SHORT_TIMEOUT) |
| grace_period_length = time.time() - shutdown_start_time |
| self.assertGreater(grace_period_length, |
| test_constants.SHORT_TIMEOUT / 3) |
| |
| # Validates the states. |
| self.assertEqual(_RESPONSE, await call) |
| self.assertTrue(call.done()) |
| |
| async def test_graceful_shutdown_failed(self): |
| call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) |
| await self._generic_handler.wait_for_call() |
| |
| await self._server.stop(test_constants.SHORT_TIMEOUT) |
| |
| with self.assertRaises(grpc.RpcError) as exception_context: |
| await call |
| self.assertEqual(grpc.StatusCode.UNAVAILABLE, |
| exception_context.exception.code()) |
| self.assertIn('GOAWAY', exception_context.exception.details()) |
| |
| async def test_concurrent_graceful_shutdown(self): |
| call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) |
| await self._generic_handler.wait_for_call() |
| |
| # Expects the shortest grace period to be effective. |
| shutdown_start_time = time.time() |
| await asyncio.gather( |
| self._server.stop(test_constants.LONG_TIMEOUT), |
| self._server.stop(test_constants.SHORT_TIMEOUT), |
| self._server.stop(test_constants.LONG_TIMEOUT), |
| ) |
| grace_period_length = time.time() - shutdown_start_time |
| self.assertGreater(grace_period_length, |
| test_constants.SHORT_TIMEOUT / 3) |
| |
| self.assertEqual(_RESPONSE, await call) |
| self.assertTrue(call.done()) |
| |
| async def test_concurrent_graceful_shutdown_immediate(self): |
| call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) |
| await self._generic_handler.wait_for_call() |
| |
| # Expects no grace period, due to the "server.stop(None)". |
| await asyncio.gather( |
| self._server.stop(test_constants.LONG_TIMEOUT), |
| self._server.stop(None), |
| self._server.stop(test_constants.SHORT_TIMEOUT), |
| self._server.stop(test_constants.LONG_TIMEOUT), |
| ) |
| |
| with self.assertRaises(grpc.RpcError) as exception_context: |
| await call |
| self.assertEqual(grpc.StatusCode.UNAVAILABLE, |
| exception_context.exception.code()) |
| self.assertIn('GOAWAY', exception_context.exception.details()) |
| |
| @unittest.skip('https://github.com/grpc/grpc/issues/20818') |
| async def test_shutdown_before_call(self): |
| server_target, server, _ = _start_test_server() |
| await server.stop(None) |
| |
| # Ensures the server is cleaned up at this point. |
| # Some proper exception should be raised. |
| async with aio.insecure_channel('localhost:%d' % port) as channel: |
| await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) |
| |
| |
| if __name__ == '__main__': |
| logging.basicConfig(level=logging.DEBUG) |
| unittest.main(verbosity=2) |