| # Copyright 2020 The gRPC Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Tests behavior around the metadata mechanism.""" |
| |
| import asyncio |
| import logging |
| import platform |
| import random |
| import unittest |
| |
| import grpc |
| from grpc.experimental import aio |
| |
| from tests_aio.unit._test_base import AioTestBase |
| from tests_aio.unit import _common |
| |
| _TEST_CLIENT_TO_SERVER = '/test/TestClientToServer' |
| _TEST_SERVER_TO_CLIENT = '/test/TestServerToClient' |
| _TEST_TRAILING_METADATA = '/test/TestTrailingMetadata' |
| _TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata' |
| _TEST_GENERIC_HANDLER = '/test/TestGenericHandler' |
| _TEST_UNARY_STREAM = '/test/TestUnaryStream' |
| _TEST_STREAM_UNARY = '/test/TestStreamUnary' |
| _TEST_STREAM_STREAM = '/test/TestStreamStream' |
| |
| _REQUEST = b'\x00\x00\x00' |
| _RESPONSE = b'\x01\x01\x01' |
| |
| _INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( |
| ('client-to-server', 'question'), |
| ('client-to-server-bin', b'\x07\x07\x07'), |
| ) |
| _INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( |
| ('server-to-client', 'answer'), |
| ('server-to-client-bin', b'\x06\x06\x06'), |
| ) |
| _TRAILING_METADATA = aio.Metadata( |
| ('a-trailing-metadata', 'stack-trace'), |
| ('a-trailing-metadata-bin', b'\x05\x05\x05'), |
| ) |
| _INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( |
| ('a-must-have-key', 'secret'),) |
| |
| _INVALID_METADATA_TEST_CASES = ( |
| ( |
| TypeError, |
| ((42, 42),), |
| ), |
| ( |
| TypeError, |
| (({}, {}),), |
| ), |
| ( |
| TypeError, |
| ((None, {}),), |
| ), |
| ( |
| TypeError, |
| (({}, {}),), |
| ), |
| ( |
| TypeError, |
| (('normal', object()),), |
| ), |
| ) |
| |
| |
| class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): |
| |
| def __init__(self): |
| self._routing_table = { |
| _TEST_CLIENT_TO_SERVER: |
| grpc.unary_unary_rpc_method_handler(self._test_client_to_server |
| ), |
| _TEST_SERVER_TO_CLIENT: |
| grpc.unary_unary_rpc_method_handler(self._test_server_to_client |
| ), |
| _TEST_TRAILING_METADATA: |
| grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata |
| ), |
| _TEST_UNARY_STREAM: |
| grpc.unary_stream_rpc_method_handler(self._test_unary_stream), |
| _TEST_STREAM_UNARY: |
| grpc.stream_unary_rpc_method_handler(self._test_stream_unary), |
| _TEST_STREAM_STREAM: |
| grpc.stream_stream_rpc_method_handler(self._test_stream_stream), |
| } |
| |
| @staticmethod |
| async def _test_client_to_server(request, context): |
| assert _REQUEST == request |
| assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, |
| context.invocation_metadata()) |
| return _RESPONSE |
| |
| @staticmethod |
| async def _test_server_to_client(request, context): |
| assert _REQUEST == request |
| await context.send_initial_metadata( |
| _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) |
| return _RESPONSE |
| |
| @staticmethod |
| async def _test_trailing_metadata(request, context): |
| assert _REQUEST == request |
| context.set_trailing_metadata(_TRAILING_METADATA) |
| return _RESPONSE |
| |
| @staticmethod |
| async def _test_unary_stream(request, context): |
| assert _REQUEST == request |
| assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, |
| context.invocation_metadata()) |
| await context.send_initial_metadata( |
| _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) |
| yield _RESPONSE |
| context.set_trailing_metadata(_TRAILING_METADATA) |
| |
| @staticmethod |
| async def _test_stream_unary(request_iterator, context): |
| assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, |
| context.invocation_metadata()) |
| await context.send_initial_metadata( |
| _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) |
| |
| async for request in request_iterator: |
| assert _REQUEST == request |
| |
| context.set_trailing_metadata(_TRAILING_METADATA) |
| return _RESPONSE |
| |
| @staticmethod |
| async def _test_stream_stream(request_iterator, context): |
| assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, |
| context.invocation_metadata()) |
| await context.send_initial_metadata( |
| _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) |
| |
| async for request in request_iterator: |
| assert _REQUEST == request |
| |
| yield _RESPONSE |
| context.set_trailing_metadata(_TRAILING_METADATA) |
| |
| def service(self, handler_call_details): |
| return self._routing_table.get(handler_call_details.method) |
| |
| |
| class _TestGenericHandlerItself(grpc.GenericRpcHandler): |
| |
| @staticmethod |
| async def _method(request, unused_context): |
| assert _REQUEST == request |
| return _RESPONSE |
| |
| def service(self, handler_call_details): |
| assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER, |
| handler_call_details.invocation_metadata) |
| return grpc.unary_unary_rpc_method_handler(self._method) |
| |
| |
| async def _start_test_server(): |
| server = aio.server() |
| port = server.add_insecure_port('[::]:0') |
| server.add_generic_rpc_handlers(( |
| _TestGenericHandlerForMethods(), |
| _TestGenericHandlerItself(), |
| )) |
| await server.start() |
| return 'localhost:%d' % port, server |
| |
| |
| class TestMetadata(AioTestBase): |
| |
| async def setUp(self): |
| address, self._server = await _start_test_server() |
| self._client = aio.insecure_channel(address) |
| |
| async def tearDown(self): |
| await self._client.close() |
| await self._server.stop(None) |
| |
| async def test_from_client_to_server(self): |
| multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) |
| call = multicallable(_REQUEST, |
| metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) |
| self.assertEqual(_RESPONSE, await call) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| async def test_from_server_to_client(self): |
| multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) |
| call = multicallable(_REQUEST) |
| |
| self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await |
| call.initial_metadata()) |
| self.assertEqual(_RESPONSE, await call) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| async def test_trailing_metadata(self): |
| multicallable = self._client.unary_unary(_TEST_TRAILING_METADATA) |
| call = multicallable(_REQUEST) |
| self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) |
| self.assertEqual(_RESPONSE, await call) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| async def test_from_client_to_server_with_list(self): |
| multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) |
| call = multicallable( |
| _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) # pytype: disable=wrong-arg-types |
| self.assertEqual(_RESPONSE, await call) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| @unittest.skipIf(platform.system() == 'Windows', |
| 'https://github.com/grpc/grpc/issues/21943') |
| async def test_invalid_metadata(self): |
| multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) |
| for exception_type, metadata in _INVALID_METADATA_TEST_CASES: |
| with self.subTest(metadata=metadata): |
| with self.assertRaises(exception_type): |
| call = multicallable(_REQUEST, metadata=metadata) |
| await call |
| |
| async def test_generic_handler(self): |
| multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER) |
| call = multicallable(_REQUEST, |
| metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER) |
| self.assertEqual(_RESPONSE, await call) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| async def test_unary_stream(self): |
| multicallable = self._client.unary_stream(_TEST_UNARY_STREAM) |
| call = multicallable(_REQUEST, |
| metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) |
| |
| self.assertTrue( |
| _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await |
| call.initial_metadata())) |
| |
| self.assertSequenceEqual([_RESPONSE], |
| [request async for request in call]) |
| |
| self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| async def test_stream_unary(self): |
| multicallable = self._client.stream_unary(_TEST_STREAM_UNARY) |
| call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| self.assertTrue( |
| _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await |
| call.initial_metadata())) |
| self.assertEqual(_RESPONSE, await call) |
| |
| self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| async def test_stream_stream(self): |
| multicallable = self._client.stream_stream(_TEST_STREAM_STREAM) |
| call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) |
| await call.write(_REQUEST) |
| await call.done_writing() |
| |
| self.assertTrue( |
| _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await |
| call.initial_metadata())) |
| self.assertSequenceEqual([_RESPONSE], |
| [request async for request in call]) |
| self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) |
| self.assertEqual(grpc.StatusCode.OK, await call.code()) |
| |
| async def test_compatibility_with_tuple(self): |
| metadata_obj = aio.Metadata(('key', '42'), ('key-2', 'value')) |
| self.assertEqual(metadata_obj, tuple(metadata_obj)) |
| self.assertEqual(tuple(metadata_obj), metadata_obj) |
| |
| expected_sum = tuple(metadata_obj) + (('third', '3'),) |
| self.assertEqual(expected_sum, metadata_obj + (('third', '3'),)) |
| self.assertEqual(expected_sum, metadata_obj + aio.Metadata( |
| ('third', '3'))) |
| |
| |
| if __name__ == '__main__': |
| logging.basicConfig(level=logging.DEBUG) |
| unittest.main(verbosity=2) |