blob: 3e155532d583bc2a34c208fbe2a479fd5fc72123 [file] [log] [blame]
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import ClassVar, Dict, Optional
# Workaround: `grpc` must be imported before `google.protobuf.json_format`,
# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
import grpc
from google.protobuf import json_format
import google.protobuf.message
logger = logging.getLogger(__name__)
# Type aliases
Message = google.protobuf.message.Message
class GrpcClientHelper:
channel: grpc.Channel
DEFAULT_RPC_DEADLINE_SEC = 90
def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
self.channel = channel
self.stub = stub_class(channel)
# This is purely cosmetic to make RPC logs look like method calls.
self.log_service_name = re.sub('Stub$', '',
self.stub.__class__.__name__)
def call_unary_with_deadline(
self,
*,
rpc: str,
req: Message,
deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC,
log_level: Optional[int] = logging.DEBUG) -> Message:
if deadline_sec is None:
deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC
call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec)
self._log_rpc_request(rpc, req, call_kwargs, log_level)
# Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options)
rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
return rpc_callable(req, **call_kwargs)
def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):
logger.log(logging.DEBUG if log_level is None else log_level,
'RPC %s.%s(request=%s(%r), %s)', self.log_service_name, rpc,
req.__class__.__name__, json_format.MessageToDict(req),
', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
class GrpcApp:
channels: Dict[int, grpc.Channel]
class NotFound(Exception):
"""Requested resource not found"""
def __init__(self, message):
self.message = message
super().__init__(message)
def __init__(self, rpc_host):
self.rpc_host = rpc_host
# Cache gRPC channels per port
self.channels = dict()
def _make_channel(self, port) -> grpc.Channel:
if port not in self.channels:
target = f'{self.rpc_host}:{port}'
self.channels[port] = grpc.insecure_channel(target)
return self.channels[port]
def close(self):
# Close all channels
for channel in self.channels.values():
channel.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def __del__(self):
self.close()