blob: f23b07c915142be56359e12cadf9a5aa08fc781b [file] [log] [blame]
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import multiprocessing
import sys
import time
from typing import Tuple
import collections
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2,
stats_pb2, worker_service_pb2_grpc)
from tests.qps import histogram
from tests.unit import resources
from tests_aio.benchmark import benchmark_client, benchmark_servicer
from tests.unit.framework.common import get_socket
_NUM_CORES = multiprocessing.cpu_count()
_NUM_CORE_PYTHON_CAN_USE = 1
_WORKER_ENTRY_FILE = os.path.split(os.path.abspath(__file__))[0] + '/worker.py'
_LOGGER = logging.getLogger(__name__)
class _SubWorker(
collections.namedtuple('_SubWorker',
['process', 'port', 'channel', 'stub'])):
def _repr(self):
return f'<_SubWorker pid={self.process.pid} port={self.port}>'
def __repr__(self):
return self._repr()
def __str__(self):
return self._repr()
def _get_server_status(start_time: float, end_time: float,
port: int) -> control_pb2.ServerStatus:
end_time = time.time()
elapsed_time = end_time - start_time
stats = stats_pb2.ServerStats(time_elapsed=elapsed_time,
time_user=elapsed_time,
time_system=elapsed_time)
return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES)
def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]:
server = aio.server(options=(('grpc.so_reuseport', 1),))
if config.server_type == control_pb2.ASYNC_SERVER:
servicer = benchmark_servicer.BenchmarkServicer()
benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
servicer, server)
elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
resp_size = config.payload_config.bytebuf_params.resp_size
servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size)
method_implementations = {
'StreamingCall':
grpc.stream_stream_rpc_method_handler(servicer.StreamingCall),
'UnaryCall':
grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
}
handler = grpc.method_handlers_generic_handler(
'grpc.testing.BenchmarkService', method_implementations)
server.add_generic_rpc_handlers((handler,))
else:
raise NotImplementedError('Unsupported server type {}'.format(
config.server_type))
if config.HasField('security_params'): # Use SSL
server_creds = grpc.ssl_server_credentials(
((resources.private_key(), resources.certificate_chain()),))
port = server.add_secure_port('[::]:{}'.format(config.port),
server_creds)
else:
port = server.add_insecure_port('[::]:{}'.format(config.port))
return server, port
def _get_client_status(start_time: float, end_time: float,
qps_data: histogram.Histogram
) -> control_pb2.ClientStatus:
latencies = qps_data.get_data()
end_time = time.time()
elapsed_time = end_time - start_time
stats = stats_pb2.ClientStats(latencies=latencies,
time_elapsed=elapsed_time,
time_user=elapsed_time,
time_system=elapsed_time)
return control_pb2.ClientStatus(stats=stats)
def _create_client(server: str, config: control_pb2.ClientConfig,
qps_data: histogram.Histogram
) -> benchmark_client.BenchmarkClient:
if config.load_params.WhichOneof('load') != 'closed_loop':
raise NotImplementedError(
f'Unsupported load parameter {config.load_params}')
if config.client_type == control_pb2.ASYNC_CLIENT:
if config.rpc_type == control_pb2.UNARY:
client_type = benchmark_client.UnaryAsyncBenchmarkClient
elif config.rpc_type == control_pb2.STREAMING:
client_type = benchmark_client.StreamingAsyncBenchmarkClient
else:
raise NotImplementedError(
f'Unsupported rpc_type [{config.rpc_type}]')
else:
raise NotImplementedError(
f'Unsupported client type {config.client_type}')
return client_type(server, config, qps_data)
def _pick_an_unused_port() -> int:
_, port, sock = get_socket()
sock.close()
return port
async def _create_sub_worker() -> _SubWorker:
port = _pick_an_unused_port()
_LOGGER.info('Creating sub worker at port [%d]...', port)
process = await asyncio.create_subprocess_exec(sys.executable,
_WORKER_ENTRY_FILE,
'--driver_port', str(port))
_LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port,
process.pid)
channel = aio.insecure_channel(f'localhost:{port}')
_LOGGER.info('Waiting for sub worker at port [%d]', port)
await channel.channel_ready()
stub = worker_service_pb2_grpc.WorkerServiceStub(channel)
return _SubWorker(
process=process,
port=port,
channel=channel,
stub=stub,
)
class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
"""Python Worker Server implementation."""
def __init__(self):
self._loop = asyncio.get_event_loop()
self._quit_event = asyncio.Event()
async def _run_single_server(self, config, request_iterator, context):
server, port = _create_server(config)
await server.start()
_LOGGER.info('Server started at port [%d]', port)
start_time = time.time()
await context.write(_get_server_status(start_time, start_time, port))
async for request in request_iterator:
end_time = time.time()
status = _get_server_status(start_time, end_time, port)
if request.mark.reset:
start_time = end_time
await context.write(status)
await server.stop(None)
async def RunServer(self, request_iterator, context):
config_request = await context.read()
config = config_request.setup
_LOGGER.info('Received ServerConfig: %s', config)
if config.async_server_threads <= 0:
_LOGGER.info('async_server_threads can\'t be [%d]',
config.async_server_threads)
_LOGGER.info('Using async_server_threads == [%d]', _NUM_CORES)
config.async_server_threads = _NUM_CORES
if config.port == 0:
config.port = _pick_an_unused_port()
_LOGGER.info('Port picked [%d]', config.port)
if config.async_server_threads == 1:
await self._run_single_server(config, request_iterator, context)
else:
sub_workers = await asyncio.gather(*(
_create_sub_worker()
for _ in range(config.async_server_threads)))
calls = [worker.stub.RunServer() for worker in sub_workers]
config_request.setup.async_server_threads = 1
for call in calls:
await call.write(config_request)
# An empty status indicates the peer is ready
await call.read()
start_time = time.time()
await context.write(
_get_server_status(
start_time,
start_time,
config.port,
))
async for request in request_iterator:
end_time = time.time()
for call in calls:
_LOGGER.debug('Fetching status...')
await call.write(request)
# Reports from sub workers doesn't matter
await call.read()
status = _get_server_status(
start_time,
end_time,
config.port,
)
if request.mark.reset:
start_time = end_time
await context.write(status)
for call in calls:
await call.done_writing()
for worker in sub_workers:
await worker.stub.QuitWorker(control_pb2.Void())
await worker.channel.close()
_LOGGER.info('Waiting for [%s] to quit...', worker)
await worker.process.wait()
async def _run_single_client(self, config, request_iterator, context):
running_tasks = []
qps_data = histogram.Histogram(config.histogram_params.resolution,
config.histogram_params.max_possible)
start_time = time.time()
# Create a client for each channel as asyncio.Task
for i in range(config.client_channels):
server = config.server_targets[i % len(config.server_targets)]
client = _create_client(server, config, qps_data)
_LOGGER.info('Client created against server [%s]', server)
running_tasks.append(self._loop.create_task(client.run()))
end_time = time.time()
await context.write(_get_client_status(start_time, end_time, qps_data))
# Respond to stat requests
async for request in request_iterator:
end_time = time.time()
status = _get_client_status(start_time, end_time, qps_data)
if request.mark.reset:
qps_data.reset()
start_time = time.time()
await context.write(status)
# Cleanup the clients
for task in running_tasks:
task.cancel()
async def RunClient(self, request_iterator, context):
config_request = await context.read()
config = config_request.setup
_LOGGER.info('Received ClientConfig: %s', config)
if config.async_client_threads <= 0:
_LOGGER.info('async_client_threads can\'t be [%d]',
config.async_client_threads)
_LOGGER.info('Using async_client_threads == [%d]', _NUM_CORES)
config.async_client_threads = _NUM_CORES
if config.async_client_threads == 1:
await self._run_single_client(config, request_iterator, context)
else:
sub_workers = await asyncio.gather(*(
_create_sub_worker()
for _ in range(config.async_client_threads)))
calls = [worker.stub.RunClient() for worker in sub_workers]
config_request.setup.async_client_threads = 1
for call in calls:
await call.write(config_request)
# An empty status indicates the peer is ready
await call.read()
start_time = time.time()
result = histogram.Histogram(config.histogram_params.resolution,
config.histogram_params.max_possible)
end_time = time.time()
await context.write(_get_client_status(start_time, end_time,
result))
async for request in request_iterator:
end_time = time.time()
for call in calls:
_LOGGER.debug('Fetching status...')
await call.write(request)
sub_status = await call.read()
result.merge(sub_status.stats.latencies)
_LOGGER.debug('Update from sub worker count=[%d]',
sub_status.stats.latencies.count)
status = _get_client_status(start_time, end_time, result)
if request.mark.reset:
result.reset()
start_time = time.time()
_LOGGER.debug('Reporting count=[%d]',
status.stats.latencies.count)
await context.write(status)
for call in calls:
await call.done_writing()
for worker in sub_workers:
await worker.stub.QuitWorker(control_pb2.Void())
await worker.channel.close()
_LOGGER.info('Waiting for sub worker [%s] to quit...', worker)
await worker.process.wait()
_LOGGER.info('Sub worker [%s] quit', worker)
@staticmethod
async def CoreCount(unused_request, unused_context):
return control_pb2.CoreResponse(cores=_NUM_CORES)
async def QuitWorker(self, unused_request, unused_context):
_LOGGER.info('QuitWorker command received.')
self._quit_event.set()
return control_pb2.Void()
async def wait_for_quit(self):
await self._quit_event.wait()