blob: 209643e52d107e05a6ed42d4b6430d78c9367dcf [file] [log] [blame]
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the grpc.aio.UnaryUnaryCall class."""
import asyncio
import logging
import unittest
import datetime
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
_REQUEST_PAYLOAD_SIZE = 7
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1
class TestUnaryUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_call_ok(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.done())
response = await call
self.assertTrue(call.done())
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Response is cached at call object level, reentrance
# returns again the same response
response_retry = await call
self.assertIs(response, response_retry)
async def test_call_rpc_error(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
call = hi(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
call.code())
# Exception is cached at call object level, reentrance
# returns again the same exception
with self.assertRaises(grpc.RpcError) as exception_context_retry:
await call
self.assertIs(exception_context.exception,
exception_context_retry.exception)
async def test_call_code_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_details_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual('', await call.details())
async def test_cancel_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
# The info in the RpcError should match the info in Call object.
self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
'Locally cancelled by application!')
async def test_cancel_unary_unary_in_task(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
call = stub.EmptyCall(messages_pb2.SimpleRequest())
async def another_coro():
coro_started.set()
await call
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
class TestUnaryStreamCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
async def test_multiple_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call.read()
async def test_early_cancel_unary_stream(self):
"""Test cancellation before receiving messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
async def test_late_cancel_unary_stream(self):
"""Test cancellation after received all messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
# After all messages received, it is possible that the final state
# is received or on its way. It's basically a data race, so our
# expectation here is do not crash :)
call.cancel()
self.assertIn(await call.code(),
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_too_many_reads_unary_stream(self):
"""Test calling read after received all messages fails."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
# After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertIs(await call.read(), aio.EOF)
async def test_unary_stream_async_generator(self):
"""Sunny day test case for unary_stream."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
async for response in call:
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_cancel_unary_stream_in_task_using_read(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
await call.read()
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
async def test_cancel_unary_stream_in_task_using_async_for(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
async for _ in call:
pass
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
def test_call_credentials(self):
class DummyAuth(grpc.AuthMetadataPlugin):
def __call__(self, context, callback):
signature = context.method_name[::-1]
callback((("test", signature),), None)
async def coro():
server_target, _ = await start_test_server(secure=False) # pylint: disable=unused-variable
async with aio.insecure_channel(server_target) as channel:
hi = channel.unary_unary('/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.
SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.
SimpleResponse.FromString)
call_credentials = grpc.metadata_call_credentials(DummyAuth())
call = hi(messages_pb2.SimpleRequest(),
credentials=call_credentials)
response = await call
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.loop.run_until_complete(coro())
class TestStreamUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_cancel_stream_unary(self):
call = self._stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
await call.done_writing()
with self.assertRaises(asyncio.CancelledError):
await call
async def test_early_cancel_stream_unary(self):
call = self._stub.StreamingInputCall()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
with self.assertRaises(asyncio.InvalidStateError):
await call.write(messages_pb2.StreamingInputCallRequest())
# Should be no-op
await call.done_writing()
with self.assertRaises(asyncio.CancelledError):
await call
async def test_write_after_done_writing(self):
call = self._stub.StreamingInputCall()
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
# Sends out requests
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
# Should be no-op
await call.done_writing()
with self.assertRaises(asyncio.InvalidStateError):
await call.write(messages_pb2.StreamingInputCallRequest())
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_error_in_async_generator(self):
# Server will pause between responses
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# We expect the request iterator to receive the exception
request_iterator_received_the_exception = asyncio.Event()
async def request_iterator():
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
request_iterator_received_the_exception.set()
call = self._stub.StreamingInputCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
with self.assertRaises(asyncio.CancelledError):
await call
await request_iterator_received_the_exception.wait()
# No failures in the cancel later task!
await cancel_later_task
# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
class TestStreamStreamCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_cancel(self):
# Invokes the actual RPC
call = self._stub.FullDuplexCall()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_with_pending_read(self):
call = self._stub.FullDuplexCall()
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_with_ongoing_read(self):
call = self._stub.FullDuplexCall()
coro_started = asyncio.Event()
async def read_coro():
coro_started.set()
await call.read()
read_task = self.loop.create_task(read_coro())
await coro_started.wait()
self.assertFalse(read_task.done())
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_early_cancel(self):
call = self._stub.FullDuplexCall()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_cancel_after_done_writing(self):
call = self._stub.FullDuplexCall()
await call.done_writing()
# Cancels the RPC
self.assertFalse(call.done())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
async def test_late_cancel(self):
call = self._stub.FullDuplexCall()
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
# Cancels the RPC
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
# Status is still OK
self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_async_generator(self):
async def request_generator():
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call = self._stub.FullDuplexCall(request_generator())
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_too_many_reads(self):
async def request_generator():
for _ in range(_NUM_STREAM_RESPONSES):
yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
call = self._stub.FullDuplexCall(request_generator())
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# After the RPC finished, the read should also produce EOF
self.assertIs(await call.read(), aio.EOF)
async def test_read_write_after_done_writing(self):
call = self._stub.FullDuplexCall()
# Writes two requests, and pending two requests
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
await call.done_writing()
# Further write should fail
with self.assertRaises(asyncio.InvalidStateError):
await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
# But read should be unaffected
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_error_in_async_generator(self):
# Server will pause between responses
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# We expect the request iterator to receive the exception
request_iterator_received_the_exception = asyncio.Event()
async def request_iterator():
with self.assertRaises(asyncio.CancelledError):
for _ in range(_NUM_STREAM_RESPONSES):
yield request
await asyncio.sleep(test_constants.SHORT_TIMEOUT)
request_iterator_received_the_exception.set()
call = self._stub.FullDuplexCall(request_iterator())
# Cancel the RPC after at least one response
async def cancel_later():
await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
call.cancel()
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await request_iterator_received_the_exception.wait()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
# No failures in the cancel later task!
await cancel_later_task
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)