blob: 452e787112b62366667fdb496d0740509e085e00 [file] [log] [blame]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AsyncIO helpers for :mod:`grpc` supporting 3.6+.
Please combine more detailed docstring in grpc_helpers.py to use following
functions. This module is implementing the same surface with AsyncIO semantics.
"""
import asyncio
import functools
import grpc
from grpc import aio
from google.api_core import exceptions, grpc_helpers
# TODO(lidiz) Support gRPC GCP wrapper
HAS_GRPC_GCP = False
# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
# automatic patching for us. But that means the overhead of creating an
# extra Python function spreads to every single send and receive.
class _WrappedCall(aio.Call):
def __init__(self):
self._call = None
def with_call(self, call):
"""Supplies the call object separately to keep __init__ clean."""
self._call = call
return self
async def initial_metadata(self):
return await self._call.initial_metadata()
async def trailing_metadata(self):
return await self._call.trailing_metadata()
async def code(self):
return await self._call.code()
async def details(self):
return await self._call.details()
def cancelled(self):
return self._call.cancelled()
def done(self):
return self._call.done()
def time_remaining(self):
return self._call.time_remaining()
def cancel(self):
return self._call.cancel()
def add_done_callback(self, callback):
self._call.add_done_callback(callback)
async def wait_for_connection(self):
try:
await self._call.wait_for_connection()
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
class _WrappedUnaryResponseMixin(_WrappedCall):
def __await__(self):
try:
response = yield from self._call.__await__()
return response
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
class _WrappedStreamResponseMixin(_WrappedCall):
def __init__(self):
self._wrapped_async_generator = None
async def read(self):
try:
return await self._call.read()
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
async def _wrapped_aiter(self):
try:
# NOTE(lidiz) coverage doesn't understand the exception raised from
# __anext__ method. It is covered by test case:
# test_wrap_stream_errors_aiter_non_rpc_error
async for response in self._call: # pragma: no branch
yield response
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
def __aiter__(self):
if not self._wrapped_async_generator:
self._wrapped_async_generator = self._wrapped_aiter()
return self._wrapped_async_generator
class _WrappedStreamRequestMixin(_WrappedCall):
async def write(self, request):
try:
await self._call.write(request)
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
async def done_writing(self):
try:
await self._call.done_writing()
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error
# NOTE(lidiz) Implementing each individual class separately, so we don't
# expose any API that should not be seen. E.g., __aiter__ in unary-unary
# RPC, or __await__ in stream-stream RPC.
class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin, aio.UnaryUnaryCall):
"""Wrapped UnaryUnaryCall to map exceptions."""
class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin, aio.UnaryStreamCall):
"""Wrapped UnaryStreamCall to map exceptions."""
class _WrappedStreamUnaryCall(
_WrappedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall
):
"""Wrapped StreamUnaryCall to map exceptions."""
class _WrappedStreamStreamCall(
_WrappedStreamRequestMixin, _WrappedStreamResponseMixin, aio.StreamStreamCall
):
"""Wrapped StreamStreamCall to map exceptions."""
def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary async callables."""
grpc_helpers._patch_callable_name(callable_)
@functools.wraps(callable_)
def error_remapped_callable(*args, **kwargs):
call = callable_(*args, **kwargs)
return _WrappedUnaryUnaryCall().with_call(call)
return error_remapped_callable
def _wrap_stream_errors(callable_):
"""Map errors for streaming RPC async callables."""
grpc_helpers._patch_callable_name(callable_)
@functools.wraps(callable_)
async def error_remapped_callable(*args, **kwargs):
call = callable_(*args, **kwargs)
if isinstance(call, aio.UnaryStreamCall):
call = _WrappedUnaryStreamCall().with_call(call)
elif isinstance(call, aio.StreamUnaryCall):
call = _WrappedStreamUnaryCall().with_call(call)
elif isinstance(call, aio.StreamStreamCall):
call = _WrappedStreamStreamCall().with_call(call)
else:
raise TypeError("Unexpected type of call %s" % type(call))
await call.wait_for_connection()
return call
return error_remapped_callable
def wrap_errors(callable_):
"""Wrap a gRPC async callable and map :class:`grpc.RpcErrors` to
friendly error classes.
Errors raised by the gRPC callable are mapped to the appropriate
:class:`google.api_core.exceptions.GoogleAPICallError` subclasses. The
original `grpc.RpcError` (which is usually also a `grpc.Call`) is
available from the ``response`` property on the mapped exception. This
is useful for extracting metadata from the original error.
Args:
callable_ (Callable): A gRPC callable.
Returns: Callable: The wrapped gRPC callable.
"""
if isinstance(callable_, aio.UnaryUnaryMultiCallable):
return _wrap_unary_errors(callable_)
else:
return _wrap_stream_errors(callable_)
def create_channel(
target,
credentials=None,
scopes=None,
ssl_credentials=None,
credentials_file=None,
quota_project_id=None,
default_scopes=None,
default_host=None,
**kwargs
):
"""Create an AsyncIO secure channel with credentials.
Args:
target (str): The target service address in the format 'hostname:port'.
credentials (google.auth.credentials.Credentials): The credentials. If
not specified, then this function will attempt to ascertain the
credentials from the environment using :func:`google.auth.default`.
scopes (Sequence[str]): A optional list of scopes needed for this
service. These are only used when credentials are not specified and
are passed to :func:`google.auth.default`.
ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
credentials. This can be used to specify different certificates.
credentials_file (str): A file with credentials that can be loaded with
:func:`google.auth.load_credentials_from_file`. This argument is
mutually exclusive with credentials.
quota_project_id (str): An optional project to use for billing and quota.
default_scopes (Sequence[str]): Default scopes passed by a Google client
library. Use 'scopes' for user-defined scopes.
default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
kwargs: Additional key-word args passed to :func:`aio.secure_channel`.
Returns:
aio.Channel: The created channel.
Raises:
google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
"""
composite_credentials = grpc_helpers._create_composite_credentials(
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
default_scopes=default_scopes,
ssl_credentials=ssl_credentials,
quota_project_id=quota_project_id,
default_host=default_host,
)
return aio.secure_channel(target, composite_credentials, **kwargs)
class FakeUnaryUnaryCall(_WrappedUnaryUnaryCall):
"""Fake implementation for unary-unary RPCs.
It is a dummy object for response message. Supply the intended response
upon the initialization, and the coroutine will return the exact response
message.
"""
def __init__(self, response=object()):
self.response = response
self._future = asyncio.get_event_loop().create_future()
self._future.set_result(self.response)
def __await__(self):
response = yield from self._future.__await__()
return response
class FakeStreamUnaryCall(_WrappedStreamUnaryCall):
"""Fake implementation for stream-unary RPCs.
It is a dummy object for response message. Supply the intended response
upon the initialization, and the coroutine will return the exact response
message.
"""
def __init__(self, response=object()):
self.response = response
self._future = asyncio.get_event_loop().create_future()
self._future.set_result(self.response)
def __await__(self):
response = yield from self._future.__await__()
return response
async def wait_for_connection(self):
pass