| # Copyright 2017 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Helpers for :mod:`grpc`.""" |
| |
| import collections |
| import functools |
| |
| import grpc |
| import pkg_resources |
| |
| from google.api_core import exceptions |
| import google.auth |
| import google.auth.credentials |
| import google.auth.transport.grpc |
| import google.auth.transport.requests |
| |
| try: |
| import grpc_gcp |
| |
| HAS_GRPC_GCP = True |
| except ImportError: |
| HAS_GRPC_GCP = False |
| |
| try: |
| # google.auth.__version__ was added in 1.26.0 |
| _GOOGLE_AUTH_VERSION = google.auth.__version__ |
| except AttributeError: |
| try: # try pkg_resources if it is available |
| _GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version |
| except pkg_resources.DistributionNotFound: # pragma: NO COVER |
| _GOOGLE_AUTH_VERSION = None |
| |
| # The list of gRPC Callable interfaces that return iterators. |
| _STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable) |
| |
| |
| def _patch_callable_name(callable_): |
| """Fix-up gRPC callable attributes. |
| |
| gRPC callable lack the ``__name__`` attribute which causes |
| :func:`functools.wraps` to error. This adds the attribute if needed. |
| """ |
| if not hasattr(callable_, "__name__"): |
| callable_.__name__ = callable_.__class__.__name__ |
| |
| |
| def _wrap_unary_errors(callable_): |
| """Map errors for Unary-Unary and Stream-Unary gRPC callables.""" |
| _patch_callable_name(callable_) |
| |
| @functools.wraps(callable_) |
| def error_remapped_callable(*args, **kwargs): |
| try: |
| return callable_(*args, **kwargs) |
| except grpc.RpcError as exc: |
| raise exceptions.from_grpc_error(exc) from exc |
| |
| return error_remapped_callable |
| |
| |
| class _StreamingResponseIterator(grpc.Call): |
| def __init__(self, wrapped, prefetch_first_result=True): |
| self._wrapped = wrapped |
| |
| # This iterator is used in a retry context, and returned outside after init. |
| # gRPC will not throw an exception until the stream is consumed, so we need |
| # to retrieve the first result, in order to fail, in order to trigger a retry. |
| try: |
| if prefetch_first_result: |
| self._stored_first_result = next(self._wrapped) |
| except TypeError: |
| # It is possible the wrapped method isn't an iterable (a grpc.Call |
| # for instance). If this happens don't store the first result. |
| pass |
| except StopIteration: |
| # ignore stop iteration at this time. This should be handled outside of retry. |
| pass |
| |
| def __iter__(self): |
| """This iterator is also an iterable that returns itself.""" |
| return self |
| |
| def __next__(self): |
| """Get the next response from the stream. |
| |
| Returns: |
| protobuf.Message: A single response from the stream. |
| """ |
| try: |
| if hasattr(self, "_stored_first_result"): |
| result = self._stored_first_result |
| del self._stored_first_result |
| return result |
| return next(self._wrapped) |
| except grpc.RpcError as exc: |
| # If the stream has already returned data, we cannot recover here. |
| raise exceptions.from_grpc_error(exc) from exc |
| |
| # grpc.Call & grpc.RpcContext interface |
| |
| def add_callback(self, callback): |
| return self._wrapped.add_callback(callback) |
| |
| def cancel(self): |
| return self._wrapped.cancel() |
| |
| def code(self): |
| return self._wrapped.code() |
| |
| def details(self): |
| return self._wrapped.details() |
| |
| def initial_metadata(self): |
| return self._wrapped.initial_metadata() |
| |
| def is_active(self): |
| return self._wrapped.is_active() |
| |
| def time_remaining(self): |
| return self._wrapped.time_remaining() |
| |
| def trailing_metadata(self): |
| return self._wrapped.trailing_metadata() |
| |
| |
| def _wrap_stream_errors(callable_): |
| """Wrap errors for Unary-Stream and Stream-Stream gRPC callables. |
| |
| The callables that return iterators require a bit more logic to re-map |
| errors when iterating. This wraps both the initial invocation and the |
| iterator of the return value to re-map errors. |
| """ |
| _patch_callable_name(callable_) |
| |
| @functools.wraps(callable_) |
| def error_remapped_callable(*args, **kwargs): |
| try: |
| result = callable_(*args, **kwargs) |
| # Auto-fetching the first result causes PubSub client's streaming pull |
| # to hang when re-opening the stream, thus we need examine the hacky |
| # hidden flag to see if pre-fetching is disabled. |
| # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257 |
| prefetch_first = getattr(callable_, "_prefetch_first_result_", True) |
| return _StreamingResponseIterator( |
| result, prefetch_first_result=prefetch_first |
| ) |
| except grpc.RpcError as exc: |
| raise exceptions.from_grpc_error(exc) from exc |
| |
| return error_remapped_callable |
| |
| |
| def wrap_errors(callable_): |
| """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error |
| classes. |
| |
| Errors raised by the gRPC callable are mapped to the appropriate |
| :class:`google.api_core.exceptions.GoogleAPICallError` subclasses. |
| The original `grpc.RpcError` (which is usually also a `grpc.Call`) is |
| available from the ``response`` property on the mapped exception. This |
| is useful for extracting metadata from the original error. |
| |
| Args: |
| callable_ (Callable): A gRPC callable. |
| |
| Returns: |
| Callable: The wrapped gRPC callable. |
| """ |
| if isinstance(callable_, _STREAM_WRAP_CLASSES): |
| return _wrap_stream_errors(callable_) |
| else: |
| return _wrap_unary_errors(callable_) |
| |
| |
| def _create_composite_credentials( |
| credentials=None, |
| credentials_file=None, |
| default_scopes=None, |
| scopes=None, |
| ssl_credentials=None, |
| quota_project_id=None, |
| default_host=None, |
| ): |
| """Create the composite credentials for secure channels. |
| |
| Args: |
| credentials (google.auth.credentials.Credentials): The credentials. If |
| not specified, then this function will attempt to ascertain the |
| credentials from the environment using :func:`google.auth.default`. |
| credentials_file (str): A file with credentials that can be loaded with |
| :func:`google.auth.load_credentials_from_file`. This argument is |
| mutually exclusive with credentials. |
| default_scopes (Sequence[str]): A optional list of scopes needed for this |
| service. These are only used when credentials are not specified and |
| are passed to :func:`google.auth.default`. |
| scopes (Sequence[str]): A optional list of scopes needed for this |
| service. These are only used when credentials are not specified and |
| are passed to :func:`google.auth.default`. |
| ssl_credentials (grpc.ChannelCredentials): Optional SSL channel |
| credentials. This can be used to specify different certificates. |
| quota_project_id (str): An optional project to use for billing and quota. |
| default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". |
| |
| Returns: |
| grpc.ChannelCredentials: The composed channel credentials object. |
| |
| Raises: |
| google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. |
| """ |
| if credentials and credentials_file: |
| raise exceptions.DuplicateCredentialArgs( |
| "'credentials' and 'credentials_file' are mutually exclusive." |
| ) |
| |
| if credentials_file: |
| credentials, _ = google.auth.load_credentials_from_file( |
| credentials_file, scopes=scopes, default_scopes=default_scopes |
| ) |
| elif credentials: |
| credentials = google.auth.credentials.with_scopes_if_required( |
| credentials, scopes=scopes, default_scopes=default_scopes |
| ) |
| else: |
| credentials, _ = google.auth.default( |
| scopes=scopes, default_scopes=default_scopes |
| ) |
| |
| if quota_project_id and isinstance( |
| credentials, google.auth.credentials.CredentialsWithQuotaProject |
| ): |
| credentials = credentials.with_quota_project(quota_project_id) |
| |
| request = google.auth.transport.requests.Request() |
| |
| # Create the metadata plugin for inserting the authorization header. |
| metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin( |
| credentials, request, default_host=default_host, |
| ) |
| |
| # Create a set of grpc.CallCredentials using the metadata plugin. |
| google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) |
| |
| if ssl_credentials is None: |
| ssl_credentials = grpc.ssl_channel_credentials() |
| |
| # Combine the ssl credentials and the authorization credentials. |
| return grpc.composite_channel_credentials(ssl_credentials, google_auth_credentials) |
| |
| |
| def create_channel( |
| target, |
| credentials=None, |
| scopes=None, |
| ssl_credentials=None, |
| credentials_file=None, |
| quota_project_id=None, |
| default_scopes=None, |
| default_host=None, |
| **kwargs |
| ): |
| """Create a secure channel with credentials. |
| |
| Args: |
| target (str): The target service address in the format 'hostname:port'. |
| credentials (google.auth.credentials.Credentials): The credentials. If |
| not specified, then this function will attempt to ascertain the |
| credentials from the environment using :func:`google.auth.default`. |
| scopes (Sequence[str]): A optional list of scopes needed for this |
| service. These are only used when credentials are not specified and |
| are passed to :func:`google.auth.default`. |
| ssl_credentials (grpc.ChannelCredentials): Optional SSL channel |
| credentials. This can be used to specify different certificates. |
| credentials_file (str): A file with credentials that can be loaded with |
| :func:`google.auth.load_credentials_from_file`. This argument is |
| mutually exclusive with credentials. |
| quota_project_id (str): An optional project to use for billing and quota. |
| default_scopes (Sequence[str]): Default scopes passed by a Google client |
| library. Use 'scopes' for user-defined scopes. |
| default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". |
| kwargs: Additional key-word args passed to |
| :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`. |
| |
| Returns: |
| grpc.Channel: The created channel. |
| |
| Raises: |
| google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. |
| """ |
| |
| composite_credentials = _create_composite_credentials( |
| credentials=credentials, |
| credentials_file=credentials_file, |
| default_scopes=default_scopes, |
| scopes=scopes, |
| ssl_credentials=ssl_credentials, |
| quota_project_id=quota_project_id, |
| default_host=default_host, |
| ) |
| |
| if HAS_GRPC_GCP: |
| # If grpc_gcp module is available use grpc_gcp.secure_channel, |
| # otherwise, use grpc.secure_channel to create grpc channel. |
| return grpc_gcp.secure_channel(target, composite_credentials, **kwargs) |
| else: |
| return grpc.secure_channel(target, composite_credentials, **kwargs) |
| |
| |
| _MethodCall = collections.namedtuple( |
| "_MethodCall", ("request", "timeout", "metadata", "credentials") |
| ) |
| |
| _ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request")) |
| |
| |
| class _CallableStub(object): |
| """Stub for the grpc.*MultiCallable interfaces.""" |
| |
| def __init__(self, method, channel): |
| self._method = method |
| self._channel = channel |
| self.response = None |
| """Union[protobuf.Message, Callable[protobuf.Message], exception]: |
| The response to give when invoking this callable. If this is a |
| callable, it will be invoked with the request protobuf. If it's an |
| exception, the exception will be raised when this is invoked. |
| """ |
| self.responses = None |
| """Iterator[ |
| Union[protobuf.Message, Callable[protobuf.Message], exception]]: |
| An iterator of responses. If specified, self.response will be populated |
| on each invocation by calling ``next(self.responses)``.""" |
| self.requests = [] |
| """List[protobuf.Message]: All requests sent to this callable.""" |
| self.calls = [] |
| """List[Tuple]: All invocations of this callable. Each tuple is the |
| request, timeout, metadata, and credentials.""" |
| |
| def __call__(self, request, timeout=None, metadata=None, credentials=None): |
| self._channel.requests.append(_ChannelRequest(self._method, request)) |
| self.calls.append(_MethodCall(request, timeout, metadata, credentials)) |
| self.requests.append(request) |
| |
| response = self.response |
| if self.responses is not None: |
| if response is None: |
| response = next(self.responses) |
| else: |
| raise ValueError( |
| "{method}.response and {method}.responses are mutually " |
| "exclusive.".format(method=self._method) |
| ) |
| |
| if callable(response): |
| return response(request) |
| |
| if isinstance(response, Exception): |
| raise response |
| |
| if response is not None: |
| return response |
| |
| raise ValueError('Method stub for "{}" has no response.'.format(self._method)) |
| |
| |
| def _simplify_method_name(method): |
| """Simplifies a gRPC method name. |
| |
| When gRPC invokes the channel to create a callable, it gives a full |
| method name like "/google.pubsub.v1.Publisher/CreateTopic". This |
| returns just the name of the method, in this case "CreateTopic". |
| |
| Args: |
| method (str): The name of the method. |
| |
| Returns: |
| str: The simplified name of the method. |
| """ |
| return method.rsplit("/", 1).pop() |
| |
| |
| class ChannelStub(grpc.Channel): |
| """A testing stub for the grpc.Channel interface. |
| |
| This can be used to test any client that eventually uses a gRPC channel |
| to communicate. By passing in a channel stub, you can configure which |
| responses are returned and track which requests are made. |
| |
| For example: |
| |
| .. code-block:: python |
| |
| channel_stub = grpc_helpers.ChannelStub() |
| client = FooClient(channel=channel_stub) |
| |
| channel_stub.GetFoo.response = foo_pb2.Foo(name='bar') |
| |
| foo = client.get_foo(labels=['baz']) |
| |
| assert foo.name == 'bar' |
| assert channel_stub.GetFoo.requests[0].labels = ['baz'] |
| |
| Each method on the stub can be accessed and configured on the channel. |
| Here's some examples of various configurations: |
| |
| .. code-block:: python |
| |
| # Return a basic response: |
| |
| channel_stub.GetFoo.response = foo_pb2.Foo(name='bar') |
| assert client.get_foo().name == 'bar' |
| |
| # Raise an exception: |
| channel_stub.GetFoo.response = NotFound('...') |
| |
| with pytest.raises(NotFound): |
| client.get_foo() |
| |
| # Use a sequence of responses: |
| channel_stub.GetFoo.responses = iter([ |
| foo_pb2.Foo(name='bar'), |
| foo_pb2.Foo(name='baz'), |
| ]) |
| |
| assert client.get_foo().name == 'bar' |
| assert client.get_foo().name == 'baz' |
| |
| # Use a callable |
| |
| def on_get_foo(request): |
| return foo_pb2.Foo(name='bar' + request.id) |
| |
| channel_stub.GetFoo.response = on_get_foo |
| |
| assert client.get_foo(id='123').name == 'bar123' |
| """ |
| |
| def __init__(self, responses=[]): |
| self.requests = [] |
| """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made |
| on this channel in order. The tuple is of method name, request |
| message.""" |
| self._method_stubs = {} |
| |
| def _stub_for_method(self, method): |
| method = _simplify_method_name(method) |
| self._method_stubs[method] = _CallableStub(method, self) |
| return self._method_stubs[method] |
| |
| def __getattr__(self, key): |
| try: |
| return self._method_stubs[key] |
| except KeyError: |
| raise AttributeError |
| |
| def unary_unary(self, method, request_serializer=None, response_deserializer=None): |
| """grpc.Channel.unary_unary implementation.""" |
| return self._stub_for_method(method) |
| |
| def unary_stream(self, method, request_serializer=None, response_deserializer=None): |
| """grpc.Channel.unary_stream implementation.""" |
| return self._stub_for_method(method) |
| |
| def stream_unary(self, method, request_serializer=None, response_deserializer=None): |
| """grpc.Channel.stream_unary implementation.""" |
| return self._stub_for_method(method) |
| |
| def stream_stream( |
| self, method, request_serializer=None, response_deserializer=None |
| ): |
| """grpc.Channel.stream_stream implementation.""" |
| return self._stub_for_method(method) |
| |
| def subscribe(self, callback, try_to_connect=False): |
| """grpc.Channel.subscribe implementation.""" |
| pass |
| |
| def unsubscribe(self, callback): |
| """grpc.Channel.unsubscribe implementation.""" |
| pass |
| |
| def close(self): |
| """grpc.Channel.close implementation.""" |
| pass |