| #!/usr/bin/env python3 |
| # Copyright 2021 The Pigweed Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| # use this file except in compliance with the License. You may obtain a copy of |
| # the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| # License for the specific language governing permissions and limitations under |
| # the License. |
| """Tests using the callback client for pw_rpc.""" |
| |
| import unittest |
| from unittest import mock |
| from typing import Any, List, Optional, Tuple |
| |
| from pw_protobuf_compiler import python_protos |
| from pw_status import Status |
| |
| from pw_rpc import callback_client, client, descriptors, packets |
| from pw_rpc.internal import packet_pb2 |
| |
| TEST_PROTO_1 = """\ |
| syntax = "proto3"; |
| |
| package pw.test1; |
| |
| message SomeMessage { |
| uint32 magic_number = 1; |
| } |
| |
| message AnotherMessage { |
| enum Result { |
| FAILED = 0; |
| FAILED_MISERABLY = 1; |
| I_DONT_WANT_TO_TALK_ABOUT_IT = 2; |
| } |
| |
| Result result = 1; |
| string payload = 2; |
| } |
| |
| service PublicService { |
| rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} |
| rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} |
| rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} |
| rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} |
| } |
| """ |
| |
| CLIENT_CHANNEL_ID: int = 489 |
| |
| |
| def _message_bytes(msg) -> bytes: |
| return msg if isinstance(msg, bytes) else msg.SerializeToString() |
| |
| |
| class _CallbackClientImplTestBase(unittest.TestCase): |
| """Supports writing tests that require responses from an RPC server.""" |
| |
| def setUp(self) -> None: |
| self._protos = python_protos.Library.from_strings(TEST_PROTO_1) |
| self._request = self._protos.packages.pw.test1.SomeMessage |
| |
| self._client = client.Client.from_modules( |
| callback_client.Impl(), |
| [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], |
| self._protos.modules(), |
| ) |
| self._service = self._client.channel( |
| CLIENT_CHANNEL_ID |
| ).rpcs.pw.test1.PublicService |
| |
| self.requests: List[packet_pb2.RpcPacket] = [] |
| self._next_packets: List[Tuple[bytes, Status]] = [] |
| self.send_responses_after_packets: float = 1 |
| |
| self.output_exception: Optional[Exception] = None |
| |
| def last_request(self) -> packet_pb2.RpcPacket: |
| assert self.requests |
| return self.requests[-1] |
| |
| def _enqueue_response( |
| self, |
| channel_id: int = CLIENT_CHANNEL_ID, |
| method: Optional[descriptors.Method] = None, |
| status: Status = Status.OK, |
| payload: bytes = b'', |
| *, |
| ids: Optional[Tuple[int, int]] = None, |
| process_status: Status = Status.OK, |
| call_id: int = client.OPEN_CALL_ID, |
| ) -> None: |
| if method: |
| assert ids is None |
| service_id, method_id = method.service.id, method.id |
| else: |
| assert ids is not None and method is None |
| service_id, method_id = ids |
| |
| self._next_packets.append( |
| ( |
| packet_pb2.RpcPacket( |
| type=packet_pb2.PacketType.RESPONSE, |
| channel_id=channel_id, |
| service_id=service_id, |
| method_id=method_id, |
| call_id=call_id, |
| status=status.value, |
| payload=_message_bytes(payload), |
| ).SerializeToString(), |
| process_status, |
| ) |
| ) |
| |
| def _enqueue_server_stream( |
| self, |
| channel_id: int, |
| method, |
| response, |
| process_status=Status.OK, |
| call_id: int = client.OPEN_CALL_ID, |
| ) -> None: |
| self._next_packets.append( |
| ( |
| packet_pb2.RpcPacket( |
| type=packet_pb2.PacketType.SERVER_STREAM, |
| channel_id=channel_id, |
| service_id=method.service.id, |
| method_id=method.id, |
| call_id=call_id, |
| payload=_message_bytes(response), |
| ).SerializeToString(), |
| process_status, |
| ) |
| ) |
| |
| def _enqueue_error( |
| self, |
| channel_id: int, |
| service, |
| method, |
| status: Status, |
| process_status=Status.OK, |
| call_id: int = client.OPEN_CALL_ID, |
| ) -> None: |
| self._next_packets.append( |
| ( |
| packet_pb2.RpcPacket( |
| type=packet_pb2.PacketType.SERVER_ERROR, |
| channel_id=channel_id, |
| service_id=service |
| if isinstance(service, int) |
| else service.id, |
| method_id=method if isinstance(method, int) else method.id, |
| call_id=call_id, |
| status=status.value, |
| ).SerializeToString(), |
| process_status, |
| ) |
| ) |
| |
| def _handle_packet(self, data: bytes) -> None: |
| if self.output_exception: |
| raise self.output_exception # pylint: disable=raising-bad-type |
| |
| self.requests.append(packets.decode(data)) |
| |
| if self.send_responses_after_packets > 1: |
| self.send_responses_after_packets -= 1 |
| return |
| |
| self._process_enqueued_packets() |
| |
| def _process_enqueued_packets(self) -> None: |
| # Set send_responses_after_packets to infinity to prevent potential |
| # infinite recursion when a packet causes another packet to send. |
| send_after_count = self.send_responses_after_packets |
| self.send_responses_after_packets = float('inf') |
| |
| for packet, status in self._next_packets: |
| self.assertIs(status, self._client.process_packet(packet)) |
| |
| self._next_packets.clear() |
| self.send_responses_after_packets = send_after_count |
| |
| def _sent_payload(self, message_type: type) -> Any: |
| message = message_type() |
| message.ParseFromString(self.last_request().payload) |
| return message |
| |
| |
| # Disable docstring requirements for test functions. |
| # pylint: disable=missing-function-docstring |
| |
| |
| class CallbackClientImplTest(_CallbackClientImplTestBase): |
| """Tests the callback_client.Impl client implementation.""" |
| |
| def test_callback_exceptions_suppressed(self) -> None: |
| stub = self._service.SomeUnary |
| |
| self._enqueue_response(CLIENT_CHANNEL_ID, stub.method) |
| exception_msg = 'YOU BROKE IT O-]-<' |
| |
| with self.assertLogs(callback_client.__package__, 'ERROR') as logs: |
| stub.invoke( |
| self._request(), mock.Mock(side_effect=Exception(exception_msg)) |
| ) |
| |
| self.assertIn(exception_msg, ''.join(logs.output)) |
| |
| # Make sure we can still invoke the RPC. |
| self._enqueue_response(CLIENT_CHANNEL_ID, stub.method, Status.UNKNOWN) |
| status, _ = stub() |
| self.assertIs(status, Status.UNKNOWN) |
| |
| def test_ignore_bad_packets_with_pending_rpc(self) -> None: |
| method = self._service.SomeUnary.method |
| service_id = method.service.id |
| |
| # Unknown channel |
| self._enqueue_response(999, method, process_status=Status.NOT_FOUND) |
| # Bad service |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, ids=(999, method.id), process_status=Status.OK |
| ) |
| # Bad method |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, ids=(service_id, 999), process_status=Status.OK |
| ) |
| # For RPC not pending (is Status.OK because the packet is processed) |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, |
| ids=(service_id, self._service.SomeBidiStreaming.method.id), |
| process_status=Status.OK, |
| ) |
| |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, method, process_status=Status.OK |
| ) |
| |
| status, response = self._service.SomeUnary(magic_number=6) |
| self.assertIs(Status.OK, status) |
| self.assertEqual('', response.payload) |
| |
| def test_server_error_for_unknown_call_sends_no_errors(self) -> None: |
| method = self._service.SomeUnary.method |
| service_id = method.service.id |
| |
| # Unknown channel |
| self._enqueue_error( |
| 999, |
| service_id, |
| method, |
| Status.NOT_FOUND, |
| process_status=Status.NOT_FOUND, |
| ) |
| # Bad service |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, 999, method.id, Status.INVALID_ARGUMENT |
| ) |
| # Bad method |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, service_id, 999, Status.INVALID_ARGUMENT |
| ) |
| # For RPC not pending |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| service_id, |
| self._service.SomeBidiStreaming.method.id, |
| Status.NOT_FOUND, |
| ) |
| |
| self._process_enqueued_packets() |
| |
| self.assertEqual(self.requests, []) |
| |
| def test_exception_if_payload_fails_to_decode(self) -> None: |
| method = self._service.SomeUnary.method |
| |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, |
| method, |
| Status.OK, |
| b'INVALID DATA!!!', |
| process_status=Status.OK, |
| ) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| self._service.SomeUnary(magic_number=6) |
| |
| self.assertIs(context.exception.status, Status.DATA_LOSS) |
| |
| def test_rpc_help_contains_method_name(self) -> None: |
| rpc = self._service.SomeUnary |
| self.assertIn(rpc.method.full_name, rpc.help()) |
| |
| def test_default_timeouts_set_on_impl(self) -> None: |
| impl = callback_client.Impl(None, 1.5) |
| |
| self.assertEqual(impl.default_unary_timeout_s, None) |
| self.assertEqual(impl.default_stream_timeout_s, 1.5) |
| |
| def test_default_timeouts_set_for_all_rpcs(self) -> None: |
| rpc_client = client.Client.from_modules( |
| callback_client.Impl(99, 100), |
| [client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)], |
| self._protos.modules(), |
| ) |
| rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs |
| |
| self.assertEqual( |
| rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99 |
| ) |
| self.assertEqual( |
| rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, |
| 100, |
| ) |
| self.assertEqual( |
| rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s, |
| 99, |
| ) |
| self.assertEqual( |
| rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100 |
| ) |
| |
| def test_rpc_provides_request_type(self) -> None: |
| self.assertIs( |
| self._service.SomeUnary.request, |
| self._service.SomeUnary.method.request_type, |
| ) |
| |
| def test_rpc_provides_response_type(self) -> None: |
| self.assertIs( |
| self._service.SomeUnary.request, |
| self._service.SomeUnary.method.request_type, |
| ) |
| |
| |
| class UnaryTest(_CallbackClientImplTestBase): |
| """Tests for invoking a unary RPC.""" |
| |
| def setUp(self) -> None: |
| super().setUp() |
| self.rpc = self._service.SomeUnary |
| self.method = self.rpc.method |
| |
| def test_blocking_call(self) -> None: |
| for _ in range(3): |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, |
| self.method, |
| Status.ABORTED, |
| self.method.response_type(payload='0_o'), |
| ) |
| |
| status, response = self._service.SomeUnary( |
| self.method.request_type(magic_number=6) |
| ) |
| |
| self.assertEqual( |
| 6, self._sent_payload(self.method.request_type).magic_number |
| ) |
| |
| self.assertIs(Status.ABORTED, status) |
| self.assertEqual('0_o', response.payload) |
| |
| def test_nonblocking_call(self) -> None: |
| for _ in range(3): |
| callback = mock.Mock() |
| call = self.rpc.invoke( |
| self._request(magic_number=5), callback, callback |
| ) |
| |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, |
| self.method, |
| Status.ABORTED, |
| self.method.response_type(payload='0_o'), |
| call_id=call.call_id, |
| ) |
| self._process_enqueued_packets() |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, self.method.response_type(payload='0_o')), |
| mock.call(call, Status.ABORTED), |
| ] |
| ) |
| |
| self.assertEqual( |
| 5, self._sent_payload(self.method.request_type).magic_number |
| ) |
| |
| def test_concurrent_nonblocking_calls(self) -> None: |
| # Start several calls to the same method |
| callbacks_and_calls: List[ |
| Tuple[mock.Mock, callback_client.call.Call] |
| ] = [] |
| for _ in range(3): |
| callback = mock.Mock() |
| call = self.rpc.invoke(self._request(magic_number=5), callback) |
| callbacks_and_calls.append((callback, call)) |
| |
| # Respond only to the last call |
| last_callback, last_call = callbacks_and_calls.pop() |
| last_payload = self.method.response_type(payload='last payload') |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, |
| self.method, |
| payload=last_payload, |
| call_id=last_call.call_id, |
| ) |
| self._process_enqueued_packets() |
| |
| # Assert that only the last caller received a response |
| last_callback.assert_called_once_with(last_call, last_payload) |
| for remaining_callback, _ in callbacks_and_calls: |
| remaining_callback.assert_not_called() |
| |
| # Respond to the other callers and check for receipt |
| other_payload = self.method.response_type(payload='other payload') |
| for callback, call in callbacks_and_calls: |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, |
| self.method, |
| payload=other_payload, |
| call_id=call.call_id, |
| ) |
| self._process_enqueued_packets() |
| callback.assert_called_once_with(call, other_payload) |
| |
| def test_open(self) -> None: |
| self.output_exception = IOError('something went wrong sending!') |
| |
| for _ in range(3): |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, |
| self.method, |
| Status.ABORTED, |
| self.method.response_type(payload='0_o'), |
| ) |
| |
| callback = mock.Mock() |
| call = self.rpc.open( |
| self._request(magic_number=5), callback, callback |
| ) |
| self.assertEqual(self.requests, []) |
| |
| self._process_enqueued_packets() |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, self.method.response_type(payload='0_o')), |
| mock.call(call, Status.ABORTED), |
| ] |
| ) |
| |
| def test_blocking_server_error(self) -> None: |
| for _ in range(3): |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.NOT_FOUND, |
| ) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| self._service.SomeUnary( |
| self.method.request_type(magic_number=6) |
| ) |
| |
| self.assertIs(context.exception.status, Status.NOT_FOUND) |
| |
| def test_nonblocking_cancel(self) -> None: |
| callback = mock.Mock() |
| |
| for _ in range(3): |
| call = self._service.SomeUnary.invoke( |
| self._request(magic_number=55), callback |
| ) |
| |
| self.assertGreater(len(self.requests), 0) |
| self.requests.clear() |
| |
| self.assertTrue(call.cancel()) |
| self.assertFalse(call.cancel()) # Already cancelled, returns False |
| |
| self.assertEqual( |
| self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR |
| ) |
| self.assertEqual(self.last_request().status, Status.CANCELLED.value) |
| |
| callback.assert_not_called() |
| |
| def test_nonblocking_with_request_args(self) -> None: |
| self.rpc.invoke(request_args=dict(magic_number=1138)) |
| self.assertEqual( |
| self._sent_payload(self.rpc.request).magic_number, 1138 |
| ) |
| |
| def test_blocking_timeout_as_argument(self) -> None: |
| with self.assertRaises(callback_client.RpcTimeout): |
| self._service.SomeUnary(pw_rpc_timeout_s=0.0001) |
| |
| def test_blocking_timeout_set_default(self) -> None: |
| self._service.SomeUnary.default_timeout_s = 0.0001 |
| |
| with self.assertRaises(callback_client.RpcTimeout): |
| self._service.SomeUnary() |
| |
| def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: |
| first_call = self.rpc.invoke() |
| self.assertFalse(first_call.completed()) |
| |
| second_call = self.rpc.invoke() |
| |
| self.assertIs(first_call.error, None) |
| self.assertIs(second_call.error, None) |
| |
| def test_nonblocking_exception_in_callback(self) -> None: |
| exception = ValueError('something went wrong! (intentionally)') |
| |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) |
| |
| call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception)) |
| |
| with self.assertRaises(RuntimeError) as context: |
| call.wait() |
| |
| self.assertEqual(context.exception.__cause__, exception) |
| |
| def test_unary_response(self) -> None: |
| proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123) |
| self.assertEqual( |
| repr(callback_client.UnaryResponse(Status.ABORTED, proto)), |
| '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))', |
| ) |
| self.assertEqual( |
| repr(callback_client.UnaryResponse(Status.OK, None)), |
| '(Status.OK, None)', |
| ) |
| |
| def test_on_call_hook(self) -> None: |
| hook_function = mock.Mock() |
| |
| self._client = client.Client.from_modules( |
| callback_client.Impl(on_call_hook=hook_function), |
| [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], |
| self._protos.modules(), |
| ) |
| |
| self._service = self._client.channel( |
| CLIENT_CHANNEL_ID |
| ).rpcs.pw.test1.PublicService |
| |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) |
| self._service.SomeUnary(self.method.request_type(magic_number=6)) |
| |
| hook_function.assert_called_once() |
| self.assertEqual( |
| hook_function.call_args[0][0].method.full_name, |
| self.method.full_name, |
| ) |
| |
| |
| class ServerStreamingTest(_CallbackClientImplTestBase): |
| """Tests for server streaming RPCs.""" |
| |
| def setUp(self) -> None: |
| super().setUp() |
| self.rpc = self._service.SomeServerStreaming |
| self.method = self.rpc.method |
| |
| def test_blocking_call(self) -> None: |
| rep1 = self.method.response_type(payload='!!!') |
| rep2 = self.method.response_type(payload='?') |
| |
| for _ in range(3): |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.ABORTED |
| ) |
| |
| self.assertEqual( |
| [rep1, rep2], |
| self._service.SomeServerStreaming(magic_number=4).responses, |
| ) |
| |
| self.assertEqual( |
| 4, self._sent_payload(self.method.request_type).magic_number |
| ) |
| |
| def test_nonblocking_call(self) -> None: |
| rep1 = self.method.response_type(payload='!!!') |
| rep2 = self.method.response_type(payload='?') |
| |
| for _ in range(3): |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.ABORTED |
| ) |
| |
| callback = mock.Mock() |
| call = self.rpc.invoke( |
| self._request(magic_number=3), callback, callback |
| ) |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, self.method.response_type(payload='!!!')), |
| mock.call(call, self.method.response_type(payload='?')), |
| mock.call(call, Status.ABORTED), |
| ] |
| ) |
| |
| self.assertEqual( |
| 3, self._sent_payload(self.method.request_type).magic_number |
| ) |
| |
| def test_open(self) -> None: |
| self.output_exception = IOError('something went wrong sending!') |
| rep1 = self.method.response_type(payload='!!!') |
| rep2 = self.method.response_type(payload='?') |
| |
| for _ in range(3): |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.ABORTED |
| ) |
| |
| callback = mock.Mock() |
| call = self.rpc.open( |
| self._request(magic_number=3), callback, callback |
| ) |
| self.assertEqual(self.requests, []) |
| |
| self._process_enqueued_packets() |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, self.method.response_type(payload='!!!')), |
| mock.call(call, self.method.response_type(payload='?')), |
| mock.call(call, Status.ABORTED), |
| ] |
| ) |
| |
| def test_nonblocking_cancel(self) -> None: |
| resp = self.rpc.method.response_type(payload='!!!') |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) |
| |
| callback = mock.Mock() |
| call = self.rpc.invoke(self._request(magic_number=3), callback) |
| callback.assert_called_once_with( |
| call, self.rpc.method.response_type(payload='!!!') |
| ) |
| |
| callback.reset_mock() |
| |
| call.cancel() |
| |
| self.assertEqual( |
| self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR |
| ) |
| self.assertEqual(self.last_request().status, Status.CANCELLED.value) |
| |
| # Ensure the RPC can be called after being cancelled. |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp) |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) |
| |
| call = self.rpc.invoke( |
| self._request(magic_number=3), callback, callback |
| ) |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, self.method.response_type(payload='!!!')), |
| mock.call(call, Status.OK), |
| ] |
| ) |
| |
| def test_request_completion(self) -> None: |
| resp = self.rpc.method.response_type(payload='!!!') |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) |
| |
| callback = mock.Mock() |
| call = self.rpc.invoke(self._request(magic_number=3), callback) |
| callback.assert_called_once_with( |
| call, self.rpc.method.response_type(payload='!!!') |
| ) |
| |
| callback.reset_mock() |
| |
| call.request_completion() |
| |
| self.assertEqual( |
| self.last_request().type, |
| packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, |
| ) |
| |
| # Ensure the RPC can be called after being completed. |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp) |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) |
| |
| call = self.rpc.invoke( |
| self._request(magic_number=3), callback, callback |
| ) |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, self.method.response_type(payload='!!!')), |
| mock.call(call, Status.OK), |
| ] |
| ) |
| |
| def test_nonblocking_with_request_args(self) -> None: |
| self.rpc.invoke(request_args=dict(magic_number=1138)) |
| self.assertEqual( |
| self._sent_payload(self.rpc.request).magic_number, 1138 |
| ) |
| |
| def test_blocking_timeout(self) -> None: |
| with self.assertRaises(callback_client.RpcTimeout): |
| self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) |
| |
| def test_nonblocking_iteration_timeout(self) -> None: |
| call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001) |
| with self.assertRaises(callback_client.RpcTimeout): |
| for _ in call: |
| pass |
| |
| def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: |
| first_call = self.rpc.invoke() |
| self.assertFalse(first_call.completed()) |
| |
| second_call = self.rpc.invoke() |
| |
| self.assertIs(first_call.error, None) |
| self.assertIs(second_call.error, None) |
| |
| def test_nonblocking_iterate_over_count(self) -> None: |
| reply = self.method.response_type(payload='!?') |
| |
| for _ in range(4): |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) |
| |
| call = self.rpc.invoke() |
| |
| self.assertEqual(list(call.get_responses(count=1)), [reply]) |
| self.assertEqual(next(iter(call)), reply) |
| self.assertEqual(list(call.get_responses(count=2)), [reply, reply]) |
| |
| def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None: |
| reply = self.method.response_type(payload='!?') |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) |
| |
| call = self.rpc.invoke() |
| |
| self.assertEqual(list(call.get_responses()), [reply]) |
| self.assertEqual(list(call.get_responses()), []) |
| self.assertEqual(list(call), []) |
| |
| |
| class ClientStreamingTest(_CallbackClientImplTestBase): |
| """Tests for client streaming RPCs.""" |
| |
| def setUp(self) -> None: |
| super().setUp() |
| self.rpc = self._service.SomeClientStreaming |
| self.method = self.rpc.method |
| |
| def test_blocking_call(self) -> None: |
| requests = [ |
| self.method.request_type(magic_number=123), |
| self.method.request_type(magic_number=456), |
| ] |
| |
| # Send after len(requests) and the client stream end packet. |
| self.send_responses_after_packets = 3 |
| response = self.method.response_type(payload='yo') |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.OK, response |
| ) |
| |
| results = self.rpc(requests) |
| self.assertIs(results.status, Status.OK) |
| self.assertEqual(results.response, response) |
| |
| def test_blocking_server_error(self) -> None: |
| requests = [self.method.request_type(magic_number=123)] |
| |
| # Send after len(requests) and the client stream end packet. |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.NOT_FOUND, |
| ) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| self.rpc(requests) |
| |
| self.assertIs(context.exception.status, Status.NOT_FOUND) |
| |
| def test_nonblocking_call(self) -> None: |
| """Tests a successful client streaming RPC ended by the server.""" |
| payload_1 = self.method.response_type(payload='-_-') |
| |
| for _ in range(3): |
| stream = self._service.SomeClientStreaming.invoke() |
| self.assertFalse(stream.completed()) |
| |
| stream.send(magic_number=31) |
| self.assertIs( |
| packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type |
| ) |
| self.assertEqual( |
| 31, self._sent_payload(self.method.request_type).magic_number |
| ) |
| self.assertFalse(stream.completed()) |
| |
| # Enqueue the server response to be sent after the next message. |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 |
| ) |
| |
| stream.send(magic_number=32) |
| self.assertIs( |
| packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type |
| ) |
| self.assertEqual( |
| 32, self._sent_payload(self.method.request_type).magic_number |
| ) |
| |
| self.assertTrue(stream.completed()) |
| self.assertIs(Status.OK, stream.status) |
| self.assertIsNone(stream.error) |
| self.assertEqual(payload_1, stream.response) |
| |
| def test_open(self) -> None: |
| self.output_exception = IOError('something went wrong sending!') |
| payload = self.method.response_type(payload='-_-') |
| |
| for _ in range(3): |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.OK, payload |
| ) |
| |
| callback = mock.Mock() |
| call = self.rpc.open(callback, callback, callback) |
| self.assertEqual(self.requests, []) |
| |
| self._process_enqueued_packets() |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, payload), |
| mock.call(call, Status.OK), |
| ] |
| ) |
| |
| def test_nonblocking_finish(self) -> None: |
| """Tests a client streaming RPC ended by the client.""" |
| payload_1 = self.method.response_type(payload='-_-') |
| |
| for _ in range(3): |
| stream = self._service.SomeClientStreaming.invoke() |
| self.assertFalse(stream.completed()) |
| |
| stream.send(magic_number=37) |
| self.assertIs( |
| packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type |
| ) |
| self.assertEqual( |
| 37, self._sent_payload(self.method.request_type).magic_number |
| ) |
| self.assertFalse(stream.completed()) |
| |
| # Enqueue the server response to be sent after the next message. |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 |
| ) |
| |
| stream.finish_and_wait() |
| self.assertIs( |
| packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, |
| self.last_request().type, |
| ) |
| |
| self.assertTrue(stream.completed()) |
| self.assertIs(Status.OK, stream.status) |
| self.assertIsNone(stream.error) |
| self.assertEqual(payload_1, stream.response) |
| |
| def test_nonblocking_cancel(self) -> None: |
| for _ in range(3): |
| stream = self._service.SomeClientStreaming.invoke() |
| stream.send(magic_number=37) |
| |
| self.assertTrue(stream.cancel()) |
| self.assertIs( |
| packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type |
| ) |
| self.assertIs(Status.CANCELLED.value, self.last_request().status) |
| self.assertFalse(stream.cancel()) |
| |
| self.assertTrue(stream.completed()) |
| self.assertIs(stream.error, Status.CANCELLED) |
| |
| def test_nonblocking_server_error(self) -> None: |
| for _ in range(3): |
| stream = self._service.SomeClientStreaming.invoke() |
| |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.INVALID_ARGUMENT, |
| ) |
| stream.send(magic_number=2**32 - 1) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| stream.finish_and_wait() |
| |
| self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) |
| |
| def test_nonblocking_server_error_after_stream_end(self) -> None: |
| for _ in range(3): |
| stream = self._service.SomeClientStreaming.invoke() |
| |
| # Error will be sent in response to the CLIENT_REQUEST_COMPLETION |
| # packet. |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.INVALID_ARGUMENT, |
| ) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| stream.finish_and_wait() |
| |
| self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) |
| |
| def test_nonblocking_send_after_cancelled(self) -> None: |
| call = self._service.SomeClientStreaming.invoke() |
| self.assertTrue(call.cancel()) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| call.send(payload='hello') |
| |
| self.assertIs(context.exception.status, Status.CANCELLED) |
| |
| def test_nonblocking_finish_after_completed(self) -> None: |
| reply = self.method.response_type(payload='!?') |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE, reply |
| ) |
| |
| call = self.rpc.invoke() |
| result = call.finish_and_wait() |
| self.assertEqual(result.response, reply) |
| |
| self.assertEqual(result, call.finish_and_wait()) |
| self.assertEqual(result, call.finish_and_wait()) |
| |
| def test_nonblocking_finish_after_error(self) -> None: |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.UNAVAILABLE, |
| ) |
| |
| call = self.rpc.invoke() |
| |
| for _ in range(3): |
| with self.assertRaises(callback_client.RpcError) as context: |
| call.finish_and_wait() |
| |
| self.assertIs(context.exception.status, Status.UNAVAILABLE) |
| self.assertIs(call.error, Status.UNAVAILABLE) |
| self.assertIsNone(call.response) |
| |
| def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: |
| first_call = self.rpc.invoke() |
| self.assertFalse(first_call.completed()) |
| |
| second_call = self.rpc.invoke() |
| |
| self.assertIs(first_call.error, None) |
| self.assertIs(second_call.error, None) |
| |
| |
| class BidirectionalStreamingTest(_CallbackClientImplTestBase): |
| """Tests for bidirectional streaming RPCs.""" |
| |
| def setUp(self) -> None: |
| super().setUp() |
| self.rpc = self._service.SomeBidiStreaming |
| self.method = self.rpc.method |
| |
| def test_blocking_call(self) -> None: |
| requests = [ |
| self.method.request_type(magic_number=123), |
| self.method.request_type(magic_number=456), |
| ] |
| |
| # Send after len(requests) and the client stream end packet. |
| self.send_responses_after_packets = 3 |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.NOT_FOUND) |
| |
| results = self.rpc(requests) |
| self.assertIs(results.status, Status.NOT_FOUND) |
| self.assertFalse(results.responses) |
| |
| def test_blocking_server_error(self) -> None: |
| requests = [self.method.request_type(magic_number=123)] |
| |
| # Send after len(requests) and the client stream end packet. |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.NOT_FOUND, |
| ) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| self.rpc(requests) |
| |
| self.assertIs(context.exception.status, Status.NOT_FOUND) |
| |
| def test_nonblocking_call(self) -> None: |
| """Tests a bidirectional streaming RPC ended by the server.""" |
| rep1 = self.method.response_type(payload='!!!') |
| rep2 = self.method.response_type(payload='?') |
| |
| for _ in range(3): |
| responses: list = [] |
| stream = self._service.SomeBidiStreaming.invoke( |
| lambda _, res, responses=responses: responses.append(res) |
| ) |
| self.assertFalse(stream.completed()) |
| |
| stream.send(magic_number=55) |
| self.assertIs( |
| packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type |
| ) |
| self.assertEqual( |
| 55, self._sent_payload(self.method.request_type).magic_number |
| ) |
| self.assertFalse(stream.completed()) |
| self.assertEqual([], responses) |
| |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) |
| |
| stream.send(magic_number=66) |
| self.assertIs( |
| packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type |
| ) |
| self.assertEqual( |
| 66, self._sent_payload(self.method.request_type).magic_number |
| ) |
| self.assertFalse(stream.completed()) |
| self.assertEqual([rep1, rep2], responses) |
| |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) |
| |
| stream.send(magic_number=77) |
| self.assertTrue(stream.completed()) |
| self.assertEqual([rep1, rep2], responses) |
| |
| self.assertIs(Status.OK, stream.status) |
| self.assertIsNone(stream.error) |
| |
| def test_open(self) -> None: |
| self.output_exception = IOError('something went wrong sending!') |
| rep1 = self.method.response_type(payload='!!!') |
| rep2 = self.method.response_type(payload='?') |
| |
| for _ in range(3): |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) |
| self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) |
| |
| callback = mock.Mock() |
| call = self.rpc.open(callback, callback, callback) |
| self.assertEqual(self.requests, []) |
| |
| self._process_enqueued_packets() |
| |
| callback.assert_has_calls( |
| [ |
| mock.call(call, self.method.response_type(payload='!!!')), |
| mock.call(call, self.method.response_type(payload='?')), |
| mock.call(call, Status.OK), |
| ] |
| ) |
| |
| @mock.patch('pw_rpc.callback_client.call.Call._default_response') |
| def test_nonblocking(self, callback) -> None: |
| """Tests a bidirectional streaming RPC ended by the server.""" |
| reply = self.method.response_type(payload='This is the payload!') |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) |
| |
| self._service.SomeBidiStreaming.invoke() |
| |
| callback.assert_called_once_with(mock.ANY, reply) |
| |
| def test_nonblocking_server_error(self) -> None: |
| rep1 = self.method.response_type(payload='!!!') |
| |
| for _ in range(3): |
| responses: list = [] |
| stream = self._service.SomeBidiStreaming.invoke( |
| lambda _, res, responses=responses: responses.append(res) |
| ) |
| self.assertFalse(stream.completed()) |
| |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) |
| |
| stream.send(magic_number=55) |
| self.assertFalse(stream.completed()) |
| self.assertEqual([rep1], responses) |
| |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.OUT_OF_RANGE, |
| ) |
| |
| stream.send(magic_number=99999) |
| self.assertTrue(stream.completed()) |
| self.assertEqual([rep1], responses) |
| |
| self.assertIsNone(stream.status) |
| self.assertIs(Status.OUT_OF_RANGE, stream.error) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| stream.finish_and_wait() |
| self.assertIs(context.exception.status, Status.OUT_OF_RANGE) |
| |
| def test_nonblocking_server_error_after_stream_end(self) -> None: |
| for _ in range(3): |
| stream = self._service.SomeBidiStreaming.invoke() |
| |
| # Error will be sent in response to the CLIENT_REQUEST_COMPLETION |
| # packet. |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.INVALID_ARGUMENT, |
| ) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| stream.finish_and_wait() |
| |
| self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) |
| |
| def test_nonblocking_send_after_cancelled(self) -> None: |
| call = self._service.SomeBidiStreaming.invoke() |
| self.assertTrue(call.cancel()) |
| |
| with self.assertRaises(callback_client.RpcError) as context: |
| call.send(payload='hello') |
| |
| self.assertIs(context.exception.status, Status.CANCELLED) |
| |
| def test_nonblocking_finish_after_completed(self) -> None: |
| reply = self.method.response_type(payload='!?') |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) |
| self._enqueue_response( |
| CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE |
| ) |
| |
| call = self.rpc.invoke() |
| result = call.finish_and_wait() |
| self.assertEqual(result.responses, [reply]) |
| |
| self.assertEqual(result, call.finish_and_wait()) |
| self.assertEqual(result, call.finish_and_wait()) |
| |
| def test_nonblocking_finish_after_error(self) -> None: |
| reply = self.method.response_type(payload='!?') |
| self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) |
| self._enqueue_error( |
| CLIENT_CHANNEL_ID, |
| self.method.service, |
| self.method, |
| Status.UNAVAILABLE, |
| ) |
| |
| call = self.rpc.invoke() |
| |
| for _ in range(3): |
| with self.assertRaises(callback_client.RpcError) as context: |
| call.finish_and_wait() |
| |
| self.assertIs(context.exception.status, Status.UNAVAILABLE) |
| self.assertIs(call.error, Status.UNAVAILABLE) |
| self.assertEqual(call.responses, [reply]) |
| |
| def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: |
| first_call = self.rpc.invoke() |
| self.assertFalse(first_call.completed()) |
| |
| second_call = self.rpc.invoke() |
| |
| self.assertIs(first_call.error, None) |
| self.assertIs(second_call.error, None) |
| |
| def test_stream_response(self) -> None: |
| proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123) |
| self.assertEqual( |
| repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)), |
| '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), ' |
| 'pw.test1.SomeMessage(magic_number=123)])', |
| ) |
| self.assertEqual( |
| repr(callback_client.StreamResponse(Status.OK, [])), |
| '(Status.OK, [])', |
| ) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |