blob: 5d30e23fa802169ce87f740b24e2fc3eebcc0869 [file] [log] [blame]
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text
from weakref import WeakSet
import logging
import grpc
from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
UnaryUnaryCall)
from ._interceptor import (InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction)
from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
_LOGGER = logging.getLogger(__name__)
class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls."""
_calls: AbstractSet[_base_call.RpcContext]
def __init__(self):
self._calls = WeakSet()
def _remove_call(self, call: _base_call.RpcContext):
try:
self._calls.remove(call)
except KeyError:
pass
@property
def calls(self) -> AbstractSet[_base_call.RpcContext]:
"""Returns the set of ongoing calls."""
return self._calls
def size(self) -> int:
"""Returns the number of ongoing calls."""
return len(self._calls)
def trace_call(self, call: _base_call.RpcContext):
"""Adds and manages a new ongoing call."""
self._calls.add(call)
call.add_done_callback(self._remove_call)
class _BaseMultiCallable:
"""Base class of all multi callable objects.
Handles the initialization logic and stores common attributes.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_ongoing_calls: _OngoingCalls
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_loop: asyncio.AbstractEventLoop
# pylint: disable=too-many-arguments
def __init__(
self,
channel: cygrpc.AioChannel,
ongoing_calls: _OngoingCalls,
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]],
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
self._channel = channel
self._ongoing_calls = ongoing_calls
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._interceptors = interceptors
class UnaryUnaryMultiCallable(_BaseMultiCallable):
"""Factory an asynchronous unary-unary RPC stub call from client-side."""
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
if not self._interceptors:
call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
metadata, credentials, wait_for_ready,
self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedUnaryUnaryCall(
self._interceptors, request, timeout, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
self._ongoing_calls.trace_call(call)
return call
class UnaryStreamMultiCallable(_BaseMultiCallable):
"""Affords invoking a unary-stream RPC from client-side in an asynchronous way."""
def __call__(self,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
"""
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
class StreamUnaryMultiCallable(_BaseMultiCallable):
"""Affords invoking a stream-unary RPC from client-side in an asynchronous way."""
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamUnaryCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamUnaryCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
class StreamStreamMultiCallable(_BaseMultiCallable):
"""Affords invoking a stream-stream RPC from client-side in an asynchronous way."""
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.StreamStreamCall:
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC. Only valid for
secure Channel.
wait_for_ready: This is an EXPERIMENTAL argument. An optional
flag to enable wait for ready mechanism
compression: An element of grpc.compression, e.g.
grpc.compression.Gzip. This is an EXPERIMENTAL option.
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(timeout)
if metadata is None:
metadata = _IMMUTABLE_EMPTY_TUPLE
call = StreamStreamCall(request_async_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
self._ongoing_calls.trace_call(call)
return call
class Channel:
"""Asynchronous Channel implementation.
A cygrpc.AioChannel-backed implementation.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_ongoing_calls: _OngoingCalls
def __init__(self, target: Text, options: Optional[ChannelArgumentType],
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]):
"""Constructor.
Args:
target: The target to which to connect.
options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
"""
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if interceptors is None:
self._unary_unary_interceptors = None
else:
self._unary_unary_interceptors = list(
filter(
lambda interceptor: isinstance(interceptor,
UnaryUnaryClientInterceptor),
interceptors))
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors)
if invalid_interceptors:
raise ValueError(
"Interceptor must be "+\
"UnaryUnaryClientInterceptors, the following are invalid: {}"\
.format(invalid_interceptors))
self._loop = asyncio.get_event_loop()
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials, self._loop)
self._ongoing_calls = _OngoingCalls()
async def __aenter__(self):
"""Starts an asynchronous context manager.
Returns:
Channel the channel that was instantiated.
"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Finishes the asynchronous context manager by closing the channel.
Still active RPCs will be cancelled.
"""
await self._close(None)
async def _close(self, grace):
if self._channel.closed():
return
# No new calls will be accepted by the Cython channel.
self._channel.closing()
if grace:
# pylint: disable=unused-variable
_, pending = await asyncio.wait(self._ongoing_calls.calls,
timeout=grace,
loop=self._loop)
if not pending:
return
# A new set is created acting as a shallow copy because
# when cancellation happens the calls are automatically
# removed from the originally set.
calls = WeakSet(data=self._ongoing_calls.calls)
for call in calls:
call.cancel()
self._channel.close()
async def close(self, grace: Optional[float] = None):
"""Closes this Channel and releases all resources held by it.
This method immediately stops the channel from executing new RPCs in
all cases.
If a grace period is specified, this method wait until all active
RPCs are finshed, once the grace period is reached the ones that haven't
been terminated are cancelled. If a grace period is not specified
(by passing None for grace), all existing RPCs are cancelled immediately.
This method is idempotent.
"""
await self._close(grace)
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
"""Check the connectivity state of a channel.
This is an EXPERIMENTAL API.
If the channel reaches a stable connectivity state, it is guaranteed
that the return value of this function will eventually converge to that
state.
Args: try_to_connect: a bool indicate whether the Channel should try to
connect to peer or not.
Returns: A ChannelConnectivity object.
"""
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
) -> None:
"""Wait for a change in connectivity state.
This is an EXPERIMENTAL API.
The function blocks until there is a change in the channel connectivity
state from the "last_observed_state". If the state is already
different, this function will return immediately.
There is an inherent race between the invocation of
"Channel.wait_for_state_change" and "Channel.get_state". The state can
change arbitrary times during the race, so there is no way to observe
every state transition.
If there is a need to put a timeout for this function, please refer to
"asyncio.wait_for".
Args:
last_observed_state: A grpc.ChannelConnectivity object representing
the last known state.
"""
assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None)
def unary_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryUnaryMultiCallable:
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
Args:
method: The name of the RPC method.
request_serializer: Optional behaviour for serializing the request
message. Request goes unserialized in case None is passed.
response_deserializer: Optional behaviour for deserializing the
response message. Response goes undeserialized in case None
is passed.
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer,
self._unary_unary_interceptors,
self._loop)
def unary_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
def stream_unary(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
def stream_stream(
self,
method: Text,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, self._ongoing_calls,
_common.encode(method),
request_serializer,
response_deserializer, None,
self._loop)