Implement stream-unary and stream-stream RPC
* Includes both client-side and server-side
* Adding many tests in multiple files
* Introduces EOF as stream terminator
* Fixing crashes from Core in many ways
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
index 64f89bb..468a8f4 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
@@ -23,18 +23,18 @@
_UNKNOWN_CANCELLATION_DETAILS = 'RPC cancelled for unknown reason.'
-cdef class _AioCall:
+cdef class _AioCall(GrpcCallWrapper):
def __cinit__(self,
AioChannel channel,
object deadline,
bytes method,
- CallCredentials credentials):
+ CallCredentials call_credentials):
self.call = NULL
self._channel = channel
self._references = []
self._loop = asyncio.get_event_loop()
- self._create_grpc_call(deadline, method, credentials)
+ self._create_grpc_call(deadline, method, call_credentials)
self._is_locally_cancelled = False
def __dealloc__(self):
@@ -196,9 +196,25 @@
self,
self._loop
)
- return received_message
+ if received_message:
+ return received_message
+ else:
+ return EOF
- async def unary_stream(self,
+ async def send_serialized_message(self, bytes message):
+ """Sends one single raw message in bytes."""
+ await _send_message(self,
+ message,
+ True,
+ self._loop)
+
+ async def send_receive_close(self):
+ """Half close the RPC on the client-side."""
+ cdef SendCloseFromClientOperation op = SendCloseFromClientOperation(_EMPTY_FLAGS)
+ cdef tuple ops = (op,)
+ await execute_batch(self, ops, self._loop)
+
+ async def initiate_unary_stream(self,
bytes request,
object initial_metadata_observer,
object status_observer):
@@ -233,3 +249,80 @@
await _receive_initial_metadata(self,
self._loop),
)
+
+ async def stream_unary(self,
+ tuple metadata,
+ object metadata_sent_observer,
+ object initial_metadata_observer,
+ object status_observer):
+ """Actual implementation of the complete unary-stream call.
+
+ Needs to pay extra attention to the raise mechanism. If we want to
+ propagate the final status exception, then we have to raise it.
+ Othersize, it would end normally and raise `StopAsyncIteration()`.
+ """
+ # Sends out initial_metadata ASAP.
+ await _send_initial_metadata(self,
+ metadata,
+ self._loop)
+ # Notify upper level that sending messages are allowed now.
+ metadata_sent_observer()
+
+ # Receives initial metadata.
+ initial_metadata_observer(
+ await _receive_initial_metadata(self,
+ self._loop),
+ )
+
+ cdef tuple inbound_ops
+ cdef ReceiveMessageOperation receive_message_op = ReceiveMessageOperation(_EMPTY_FLAGS)
+ cdef ReceiveStatusOnClientOperation receive_status_on_client_op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
+ inbound_ops = (receive_message_op, receive_status_on_client_op)
+
+ # Executes all operations in one batch.
+ await execute_batch(self,
+ inbound_ops,
+ self._loop)
+
+ status = AioRpcStatus(
+ receive_status_on_client_op.code(),
+ receive_status_on_client_op.details(),
+ receive_status_on_client_op.trailing_metadata(),
+ receive_status_on_client_op.error_string(),
+ )
+ # Reports the final status of the RPC to Python layer. The observer
+ # pattern is used here to unify unary and streaming code path.
+ status_observer(status)
+
+ if status.code() == StatusCode.ok:
+ return receive_message_op.message()
+ else:
+ return None
+
+ async def initiate_stream_stream(self,
+ tuple metadata,
+ object metadata_sent_observer,
+ object initial_metadata_observer,
+ object status_observer):
+ """Actual implementation of the complete stream-stream call.
+
+ Needs to pay extra attention to the raise mechanism. If we want to
+ propagate the final status exception, then we have to raise it.
+ Othersize, it would end normally and raise `StopAsyncIteration()`.
+ """
+ # Peer may prematurely end this RPC at any point. We need a corutine
+ # that watches if the server sends the final status.
+ self._loop.create_task(self._handle_status_once_received(status_observer))
+
+ # Sends out initial_metadata ASAP.
+ await _send_initial_metadata(self,
+ metadata,
+ self._loop)
+ # Notify upper level that sending messages are allowed now.
+ metadata_sent_observer()
+
+ # Receives initial metadata.
+ initial_metadata_observer(
+ await _receive_initial_metadata(self,
+ self._loop),
+ )
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
index 4022c89..c938f55 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
@@ -96,7 +96,7 @@
def call(self,
bytes method,
object deadline,
- CallCredentials credentials):
+ object python_call_credentials):
"""Assembles a Cython Call object.
Returns:
@@ -105,5 +105,12 @@
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
- cdef _AioCall call = _AioCall(self, deadline, method, credentials)
+
+ cdef CallCredentials cython_call_credentials
+ if python_call_credentials is not None:
+ cython_call_credentials = python_call_credentials._credentials
+ else:
+ cython_call_credentials = None
+
+ cdef _AioCall call = _AioCall(self, deadline, method, cython_call_credentials)
return call
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
index fbb6598..2fb5f04 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
@@ -33,3 +33,24 @@
return serializer(message)
else:
return message
+
+
+class _EOF:
+
+ def __bool__(self):
+ return False
+
+ def __len__(self):
+ return 0
+
+ def _repr(self) -> str:
+ return '<grpc.aio.EOF>'
+
+ def __repr__(self) -> str:
+ return self._repr()
+
+ def __str__(self) -> str:
+ return self._repr()
+
+
+EOF = _EOF()
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
index b8ae832..15f6bba 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
@@ -21,6 +21,10 @@
cdef grpc_call_details details
cdef grpc_metadata_array request_metadata
cdef AioServer server
+ # NOTE(lidiz) Under certain corner case, receiving the client close
+ # operation won't immediately fail ongoing RECV_MESSAGE operations. Here I
+ # added a flag to workaround this unexpected behavior.
+ cdef bint client_closed
cdef object abort_exception
cdef bint metadata_sent
cdef bint status_sent
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
index 5e01084..b8c635c 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
@@ -20,7 +20,8 @@
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
_LOGGER = logging.getLogger(__name__)
cdef int _EMPTY_FLAG = 0
-
+# TODO(lidiz) Use a designated value other than None.
+cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef class _HandlerCallDetails:
def __cinit__(self, str method, tuple invocation_metadata):
@@ -35,6 +36,7 @@
self.server = server
grpc_metadata_array_init(&self.request_metadata)
grpc_call_details_init(&self.details)
+ self.client_closed = False
self.abort_exception = None
self.metadata_sent = False
self.status_sent = False
@@ -83,13 +85,23 @@
self._loop = loop
async def read(self):
+ cdef bytes raw_message
+ if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
+ raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
- cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop)
- return deserialize(self._request_deserializer,
- raw_message)
+ if self._rpc_state.client_closed:
+ return EOF
+ raw_message = await _receive_message(self._rpc_state, self._loop)
+ if raw_message is None:
+ return EOF
+ else:
+ return deserialize(self._request_deserializer,
+ raw_message)
async def write(self, object message):
+ if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
+ raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
await _send_message(self._rpc_state,
@@ -102,6 +114,8 @@
async def send_initial_metadata(self, tuple metadata):
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
+ elif self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
+ raise RuntimeError(_SERVER_STOPPED_DETAILS)
elif self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
@@ -145,27 +159,23 @@
return None
-async def _handle_unary_unary_rpc(object method_handler,
- RPCState rpc_state,
- object loop):
- # Receives request message
- cdef bytes request_raw = await _receive_message(rpc_state, loop)
-
- # Deserializes the request message
- cdef object request_message = deserialize(
- method_handler.request_deserializer,
- request_raw,
- )
-
+async def _finish_handler_with_unary_response(RPCState rpc_state,
+ object unary_handler,
+ object request,
+ _ServicerContext servicer_context,
+ object response_serializer,
+ object loop):
+ """Finishes server method handler with a single response.
+
+ This function executes the application handler, and handles response
+ sending, as well as errors. It is shared between unary-unary and
+ stream-unary handlers.
+ """
# Executes application logic
- cdef object response_message = await method_handler.unary_unary(
- request_message,
- _ServicerContext(
- rpc_state,
- None,
- None,
- loop,
- ),
+
+ cdef object response_message = await unary_handler(
+ request,
+ servicer_context,
)
# Raises exception if aborted
@@ -173,50 +183,50 @@
# Serializes the response message
cdef bytes response_raw = serialize(
- method_handler.response_serializer,
+ response_serializer,
response_message,
)
- # Sends response message
- cdef tuple send_ops = (
- SendStatusFromServerOperation(
- tuple(),
+ # Assembles the batch operations
+ cdef Operation send_status_op = SendStatusFromServerOperation(
+ tuple(),
StatusCode.ok,
b'',
_EMPTY_FLAGS,
- ),
- SendInitialMetadataOperation(None, _EMPTY_FLAGS),
- SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
+ cdef tuple finish_ops
+ if not rpc_state.metadata_sent:
+ finish_ops = (
+ send_status_op,
+ SendInitialMetadataOperation(None, _EMPTY_FLAGS),
+ SendMessageOperation(response_raw, _EMPTY_FLAGS),
+ )
+ else:
+ finish_ops = (
+ send_status_op,
+ SendMessageOperation(response_raw, _EMPTY_FLAGS),
+ )
rpc_state.status_sent = True
- await execute_batch(rpc_state, send_ops, loop)
+ await execute_batch(rpc_state, finish_ops, loop)
-async def _handle_unary_stream_rpc(object method_handler,
- RPCState rpc_state,
- object loop):
- # Receives request message
- cdef bytes request_raw = await _receive_message(rpc_state, loop)
-
- # Deserializes the request message
- cdef object request_message = deserialize(
- method_handler.request_deserializer,
- request_raw,
- )
-
- cdef _ServicerContext servicer_context = _ServicerContext(
- rpc_state,
- method_handler.request_deserializer,
- method_handler.response_serializer,
- loop,
- )
-
+async def _finish_handler_with_stream_responses(RPCState rpc_state,
+ object stream_handler,
+ object request,
+ _ServicerContext servicer_context,
+ object loop):
+ """Finishes server method handler with multiple responses.
+
+ This function executes the application handler, and handles response
+ sending, as well as errors. It is shared between unary-stream and
+ stream-stream handlers.
+ """
cdef object async_response_generator
cdef object response_message
- if inspect.iscoroutinefunction(method_handler.unary_stream):
+ if inspect.iscoroutinefunction(stream_handler):
# The handler uses reader / writer API, returns None.
- await method_handler.unary_stream(
- request_message,
+ await stream_handler(
+ request,
servicer_context,
)
@@ -224,8 +234,8 @@
_raise_if_aborted(rpc_state)
else:
# The handler uses async generator API
- async_response_generator = method_handler.unary_stream(
- request_message,
+ async_response_generator = stream_handler(
+ request,
servicer_context,
)
@@ -250,9 +260,132 @@
_EMPTY_FLAGS,
)
- cdef tuple ops = (op,)
+ cdef tuple finish_ops = (op,)
+ if not rpc_state.metadata_sent:
+ finish_ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAGS))
rpc_state.status_sent = True
- await execute_batch(rpc_state, ops, loop)
+ await execute_batch(rpc_state, finish_ops, loop)
+
+
+async def _handle_unary_unary_rpc(object method_handler,
+ RPCState rpc_state,
+ object loop):
+ # Receives request message
+ cdef bytes request_raw = await _receive_message(rpc_state, loop)
+
+ # Deserializes the request message
+ cdef object request_message = deserialize(
+ method_handler.request_deserializer,
+ request_raw,
+ )
+
+ # Creates a dedecated ServicerContext
+ cdef _ServicerContext servicer_context = _ServicerContext(
+ rpc_state,
+ None,
+ None,
+ loop,
+ )
+
+ # Finishes the application handler
+ await _finish_handler_with_unary_response(
+ rpc_state,
+ method_handler.unary_unary,
+ request_message,
+ servicer_context,
+ method_handler.response_serializer,
+ loop
+ )
+
+
+async def _handle_unary_stream_rpc(object method_handler,
+ RPCState rpc_state,
+ object loop):
+ # Receives request message
+ cdef bytes request_raw = await _receive_message(rpc_state, loop)
+
+ # Deserializes the request message
+ cdef object request_message = deserialize(
+ method_handler.request_deserializer,
+ request_raw,
+ )
+
+ # Creates a dedecated ServicerContext
+ cdef _ServicerContext servicer_context = _ServicerContext(
+ rpc_state,
+ method_handler.request_deserializer,
+ method_handler.response_serializer,
+ loop,
+ )
+
+ # Finishes the application handler
+ await _finish_handler_with_stream_responses(
+ rpc_state,
+ method_handler.unary_stream,
+ request_message,
+ servicer_context,
+ loop,
+ )
+
+
+async def _message_receiver(_ServicerContext servicer_context):
+ """Bridge between the async generator API and the reader-writer API."""
+ cdef object message
+ while True:
+ message = await servicer_context.read()
+ if message is not EOF:
+ yield message
+ else:
+ break
+
+
+async def _handle_stream_unary_rpc(object method_handler,
+ RPCState rpc_state,
+ object loop):
+ # Creates a dedecated ServicerContext
+ cdef _ServicerContext servicer_context = _ServicerContext(
+ rpc_state,
+ method_handler.request_deserializer,
+ None,
+ loop,
+ )
+
+ # Prepares the request generator
+ cdef object request_async_iterator = _message_receiver(servicer_context)
+
+ # Finishes the application handler
+ await _finish_handler_with_unary_response(
+ rpc_state,
+ method_handler.stream_unary,
+ request_async_iterator,
+ servicer_context,
+ method_handler.response_serializer,
+ loop
+ )
+
+
+async def _handle_stream_stream_rpc(object method_handler,
+ RPCState rpc_state,
+ object loop):
+ # Creates a dedecated ServicerContext
+ cdef _ServicerContext servicer_context = _ServicerContext(
+ rpc_state,
+ method_handler.request_deserializer,
+ method_handler.response_serializer,
+ loop,
+ )
+
+ # Prepares the request generator
+ cdef object request_async_iterator = _message_receiver(servicer_context)
+
+ # Finishes the application handler
+ await _finish_handler_with_stream_responses(
+ rpc_state,
+ method_handler.stream_stream,
+ request_async_iterator,
+ servicer_context,
+ loop,
+ )
async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
@@ -293,6 +426,7 @@
# Awaits cancellation from peer.
await execute_batch(rpc_state, ops, loop)
+ rpc_state.client_closed = True
if op.cancelled() and not rpc_task.done():
# Injects `CancelledError` to halt the RPC coroutine
rpc_task.cancel()
@@ -311,8 +445,9 @@
async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
+ cdef object method_handler
# Finds the method handler (application logic)
- cdef object method_handler = _find_method_handler(
+ method_handler = _find_method_handler(
rpc_state.method().decode(),
generic_handlers,
)
@@ -328,20 +463,33 @@
)
return
- # TODO(lidiz) extend to all 4 types of RPC
- if not method_handler.request_streaming and method_handler.response_streaming:
- try:
- await _handle_unary_stream_rpc(method_handler,
+ # Handles unary-unary case
+ if not method_handler.request_streaming and not method_handler.response_streaming:
+ await _handle_unary_unary_rpc(method_handler,
rpc_state,
loop)
- except Exception as e:
- raise
- elif not method_handler.request_streaming and not method_handler.response_streaming:
- await _handle_unary_unary_rpc(method_handler,
- rpc_state,
- loop)
- else:
- raise NotImplementedError()
+ return
+
+ # Handles unary-stream case
+ if not method_handler.request_streaming and method_handler.response_streaming:
+ await _handle_unary_stream_rpc(method_handler,
+ rpc_state,
+ loop)
+ return
+
+ # Handles stream-unary case
+ if method_handler.request_streaming and not method_handler.response_streaming:
+ await _handle_stream_unary_rpc(method_handler,
+ rpc_state,
+ loop)
+ return
+
+ # Handles stream-stream case
+ if method_handler.request_streaming and method_handler.response_streaming:
+ await _handle_stream_stream_rpc(method_handler,
+ rpc_state,
+ loop)
+ return
class _RequestCallError(Exception): pass
diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py
index 2f162d5..8dc52b8 100644
--- a/src/python/grpcio/grpc/experimental/aio/__init__.py
+++ b/src/python/grpcio/grpc/experimental/aio/__init__.py
@@ -22,7 +22,7 @@
import six
import grpc
-from grpc._cython.cygrpc import init_grpc_aio, AbortError
+from grpc._cython.cygrpc import EOF, AbortError, init_grpc_aio
from ._base_call import Call, RpcContext, UnaryStreamCall, UnaryUnaryCall
from ._call import AioRpcError
@@ -86,5 +86,5 @@
'UnaryStreamCall', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
- 'insecure_channel', 'secure_channel', 'server', 'Server',
+ 'insecure_channel', 'server', 'Server', 'EOF', 'secure_channel',
'AbortError')
diff --git a/src/python/grpcio/grpc/experimental/aio/_base_call.py b/src/python/grpcio/grpc/experimental/aio/_base_call.py
index de63729..bdd6902 100644
--- a/src/python/grpcio/grpc/experimental/aio/_base_call.py
+++ b/src/python/grpcio/grpc/experimental/aio/_base_call.py
@@ -19,11 +19,12 @@
"""
from abc import ABCMeta, abstractmethod
-from typing import Any, AsyncIterable, Awaitable, Callable, Generic, Text, Optional
+from typing import (Any, AsyncIterable, Awaitable, Callable, Generic, Optional,
+ Text, Union)
import grpc
-from ._typing import MetadataType, RequestType, ResponseType
+from ._typing import EOFType, MetadataType, RequestType, ResponseType
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@@ -146,14 +147,85 @@
"""
@abstractmethod
- async def read(self) -> ResponseType:
- """Reads one message from the RPC.
+ async def read(self) -> Union[EOFType, ResponseType]:
+ """Reads one message from the stream.
- For each streaming RPC, concurrent reads in multiple coroutines are not
- allowed. If you want to perform read in multiple coroutines, you needs
- synchronization. So, you can start another read after current read is
- finished.
+ Read operations must be serialized when called from multiple
+ coroutines.
Returns:
- A response message of the RPC.
+ A response message, or an `grpc.aio.EOF` to indicate the end of the
+ stream.
+ """
+
+
+class StreamUnaryCall(Generic[RequestType, ResponseType],
+ Call,
+ metaclass=ABCMeta):
+
+ @abstractmethod
+ async def write(self, request: RequestType) -> None:
+ """Writes one message to the stream.
+
+ Raises:
+ An RpcError exception if the write failed.
+ """
+
+ @abstractmethod
+ async def done_writing(self) -> None:
+ """Notifies server that the client is done sending messages.
+
+ After done_writing is called, any additional invocation to the write
+ function will fail. This function is idempotent.
+ """
+
+ @abstractmethod
+ def __await__(self) -> Awaitable[ResponseType]:
+ """Await the response message to be ready.
+
+ Returns:
+ The response message of the stream.
+ """
+
+
+class StreamStreamCall(Generic[RequestType, ResponseType],
+ Call,
+ metaclass=ABCMeta):
+
+ @abstractmethod
+ def __aiter__(self) -> AsyncIterable[ResponseType]:
+ """Returns the async iterable representation that yields messages.
+
+ Under the hood, it is calling the "read" method.
+
+ Returns:
+ An async iterable object that yields messages.
+ """
+
+ @abstractmethod
+ async def read(self) -> Union[EOFType, ResponseType]:
+ """Reads one message from the stream.
+
+ Read operations must be serialized when called from multiple
+ coroutines.
+
+ Returns:
+ A response message, or an `grpc.aio.EOF` to indicate the end of the
+ stream.
+ """
+
+ @abstractmethod
+ async def write(self, request: RequestType) -> None:
+ """Writes one message to the stream.
+
+ Raises:
+ An RpcError exception if the write failed.
+ """
+
+ @abstractmethod
+ async def done_writing(self) -> None:
+ """Notifies server that the client is done sending messages.
+
+ After done_writing is called, any additional invocation to the write
+ function will fail. This function is idempotent.
"""
diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py
index 3080d23..25ea89c 100644
--- a/src/python/grpcio/grpc/experimental/aio/_call.py
+++ b/src/python/grpcio/grpc/experimental/aio/_call.py
@@ -29,6 +29,7 @@
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
+_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
@@ -146,31 +147,48 @@
class Call(_base_call.Call):
+ """Base implementation of client RPC Call object.
+
+ Implements logic around final status, metadata and cancellation.
+ """
_loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool
+ _cython_call: cygrpc._AioCall
- def __init__(self) -> None:
+ def __init__(self, cython_call: cygrpc._AioCall) -> None:
self._loop = asyncio.get_event_loop()
self._code = None
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False
+ self._cython_call = cython_call
- def cancel(self) -> bool:
- """Placeholder cancellation method.
-
- The implementation of this method needs to pass the cancellation reason
- into self._cancellation, using `set_result` instead of
- `set_exception`.
- """
- raise NotImplementedError()
+ def __del__(self) -> None:
+ if not self._status.done():
+ self._cancel(
+ cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
+ _GC_CANCELLATION_DETAILS, None, None))
def cancelled(self) -> bool:
return self._code == grpc.StatusCode.CANCELLED
+ def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
+ """Forwards the application cancellation reasoning."""
+ if not self._status.done():
+ self._set_status(status)
+ self._cython_call.cancel(status)
+ return True
+ else:
+ return False
+
+ def cancel(self) -> bool:
+ return self._cancel(
+ cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
+ _LOCAL_CANCELLATION_DETAILS, None, None))
+
def done(self) -> bool:
return self._status.done()
@@ -247,6 +265,7 @@
return self._repr()
+# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
@@ -254,37 +273,29 @@
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
- _channel: cygrpc.AioChannel
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
- _cython_call: cygrpc._AioCall
- def __init__( # pylint: disable=R0913
- self, request: RequestType, deadline: Optional[float],
- credentials: Optional[grpc.CallCredentials],
- channel: cygrpc.AioChannel, method: bytes,
- request_serializer: SerializingFunction,
- response_deserializer: DeserializingFunction) -> None:
- super().__init__()
+ # pylint: disable=too-many-arguments
+ def __init__(self, request: RequestType, deadline: Optional[float],
+ credentials: Optional[grpc.CallCredentials],
+ channel: cygrpc.AioChannel, method: bytes,
+ request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction) -> None:
+ channel.call(method, deadline, credentials)
+ super().__init__(channel.call(method, deadline, credentials))
self._request = request
- self._channel = channel
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
-
- if credentials is not None:
- grpc_credentials = credentials._credentials
- else:
- grpc_credentials = None
- self._cython_call = self._channel.call(method, deadline,
- grpc_credentials)
self._call = self._loop.create_task(self._invoke())
- def __del__(self) -> None:
- if not self._call.done():
- self._cancel(
- cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
- _GC_CANCELLATION_DETAILS, None, None))
+ def cancel(self) -> bool:
+ if super().cancel():
+ self._call.cancel()
+ return True
+ else:
+ return False
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
@@ -300,7 +311,7 @@
self._set_status,
)
except asyncio.CancelledError:
- if self._code != grpc.StatusCode.CANCELLED:
+ if not self.cancelled():
self.cancel()
# Raises here if RPC failed or cancelled
@@ -309,21 +320,6 @@
return _common.deserialize(serialized_response,
self._response_deserializer)
- def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
- """Forwards the application cancellation reasoning."""
- if not self._status.done():
- self._set_status(status)
- self._cython_call.cancel(status)
- self._call.cancel()
- return True
- else:
- return False
-
- def cancel(self) -> bool:
- return self._cancel(
- cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
- _LOCAL_CANCELLATION_DETAILS, None, None))
-
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
@@ -339,6 +335,7 @@
return response
+# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
@@ -346,91 +343,53 @@
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
- _channel: cygrpc.AioChannel
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
- _cython_call: cygrpc._AioCall
_send_unary_request_task: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
- def __init__( # pylint: disable=R0913
- self, request: RequestType, deadline: Optional[float],
- credentials: Optional[grpc.CallCredentials],
- channel: cygrpc.AioChannel, method: bytes,
- request_serializer: SerializingFunction,
- response_deserializer: DeserializingFunction) -> None:
- super().__init__()
+ # pylint: disable=too-many-arguments
+ def __init__(self, request: RequestType, deadline: Optional[float],
+ credentials: Optional[grpc.CallCredentials],
+ channel: cygrpc.AioChannel, method: bytes,
+ request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction) -> None:
+ super().__init__(channel.call(method, deadline, credentials))
self._request = request
- self._channel = channel
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._send_unary_request_task = self._loop.create_task(
self._send_unary_request())
- self._message_aiter = self._fetch_stream_responses()
+ self._message_aiter = None
- if credentials is not None:
- grpc_credentials = credentials._credentials
+ def cancel(self) -> bool:
+ if super().cancel():
+ self._send_unary_request_task.cancel()
+ return True
else:
- grpc_credentials = None
-
- self._cython_call = self._channel.call(method, deadline,
- grpc_credentials)
-
- def __del__(self) -> None:
- if not self._status.done():
- self._cancel(
- cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
- _GC_CANCELLATION_DETAILS, None, None))
+ return False
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
try:
- await self._cython_call.unary_stream(serialized_request,
- self._set_initial_metadata,
- self._set_status)
+ await self._cython_call.initiate_unary_stream(
+ serialized_request, self._set_initial_metadata,
+ self._set_status)
except asyncio.CancelledError:
- if self._code != grpc.StatusCode.CANCELLED:
+ if not self.cancelled():
self.cancel()
raise
async def _fetch_stream_responses(self) -> ResponseType:
- await self._send_unary_request_task
message = await self._read()
- while message:
+ while message is not cygrpc.EOF:
yield message
message = await self._read()
- def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
- """Forwards the application cancellation reasoning.
-
- Async generator will receive an exception. The cancellation will go
- deep down into Core, and then propagates backup as the
- `cygrpc.AioRpcStatus` exception.
-
- So, under race condition, e.g. the server sent out final state headers
- and the client calling "cancel" at the same time, this method respects
- the winner in Core.
- """
- if not self._status.done():
- self._set_status(status)
- self._cython_call.cancel(status)
-
- if not self._send_unary_request_task.done():
- # Injects CancelledError to the Task. The exception will
- # propagate to _fetch_stream_responses as well, if the sending
- # is not done.
- self._send_unary_request_task.cancel()
- return True
- else:
- return False
-
- def cancel(self) -> bool:
- return self._cancel(
- cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
- _LOCAL_CANCELLATION_DETAILS, None, None))
-
def __aiter__(self) -> AsyncIterable[ResponseType]:
+ if self._message_aiter is None:
+ self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
async def _read(self) -> ResponseType:
@@ -441,12 +400,12 @@
try:
raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError:
- if self._code != grpc.StatusCode.CANCELLED:
+ if not self.cancelled():
self.cancel()
- raise
+ await self._raise_for_status()
- if raw_response is None:
- return None
+ if raw_response is cygrpc.EOF:
+ return cygrpc.EOF
else:
return _common.deserialize(raw_response,
self._response_deserializer)
@@ -454,14 +413,288 @@
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_for_status()
- raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
+ return cygrpc.EOF
response_message = await self._read()
- if response_message is None:
+ if response_message is cygrpc.EOF:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
- # If no exception raised, there is something wrong internally.
- assert False, 'Read operation failed with StatusCode.OK'
+ return response_message
+
+
+# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
+# pylint: disable=abstract-method
+class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
+ """Object for managing stream-unary RPC calls.
+
+ Returned when an instance of `StreamUnaryMultiCallable` object is called.
+ """
+ _metadata: MetadataType
+ _request_serializer: SerializingFunction
+ _response_deserializer: DeserializingFunction
+
+ _metadata_sent: asyncio.Event
+ _done_writing: bool
+ _call_finisher: asyncio.Task
+ _async_request_poller: asyncio.Task
+
+ # pylint: disable=too-many-arguments
+ def __init__(self,
+ request_async_iterator: Optional[AsyncIterable[RequestType]],
+ deadline: Optional[float],
+ credentials: Optional[grpc.CallCredentials],
+ channel: cygrpc.AioChannel, method: bytes,
+ request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction) -> None:
+ super().__init__(channel.call(method, deadline, credentials))
+ self._metadata = _EMPTY_METADATA
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ self._metadata_sent = asyncio.Event(loop=self._loop)
+ self._done_writing = False
+
+ self._call_finisher = self._loop.create_task(self._conduct_rpc())
+
+ # If user passes in an async iterator, create a consumer Task.
+ if request_async_iterator is not None:
+ self._async_request_poller = self._loop.create_task(
+ self._consume_request_iterator(request_async_iterator))
else:
- return response_message
+ self._async_request_poller = None
+
+ def cancel(self) -> bool:
+ if super().cancel():
+ self._call_finisher.cancel()
+ if self._async_request_poller is not None:
+ self._async_request_poller.cancel()
+ return True
+ else:
+ return False
+
+ def _metadata_sent_observer(self):
+ self._metadata_sent.set()
+
+ async def _conduct_rpc(self) -> ResponseType:
+ try:
+ serialized_response = await self._cython_call.stream_unary(
+ self._metadata,
+ self._metadata_sent_observer,
+ self._set_initial_metadata,
+ self._set_status,
+ )
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+
+ # Raises RpcError if the RPC failed or cancelled
+ await self._raise_for_status()
+
+ return _common.deserialize(serialized_response,
+ self._response_deserializer)
+
+ async def _consume_request_iterator(
+ self, request_async_iterator: AsyncIterable[RequestType]) -> None:
+ async for request in request_async_iterator:
+ await self.write(request)
+ await self.done_writing()
+
+ def __await__(self) -> ResponseType:
+ """Wait till the ongoing RPC request finishes."""
+ try:
+ response = yield from self._call_finisher
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ raise
+ return response
+
+ async def write(self, request: RequestType) -> None:
+ if self._status.done():
+ raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
+ if self._done_writing:
+ raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
+ if not self._metadata_sent.is_set():
+ await self._metadata_sent.wait()
+
+ serialized_request = _common.serialize(request,
+ self._request_serializer)
+
+ try:
+ await self._cython_call.send_serialized_message(serialized_request)
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ await self._raise_for_status()
+
+ async def done_writing(self) -> None:
+ """Implementation of done_writing is idempotent."""
+ if self._status.done():
+ # If the RPC is finished, do nothing.
+ return
+ if not self._done_writing:
+ # If the done writing is not sent before, try to send it.
+ self._done_writing = True
+ try:
+ await self._cython_call.send_receive_close()
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ await self._raise_for_status()
+
+
+# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
+# pylint: disable=abstract-method
+class StreamStreamCall(Call, _base_call.StreamStreamCall):
+ """Object for managing stream-stream RPC calls.
+
+ Returned when an instance of `StreamStreamMultiCallable` object is called.
+ """
+ _metadata: MetadataType
+ _request_serializer: SerializingFunction
+ _response_deserializer: DeserializingFunction
+
+ _metadata_sent: asyncio.Event
+ _done_writing: bool
+ _initializer: asyncio.Task
+ _async_request_poller: asyncio.Task
+ _message_aiter: AsyncIterable[ResponseType]
+
+ # pylint: disable=too-many-arguments
+ def __init__(self,
+ request_async_iterator: Optional[AsyncIterable[RequestType]],
+ deadline: Optional[float],
+ credentials: Optional[grpc.CallCredentials],
+ channel: cygrpc.AioChannel, method: bytes,
+ request_serializer: SerializingFunction,
+ response_deserializer: DeserializingFunction) -> None:
+ super().__init__(channel.call(method, deadline, credentials))
+ self._metadata = _EMPTY_METADATA
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ self._metadata_sent = asyncio.Event(loop=self._loop)
+ self._done_writing = False
+
+ self._initializer = self._loop.create_task(self._prepare_rpc())
+
+ # If user passes in an async iterator, create a consumer coroutine.
+ if request_async_iterator is not None:
+ self._async_request_poller = self._loop.create_task(
+ self._consume_request_iterator(request_async_iterator))
+ else:
+ self._async_request_poller = None
+ self._message_aiter = None
+
+ def cancel(self) -> bool:
+ if super().cancel():
+ self._initializer.cancel()
+ if self._async_request_poller is not None:
+ self._async_request_poller.cancel()
+ return True
+ else:
+ return False
+
+ def _metadata_sent_observer(self):
+ self._metadata_sent.set()
+
+ async def _prepare_rpc(self):
+ """This method prepares the RPC for receiving/sending messages.
+
+ All other operations around the stream should only happen after the
+ completion of this method.
+ """
+ try:
+ await self._cython_call.initiate_stream_stream(
+ self._metadata,
+ self._metadata_sent_observer,
+ self._set_initial_metadata,
+ self._set_status,
+ )
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ # No need to raise RpcError here, because no one will `await` this task.
+
+ async def _consume_request_iterator(
+ self, request_async_iterator: Optional[AsyncIterable[RequestType]]
+ ) -> None:
+ async for request in request_async_iterator:
+ await self.write(request)
+ await self.done_writing()
+
+ async def write(self, request: RequestType) -> None:
+ if self._status.done():
+ raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
+ if self._done_writing:
+ raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
+ if not self._metadata_sent.is_set():
+ await self._metadata_sent.wait()
+
+ serialized_request = _common.serialize(request,
+ self._request_serializer)
+
+ try:
+ await self._cython_call.send_serialized_message(serialized_request)
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ await self._raise_for_status()
+
+ async def done_writing(self) -> None:
+ """Implementation of done_writing is idempotent."""
+ if self._status.done():
+ # If the RPC is finished, do nothing.
+ return
+ if not self._done_writing:
+ # If the done writing is not sent before, try to send it.
+ self._done_writing = True
+ try:
+ await self._cython_call.send_receive_close()
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ await self._raise_for_status()
+
+ async def _fetch_stream_responses(self) -> ResponseType:
+ """The async generator that yields responses from peer."""
+ message = await self._read()
+ while message is not cygrpc.EOF:
+ yield message
+ message = await self._read()
+
+ def __aiter__(self) -> AsyncIterable[ResponseType]:
+ if self._message_aiter is None:
+ self._message_aiter = self._fetch_stream_responses()
+ return self._message_aiter
+
+ async def _read(self) -> ResponseType:
+ # Wait for the setup
+ await self._initializer
+
+ # Reads response message from Core
+ try:
+ raw_response = await self._cython_call.receive_serialized_message()
+ except asyncio.CancelledError:
+ if not self.cancelled():
+ self.cancel()
+ await self._raise_for_status()
+
+ if raw_response is cygrpc.EOF:
+ return cygrpc.EOF
+ else:
+ return _common.deserialize(raw_response,
+ self._response_deserializer)
+
+ async def read(self) -> ResponseType:
+ if self._status.done():
+ await self._raise_for_status()
+ return cygrpc.EOF
+
+ response_message = await self._read()
+
+ if response_message is cygrpc.EOF:
+ # If the read operation failed, Core should explain why.
+ await self._raise_for_status()
+ return response_message
diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py
index 2562f0f..6d4fe91 100644
--- a/src/python/grpcio/grpc/experimental/aio/_channel.py
+++ b/src/python/grpcio/grpc/experimental/aio/_channel.py
@@ -13,14 +13,15 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
-from typing import Any, Optional, Sequence, Text
+from typing import Any, AsyncIterable, Optional, Sequence, Text
import grpc
from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
-from ._call import UnaryStreamCall, UnaryUnaryCall
+from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
+ UnaryUnaryCall)
from ._interceptor import (InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
@@ -28,8 +29,16 @@
from ._utils import _timeout_to_deadline
-class UnaryUnaryMultiCallable:
- """Factory an asynchronous unary-unary RPC stub call from client-side."""
+class _BaseMultiCallable:
+ """Base class of all multi callable objects.
+
+ Handles the initialization logic and stores common attributes.
+ """
+ _loop: asyncio.AbstractEventLoop
+ _channel: cygrpc.AioChannel
+ _method: bytes
+ _request_serializer: SerializingFunction
+ _response_deserializer: DeserializingFunction
_channel: cygrpc.AioChannel
_method: bytes
@@ -50,6 +59,10 @@
self._response_deserializer = response_deserializer
self._interceptors = interceptors
+
+class UnaryUnaryMultiCallable(_BaseMultiCallable):
+ """Factory an asynchronous unary-unary RPC stub call from client-side."""
+
def __call__(self,
request: Any,
*,
@@ -114,17 +127,8 @@
)
-class UnaryStreamMultiCallable:
- """Afford invoking a unary-stream RPC from client-side in an asynchronous way."""
-
- def __init__(self, channel: cygrpc.AioChannel, method: bytes,
- request_serializer: SerializingFunction,
- response_deserializer: DeserializingFunction) -> None:
- self._channel = channel
- self._method = method
- self._request_serializer = request_serializer
- self._response_deserializer = response_deserializer
- self._loop = asyncio.get_event_loop()
+class UnaryStreamMultiCallable(_BaseMultiCallable):
+ """Affords invoking a unary-stream RPC from client-side in an asynchronous way."""
def __call__(self,
request: Any,
@@ -176,6 +180,122 @@
)
+class StreamUnaryMultiCallable(_BaseMultiCallable):
+ """Affords invoking a stream-unary RPC from client-side in an asynchronous way."""
+
+ def __call__(self,
+ request_async_iterator: Optional[AsyncIterable[Any]] = None,
+ timeout: Optional[float] = None,
+ metadata: Optional[MetadataType] = None,
+ credentials: Optional[grpc.CallCredentials] = None,
+ wait_for_ready: Optional[bool] = None,
+ compression: Optional[grpc.Compression] = None
+ ) -> _base_call.StreamUnaryCall:
+ """Asynchronously invokes the underlying RPC.
+
+ Args:
+ request: The request value for the RPC.
+ timeout: An optional duration of time in seconds to allow
+ for the RPC.
+ metadata: Optional :term:`metadata` to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC. Only valid for
+ secure Channel.
+ wait_for_ready: This is an EXPERIMENTAL argument. An optional
+ flag to enable wait for ready mechanism
+ compression: An element of grpc.compression, e.g.
+ grpc.compression.Gzip. This is an EXPERIMENTAL option.
+
+ Returns:
+ A Call object instance which is an awaitable object.
+
+ Raises:
+ RpcError: Indicating that the RPC terminated with non-OK status. The
+ raised RpcError will also be a Call for the RPC affording the RPC's
+ metadata, status code, and details.
+ """
+
+ if metadata:
+ raise NotImplementedError("TODO: metadata not implemented yet")
+
+ if wait_for_ready:
+ raise NotImplementedError(
+ "TODO: wait_for_ready not implemented yet")
+
+ if compression:
+ raise NotImplementedError("TODO: compression not implemented yet")
+
+ deadline = _timeout_to_deadline(timeout)
+
+ return StreamUnaryCall(
+ request_async_iterator,
+ deadline,
+ credentials,
+ self._channel,
+ self._method,
+ self._request_serializer,
+ self._response_deserializer,
+ )
+
+
+class StreamStreamMultiCallable(_BaseMultiCallable):
+ """Affords invoking a stream-stream RPC from client-side in an asynchronous way."""
+
+ def __call__(self,
+ request_async_iterator: Optional[AsyncIterable[Any]] = None,
+ timeout: Optional[float] = None,
+ metadata: Optional[MetadataType] = None,
+ credentials: Optional[grpc.CallCredentials] = None,
+ wait_for_ready: Optional[bool] = None,
+ compression: Optional[grpc.Compression] = None
+ ) -> _base_call.StreamStreamCall:
+ """Asynchronously invokes the underlying RPC.
+
+ Args:
+ request: The request value for the RPC.
+ timeout: An optional duration of time in seconds to allow
+ for the RPC.
+ metadata: Optional :term:`metadata` to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC. Only valid for
+ secure Channel.
+ wait_for_ready: This is an EXPERIMENTAL argument. An optional
+ flag to enable wait for ready mechanism
+ compression: An element of grpc.compression, e.g.
+ grpc.compression.Gzip. This is an EXPERIMENTAL option.
+
+ Returns:
+ A Call object instance which is an awaitable object.
+
+ Raises:
+ RpcError: Indicating that the RPC terminated with non-OK status. The
+ raised RpcError will also be a Call for the RPC affording the RPC's
+ metadata, status code, and details.
+ """
+
+ if metadata:
+ raise NotImplementedError("TODO: metadata not implemented yet")
+
+ if wait_for_ready:
+ raise NotImplementedError(
+ "TODO: wait_for_ready not implemented yet")
+
+ if compression:
+ raise NotImplementedError("TODO: compression not implemented yet")
+
+ deadline = _timeout_to_deadline(timeout)
+
+ return StreamStreamCall(
+ request_async_iterator,
+ deadline,
+ credentials,
+ self._channel,
+ self._method,
+ self._request_serializer,
+ self._response_deserializer,
+ )
+
+
class Channel:
"""Asynchronous Channel implementation.
@@ -301,21 +421,27 @@
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
- response_deserializer)
+ response_deserializer, None)
def stream_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None):
- """Placeholder method for stream-unary calls."""
+ response_deserializer: Optional[DeserializingFunction] = None
+ ) -> StreamUnaryMultiCallable:
+ return StreamUnaryMultiCallable(self._channel, _common.encode(method),
+ request_serializer,
+ response_deserializer, None)
def stream_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None):
- """Placeholder method for stream-stream calls."""
+ response_deserializer: Optional[DeserializingFunction] = None
+ ) -> StreamStreamMultiCallable:
+ return StreamStreamMultiCallable(self._channel, _common.encode(method),
+ request_serializer,
+ response_deserializer, None)
async def _close(self):
# TODO: Send cancellation status
diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py
index 7ff8938..6428fb7 100644
--- a/src/python/grpcio/grpc/experimental/aio/_typing.py
+++ b/src/python/grpcio/grpc/experimental/aio/_typing.py
@@ -14,6 +14,7 @@
"""Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Text, Tuple, TypeVar
+from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
@@ -21,3 +22,4 @@
DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[Text, AnyStr]]
ChannelArgumentType = Sequence[Tuple[Text, Any]]
+EOFType = type(EOF)
diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json
index 024d5c8..605a088 100644
--- a/src/python/grpcio_tests/tests_aio/tests.json
+++ b/src/python/grpcio_tests/tests_aio/tests.json
@@ -2,6 +2,8 @@
"_sanity._sanity_test.AioSanityTest",
"unit.abort_test.TestAbort",
"unit.aio_rpc_error_test.TestAioRpcError",
+ "unit.call_test.TestStreamStreamCall",
+ "unit.call_test.TestStreamUnaryCall",
"unit.call_test.TestUnaryStreamCall",
"unit.call_test.TestUnaryUnaryCall",
"unit.channel_argument_test.TestChannelArgument",
diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py
index d99c46f..ccb9f45 100644
--- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py
+++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py
@@ -26,11 +26,12 @@
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
- async def UnaryCall(self, request, context):
+ async def UnaryCall(self, unused_request, unused_context):
return messages_pb2.SimpleResponse()
async def StreamingOutputCall(
- self, request: messages_pb2.StreamingOutputCallRequest, context):
+ self, request: messages_pb2.StreamingOutputCallRequest,
+ unused_context):
for response_parameters in request.response_parameters:
if response_parameters.interval_us != 0:
await asyncio.sleep(
@@ -44,11 +45,30 @@
# Next methods are extra ones that are registred programatically
# when the sever is instantiated. They are not being provided by
# the proto file.
-
async def UnaryCallWithSleep(self, request, context):
await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE)
return messages_pb2.SimpleResponse()
+ async def StreamingInputCall(self, request_async_iterator, unused_context):
+ aggregate_size = 0
+ async for request in request_async_iterator:
+ if request.payload is not None and request.payload.body:
+ aggregate_size += len(request.payload.body)
+ return messages_pb2.StreamingInputCallResponse(
+ aggregated_payload_size=aggregate_size)
+
+ async def FullDuplexCall(self, request_async_iterator, unused_context):
+ async for request in request_async_iterator:
+ for response_parameters in request.response_parameters:
+ if response_parameters.interval_us != 0:
+ await asyncio.sleep(
+ datetime.timedelta(microseconds=response_parameters.
+ interval_us).total_seconds())
+ yield messages_pb2.StreamingOutputCallResponse(
+ payload=messages_pb2.Payload(type=request.payload.type,
+ body=b'\x00' *
+ response_parameters.size))
+
async def start_test_server(secure=False):
server = aio.server(options=(('grpc.so_reuseport', 0),))
diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py
index 59db225..209643e 100644
--- a/src/python/grpcio_tests/tests_aio/unit/call_test.py
+++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py
@@ -30,10 +30,10 @@
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
+_REQUEST_PAYLOAD_SIZE = 7
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_UNREACHABLE_TARGET = '0.1:1111'
-
_INFINITE_INTERVAL_US = 2**31 - 1
@@ -286,7 +286,7 @@
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_too_many_reads_unary_stream(self):
- """Test cancellation after received all messages."""
+ """Test calling read after received all messages fails."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
@@ -306,13 +306,14 @@
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
+ self.assertIs(await call.read(), aio.EOF)
# After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK)
- with self.assertRaises(asyncio.InvalidStateError):
- await call.read()
+ self.assertIs(await call.read(), aio.EOF)
async def test_unary_stream_async_generator(self):
+ """Sunny day test case for unary_stream."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
@@ -426,6 +427,307 @@
self.loop.run_until_complete(coro())
+class TestStreamUnaryCall(AioTestBase):
+
+ async def setUp(self):
+ self._server_target, self._server = await start_test_server()
+ self._channel = aio.insecure_channel(self._server_target)
+ self._stub = test_pb2_grpc.TestServiceStub(self._channel)
+
+ async def tearDown(self):
+ await self._channel.close()
+ await self._server.stop(None)
+
+ async def test_cancel_stream_unary(self):
+ call = self._stub.StreamingInputCall()
+
+ # Prepares the request
+ payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
+ request = messages_pb2.StreamingInputCallRequest(payload=payload)
+
+ # Sends out requests
+ for _ in range(_NUM_STREAM_RESPONSES):
+ await call.write(request)
+
+ # Cancels the RPC
+ self.assertFalse(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertTrue(call.cancel())
+ self.assertTrue(call.cancelled())
+
+ await call.done_writing()
+
+ with self.assertRaises(asyncio.CancelledError):
+ await call
+
+ async def test_early_cancel_stream_unary(self):
+ call = self._stub.StreamingInputCall()
+
+ # Cancels the RPC
+ self.assertFalse(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertTrue(call.cancel())
+ self.assertTrue(call.cancelled())
+
+ with self.assertRaises(asyncio.InvalidStateError):
+ await call.write(messages_pb2.StreamingInputCallRequest())
+
+ # Should be no-op
+ await call.done_writing()
+
+ with self.assertRaises(asyncio.CancelledError):
+ await call
+
+ async def test_write_after_done_writing(self):
+ call = self._stub.StreamingInputCall()
+
+ # Prepares the request
+ payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
+ request = messages_pb2.StreamingInputCallRequest(payload=payload)
+
+ # Sends out requests
+ for _ in range(_NUM_STREAM_RESPONSES):
+ await call.write(request)
+
+ # Should be no-op
+ await call.done_writing()
+
+ with self.assertRaises(asyncio.InvalidStateError):
+ await call.write(messages_pb2.StreamingInputCallRequest())
+
+ response = await call
+ self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
+ self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
+ response.aggregated_payload_size)
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_error_in_async_generator(self):
+ # Server will pause between responses
+ request = messages_pb2.StreamingOutputCallRequest()
+ request.response_parameters.append(
+ messages_pb2.ResponseParameters(
+ size=_RESPONSE_PAYLOAD_SIZE,
+ interval_us=_RESPONSE_INTERVAL_US,
+ ))
+
+ # We expect the request iterator to receive the exception
+ request_iterator_received_the_exception = asyncio.Event()
+
+ async def request_iterator():
+ with self.assertRaises(asyncio.CancelledError):
+ for _ in range(_NUM_STREAM_RESPONSES):
+ yield request
+ await asyncio.sleep(test_constants.SHORT_TIMEOUT)
+ request_iterator_received_the_exception.set()
+
+ call = self._stub.StreamingInputCall(request_iterator())
+
+ # Cancel the RPC after at least one response
+ async def cancel_later():
+ await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
+ call.cancel()
+
+ cancel_later_task = self.loop.create_task(cancel_later())
+
+ # No exceptions here
+ with self.assertRaises(asyncio.CancelledError):
+ await call
+
+ await request_iterator_received_the_exception.wait()
+
+ # No failures in the cancel later task!
+ await cancel_later_task
+
+
+# Prepares the request that stream in a ping-pong manner.
+_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
+_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
+ messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
+
+
+class TestStreamStreamCall(AioTestBase):
+
+ async def setUp(self):
+ self._server_target, self._server = await start_test_server()
+ self._channel = aio.insecure_channel(self._server_target)
+ self._stub = test_pb2_grpc.TestServiceStub(self._channel)
+
+ async def tearDown(self):
+ await self._channel.close()
+ await self._server.stop(None)
+
+ async def test_cancel(self):
+ # Invokes the actual RPC
+ call = self._stub.FullDuplexCall()
+
+ for _ in range(_NUM_STREAM_RESPONSES):
+ await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
+ response = await call.read()
+ self.assertIsInstance(response,
+ messages_pb2.StreamingOutputCallResponse)
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+
+ # Cancels the RPC
+ self.assertFalse(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertTrue(call.cancel())
+ self.assertTrue(call.cancelled())
+ self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
+
+ async def test_cancel_with_pending_read(self):
+ call = self._stub.FullDuplexCall()
+
+ await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
+
+ # Cancels the RPC
+ self.assertFalse(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertTrue(call.cancel())
+ self.assertTrue(call.cancelled())
+ self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
+
+ async def test_cancel_with_ongoing_read(self):
+ call = self._stub.FullDuplexCall()
+ coro_started = asyncio.Event()
+
+ async def read_coro():
+ coro_started.set()
+ await call.read()
+
+ read_task = self.loop.create_task(read_coro())
+ await coro_started.wait()
+ self.assertFalse(read_task.done())
+
+ # Cancels the RPC
+ self.assertFalse(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertTrue(call.cancel())
+ self.assertTrue(call.cancelled())
+ self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
+
+ async def test_early_cancel(self):
+ call = self._stub.FullDuplexCall()
+
+ # Cancels the RPC
+ self.assertFalse(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertTrue(call.cancel())
+ self.assertTrue(call.cancelled())
+ self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
+
+ async def test_cancel_after_done_writing(self):
+ call = self._stub.FullDuplexCall()
+ await call.done_writing()
+
+ # Cancels the RPC
+ self.assertFalse(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertTrue(call.cancel())
+ self.assertTrue(call.cancelled())
+ self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
+
+ async def test_late_cancel(self):
+ call = self._stub.FullDuplexCall()
+ await call.done_writing()
+ self.assertEqual(grpc.StatusCode.OK, await call.code())
+
+ # Cancels the RPC
+ self.assertTrue(call.done())
+ self.assertFalse(call.cancelled())
+ self.assertFalse(call.cancel())
+ self.assertFalse(call.cancelled())
+
+ # Status is still OK
+ self.assertEqual(grpc.StatusCode.OK, await call.code())
+
+ async def test_async_generator(self):
+
+ async def request_generator():
+ yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
+ yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
+
+ call = self._stub.FullDuplexCall(request_generator())
+ async for response in call:
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_too_many_reads(self):
+
+ async def request_generator():
+ for _ in range(_NUM_STREAM_RESPONSES):
+ yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE
+
+ call = self._stub.FullDuplexCall(request_generator())
+ for _ in range(_NUM_STREAM_RESPONSES):
+ response = await call.read()
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+ self.assertIs(await call.read(), aio.EOF)
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+ # After the RPC finished, the read should also produce EOF
+ self.assertIs(await call.read(), aio.EOF)
+
+ async def test_read_write_after_done_writing(self):
+ call = self._stub.FullDuplexCall()
+
+ # Writes two requests, and pending two requests
+ await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
+ await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
+ await call.done_writing()
+
+ # Further write should fail
+ with self.assertRaises(asyncio.InvalidStateError):
+ await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE)
+
+ # But read should be unaffected
+ response = await call.read()
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+ response = await call.read()
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_error_in_async_generator(self):
+ # Server will pause between responses
+ request = messages_pb2.StreamingOutputCallRequest()
+ request.response_parameters.append(
+ messages_pb2.ResponseParameters(
+ size=_RESPONSE_PAYLOAD_SIZE,
+ interval_us=_RESPONSE_INTERVAL_US,
+ ))
+
+ # We expect the request iterator to receive the exception
+ request_iterator_received_the_exception = asyncio.Event()
+
+ async def request_iterator():
+ with self.assertRaises(asyncio.CancelledError):
+ for _ in range(_NUM_STREAM_RESPONSES):
+ yield request
+ await asyncio.sleep(test_constants.SHORT_TIMEOUT)
+ request_iterator_received_the_exception.set()
+
+ call = self._stub.FullDuplexCall(request_iterator())
+
+ # Cancel the RPC after at least one response
+ async def cancel_later():
+ await asyncio.sleep(test_constants.SHORT_TIMEOUT * 2)
+ call.cancel()
+
+ cancel_later_task = self.loop.create_task(cancel_later())
+
+ # No exceptions here
+ async for response in call:
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+
+ await request_iterator_received_the_exception.wait()
+
+ self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
+ # No failures in the cancel later task!
+ await cancel_later_task
+
+
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py
index 1ab372a..6267862 100644
--- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py
+++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py
@@ -32,6 +32,7 @@
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5
+_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
@@ -121,7 +122,104 @@
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
+ async def test_stream_unary_using_write(self):
+ channel = aio.insecure_channel(self._server_target)
+ stub = test_pb2_grpc.TestServiceStub(channel)
+
+ # Invokes the actual RPC
+ call = stub.StreamingInputCall()
+
+ # Prepares the request
+ payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
+ request = messages_pb2.StreamingInputCallRequest(payload=payload)
+
+ # Sends out requests
+ for _ in range(_NUM_STREAM_RESPONSES):
+ await call.write(request)
+ await call.done_writing()
+
+ # Validates the responses
+ response = await call
+ self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
+ self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
+ response.aggregated_payload_size)
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+ await channel.close()
+
+ async def test_stream_unary_using_async_gen(self):
+ channel = aio.insecure_channel(self._server_target)
+ stub = test_pb2_grpc.TestServiceStub(channel)
+
+ # Prepares the request
+ payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
+ request = messages_pb2.StreamingInputCallRequest(payload=payload)
+
+ async def gen():
+ for _ in range(_NUM_STREAM_RESPONSES):
+ yield request
+
+ # Invokes the actual RPC
+ call = stub.StreamingInputCall(gen())
+
+ # Validates the responses
+ response = await call
+ self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
+ self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
+ response.aggregated_payload_size)
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+ await channel.close()
+
+ async def test_stream_stream_using_read_write(self):
+ channel = aio.insecure_channel(self._server_target)
+ stub = test_pb2_grpc.TestServiceStub(channel)
+
+ # Invokes the actual RPC
+ call = stub.FullDuplexCall()
+
+ # Prepares the request
+ request = messages_pb2.StreamingOutputCallRequest()
+ request.response_parameters.append(
+ messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
+
+ for _ in range(_NUM_STREAM_RESPONSES):
+ await call.write(request)
+ response = await call.read()
+ self.assertIsInstance(response,
+ messages_pb2.StreamingOutputCallResponse)
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+
+ await call.done_writing()
+
+ self.assertEqual(grpc.StatusCode.OK, await call.code())
+ await channel.close()
+
+ async def test_stream_stream_using_async_gen(self):
+ channel = aio.insecure_channel(self._server_target)
+ stub = test_pb2_grpc.TestServiceStub(channel)
+
+ # Prepares the request
+ request = messages_pb2.StreamingOutputCallRequest()
+ request.response_parameters.append(
+ messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
+
+ async def gen():
+ for _ in range(_NUM_STREAM_RESPONSES):
+ yield request
+
+ # Invokes the actual RPC
+ call = stub.FullDuplexCall(gen())
+
+ async for response in call:
+ self.assertIsInstance(response,
+ messages_pb2.StreamingOutputCallResponse)
+ self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
+
+ self.assertEqual(grpc.StatusCode.OK, await call.code())
+ await channel.close()
+
if __name__ == '__main__':
- logging.basicConfig(level=logging.WARN)
+ logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py
index 265744b..fff944f 100644
--- a/src/python/grpcio_tests/tests_aio/unit/server_test.py
+++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py
@@ -13,15 +13,16 @@
# limitations under the License.
import asyncio
-import logging
-import unittest
-import time
import gc
+import logging
+import time
+import unittest
import grpc
from grpc.experimental import aio
-from tests_aio.unit._test_base import AioTestBase
+
from tests.unit.framework.common import test_constants
+from tests_aio.unit._test_base import AioTestBase
_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary'
_BLOCK_FOREVER = '/test/BlockForever'
@@ -29,9 +30,16 @@
_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen'
_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter'
_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed'
+_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen'
+_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter'
+_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed'
+_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
+_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
+_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
+_NUM_STREAM_REQUESTS = 3
_NUM_STREAM_RESPONSES = 5
@@ -39,6 +47,41 @@
def __init__(self):
self._called = asyncio.get_event_loop().create_future()
+ self._routing_table = {
+ _SIMPLE_UNARY_UNARY:
+ grpc.unary_unary_rpc_method_handler(self._unary_unary),
+ _BLOCK_FOREVER:
+ grpc.unary_unary_rpc_method_handler(self._block_forever),
+ _BLOCK_BRIEFLY:
+ grpc.unary_unary_rpc_method_handler(self._block_briefly),
+ _UNARY_STREAM_ASYNC_GEN:
+ grpc.unary_stream_rpc_method_handler(
+ self._unary_stream_async_gen),
+ _UNARY_STREAM_READER_WRITER:
+ grpc.unary_stream_rpc_method_handler(
+ self._unary_stream_reader_writer),
+ _UNARY_STREAM_EVILLY_MIXED:
+ grpc.unary_stream_rpc_method_handler(
+ self._unary_stream_evilly_mixed),
+ _STREAM_UNARY_ASYNC_GEN:
+ grpc.stream_unary_rpc_method_handler(
+ self._stream_unary_async_gen),
+ _STREAM_UNARY_READER_WRITER:
+ grpc.stream_unary_rpc_method_handler(
+ self._stream_unary_reader_writer),
+ _STREAM_UNARY_EVILLY_MIXED:
+ grpc.stream_unary_rpc_method_handler(
+ self._stream_unary_evilly_mixed),
+ _STREAM_STREAM_ASYNC_GEN:
+ grpc.stream_stream_rpc_method_handler(
+ self._stream_stream_async_gen),
+ _STREAM_STREAM_READER_WRITER:
+ grpc.stream_stream_rpc_method_handler(
+ self._stream_stream_reader_writer),
+ _STREAM_STREAM_EVILLY_MIXED:
+ grpc.stream_stream_rpc_method_handler(
+ self._stream_stream_evilly_mixed),
+ }
@staticmethod
async def _unary_unary(unused_request, unused_context):
@@ -64,23 +107,59 @@
for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE)
+ async def _stream_unary_async_gen(self, request_iterator, unused_context):
+ request_count = 0
+ async for request in request_iterator:
+ assert _REQUEST == request
+ request_count += 1
+ assert _NUM_STREAM_REQUESTS == request_count
+ return _RESPONSE
+
+ async def _stream_unary_reader_writer(self, unused_request, context):
+ for _ in range(_NUM_STREAM_REQUESTS):
+ assert _REQUEST == await context.read()
+ return _RESPONSE
+
+ async def _stream_unary_evilly_mixed(self, request_iterator, context):
+ assert _REQUEST == await context.read()
+ request_count = 0
+ async for request in request_iterator:
+ assert _REQUEST == request
+ request_count += 1
+ assert _NUM_STREAM_REQUESTS - 1 == request_count
+ return _RESPONSE
+
+ async def _stream_stream_async_gen(self, request_iterator, unused_context):
+ request_count = 0
+ async for request in request_iterator:
+ assert _REQUEST == request
+ request_count += 1
+ assert _NUM_STREAM_REQUESTS == request_count
+
+ for _ in range(_NUM_STREAM_RESPONSES):
+ yield _RESPONSE
+
+ async def _stream_stream_reader_writer(self, unused_request, context):
+ for _ in range(_NUM_STREAM_REQUESTS):
+ assert _REQUEST == await context.read()
+ for _ in range(_NUM_STREAM_RESPONSES):
+ await context.write(_RESPONSE)
+
+ async def _stream_stream_evilly_mixed(self, request_iterator, context):
+ assert _REQUEST == await context.read()
+ request_count = 0
+ async for request in request_iterator:
+ assert _REQUEST == request
+ request_count += 1
+ assert _NUM_STREAM_REQUESTS - 1 == request_count
+
+ yield _RESPONSE
+ for _ in range(_NUM_STREAM_RESPONSES - 1):
+ await context.write(_RESPONSE)
+
def service(self, handler_details):
self._called.set_result(None)
- if handler_details.method == _SIMPLE_UNARY_UNARY:
- return grpc.unary_unary_rpc_method_handler(self._unary_unary)
- if handler_details.method == _BLOCK_FOREVER:
- return grpc.unary_unary_rpc_method_handler(self._block_forever)
- if handler_details.method == _BLOCK_BRIEFLY:
- return grpc.unary_unary_rpc_method_handler(self._block_briefly)
- if handler_details.method == _UNARY_STREAM_ASYNC_GEN:
- return grpc.unary_stream_rpc_method_handler(
- self._unary_stream_async_gen)
- if handler_details.method == _UNARY_STREAM_READER_WRITER:
- return grpc.unary_stream_rpc_method_handler(
- self._unary_stream_reader_writer)
- if handler_details.method == _UNARY_STREAM_EVILLY_MIXED:
- return grpc.unary_stream_rpc_method_handler(
- self._unary_stream_evilly_mixed)
+ return self._routing_table[handler_details.method]
async def wait_for_call(self):
await self._called
@@ -98,89 +177,152 @@
class TestServer(AioTestBase):
async def setUp(self):
- self._server_target, self._server, self._generic_handler = await _start_test_server(
- )
+ addr, self._server, self._generic_handler = await _start_test_server()
+ self._channel = aio.insecure_channel(addr)
async def tearDown(self):
+ await self._channel.close()
await self._server.stop(None)
async def test_unary_unary(self):
- async with aio.insecure_channel(self._server_target) as channel:
- unary_unary_call = channel.unary_unary(_SIMPLE_UNARY_UNARY)
- response = await unary_unary_call(_REQUEST)
- self.assertEqual(response, _RESPONSE)
+ unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY)
+ response = await unary_unary_call(_REQUEST)
+ self.assertEqual(response, _RESPONSE)
async def test_unary_stream_async_generator(self):
- async with aio.insecure_channel(self._server_target) as channel:
- unary_stream_call = channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
- call = unary_stream_call(_REQUEST)
+ unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
+ call = unary_stream_call(_REQUEST)
- # Expecting the request message to reach server before retriving
- # any responses.
- await asyncio.wait_for(self._generic_handler.wait_for_call(),
- test_constants.SHORT_TIMEOUT)
+ response_cnt = 0
+ async for response in call:
+ response_cnt += 1
+ self.assertEqual(_RESPONSE, response)
- response_cnt = 0
- async for response in call:
- response_cnt += 1
- self.assertEqual(_RESPONSE, response)
-
- self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
- self.assertEqual(await call.code(), grpc.StatusCode.OK)
+ self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream_reader_writer(self):
- async with aio.insecure_channel(self._server_target) as channel:
- unary_stream_call = channel.unary_stream(
- _UNARY_STREAM_READER_WRITER)
- call = unary_stream_call(_REQUEST)
+ unary_stream_call = self._channel.unary_stream(
+ _UNARY_STREAM_READER_WRITER)
+ call = unary_stream_call(_REQUEST)
- # Expecting the request message to reach server before retriving
- # any responses.
- await asyncio.wait_for(self._generic_handler.wait_for_call(),
- test_constants.SHORT_TIMEOUT)
+ for _ in range(_NUM_STREAM_RESPONSES):
+ response = await call.read()
+ self.assertEqual(_RESPONSE, response)
- for _ in range(_NUM_STREAM_RESPONSES):
- response = await call.read()
- self.assertEqual(_RESPONSE, response)
-
- self.assertEqual(await call.code(), grpc.StatusCode.OK)
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream_evilly_mixed(self):
- async with aio.insecure_channel(self._server_target) as channel:
- unary_stream_call = channel.unary_stream(_UNARY_STREAM_EVILLY_MIXED)
- call = unary_stream_call(_REQUEST)
+ unary_stream_call = self._channel.unary_stream(
+ _UNARY_STREAM_EVILLY_MIXED)
+ call = unary_stream_call(_REQUEST)
- # Expecting the request message to reach server before retriving
- # any responses.
- await asyncio.wait_for(self._generic_handler.wait_for_call(),
- test_constants.SHORT_TIMEOUT)
+ # Uses reader API
+ self.assertEqual(_RESPONSE, await call.read())
- # Uses reader API
- self.assertEqual(_RESPONSE, await call.read())
+ # Uses async generator API
+ response_cnt = 0
+ async for response in call:
+ response_cnt += 1
+ self.assertEqual(_RESPONSE, response)
- # Uses async generator API
- response_cnt = 0
- async for response in call:
- response_cnt += 1
- self.assertEqual(_RESPONSE, response)
+ self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
- self.assertEqual(_NUM_STREAM_RESPONSES - 1, response_cnt)
+ async def test_stream_unary_async_generator(self):
+ stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)
+ call = stream_unary_call()
- self.assertEqual(await call.code(), grpc.StatusCode.OK)
+ for _ in range(_NUM_STREAM_REQUESTS):
+ await call.write(_REQUEST)
+ await call.done_writing()
+
+ response = await call
+ self.assertEqual(_RESPONSE, response)
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_stream_unary_reader_writer(self):
+ stream_unary_call = self._channel.stream_unary(
+ _STREAM_UNARY_READER_WRITER)
+ call = stream_unary_call()
+
+ for _ in range(_NUM_STREAM_REQUESTS):
+ await call.write(_REQUEST)
+ await call.done_writing()
+
+ response = await call
+ self.assertEqual(_RESPONSE, response)
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_stream_unary_evilly_mixed(self):
+ stream_unary_call = self._channel.stream_unary(
+ _STREAM_UNARY_EVILLY_MIXED)
+ call = stream_unary_call()
+
+ for _ in range(_NUM_STREAM_REQUESTS):
+ await call.write(_REQUEST)
+ await call.done_writing()
+
+ response = await call
+ self.assertEqual(_RESPONSE, response)
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_stream_stream_async_generator(self):
+ stream_stream_call = self._channel.stream_stream(
+ _STREAM_STREAM_ASYNC_GEN)
+ call = stream_stream_call()
+
+ for _ in range(_NUM_STREAM_REQUESTS):
+ await call.write(_REQUEST)
+ await call.done_writing()
+
+ for _ in range(_NUM_STREAM_RESPONSES):
+ response = await call.read()
+ self.assertEqual(_RESPONSE, response)
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_stream_stream_reader_writer(self):
+ stream_stream_call = self._channel.stream_stream(
+ _STREAM_STREAM_READER_WRITER)
+ call = stream_stream_call()
+
+ for _ in range(_NUM_STREAM_REQUESTS):
+ await call.write(_REQUEST)
+ await call.done_writing()
+
+ for _ in range(_NUM_STREAM_RESPONSES):
+ response = await call.read()
+ self.assertEqual(_RESPONSE, response)
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
+ async def test_stream_stream_evilly_mixed(self):
+ stream_stream_call = self._channel.stream_stream(
+ _STREAM_STREAM_EVILLY_MIXED)
+ call = stream_stream_call()
+
+ for _ in range(_NUM_STREAM_REQUESTS):
+ await call.write(_REQUEST)
+ await call.done_writing()
+
+ for _ in range(_NUM_STREAM_RESPONSES):
+ response = await call.read()
+ self.assertEqual(_RESPONSE, response)
+
+ self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_shutdown(self):
await self._server.stop(None)
# Ensures no SIGSEGV triggered, and ends within timeout.
async def test_shutdown_after_call(self):
- async with aio.insecure_channel(self._server_target) as channel:
- await channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
+ await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
await self._server.stop(None)
async def test_graceful_shutdown_success(self):
- channel = aio.insecure_channel(self._server_target)
- call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
+ call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call()
shutdown_start_time = time.time()
@@ -190,13 +332,11 @@
test_constants.SHORT_TIMEOUT / 3)
# Validates the states.
- await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
async def test_graceful_shutdown_failed(self):
- channel = aio.insecure_channel(self._server_target)
- call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
+ call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call()
await self._server.stop(test_constants.SHORT_TIMEOUT)
@@ -206,11 +346,9 @@
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
- await channel.close()
async def test_concurrent_graceful_shutdown(self):
- channel = aio.insecure_channel(self._server_target)
- call = channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
+ call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
await self._generic_handler.wait_for_call()
# Expects the shortest grace period to be effective.
@@ -224,13 +362,11 @@
self.assertGreater(grace_period_length,
test_constants.SHORT_TIMEOUT / 3)
- await channel.close()
self.assertEqual(_RESPONSE, await call)
self.assertTrue(call.done())
async def test_concurrent_graceful_shutdown_immediate(self):
- channel = aio.insecure_channel(self._server_target)
- call = channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
+ call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
await self._generic_handler.wait_for_call()
# Expects no grace period, due to the "server.stop(None)".
@@ -246,7 +382,6 @@
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertIn('GOAWAY', exception_context.exception.details())
- await channel.close()
@unittest.skip('https://github.com/grpc/grpc/issues/20818')
async def test_shutdown_before_call(self):