blob: 8050261dcb30e53b894b7e282a065f33e448bfd8 [file] [log] [blame]
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Avatar is a scalable multi-platform Bluetooth testing tool capable of running
any Bluetooth test cases virtually and physically.
"""
__version__ = "0.0.1"
import functools
import grpc
import grpc.aio
import importlib
import logging
from avatar import pandora_server
from avatar.aio import asynchronous
from avatar.pandora_client import BumblePandoraClient as BumblePandoraDevice, PandoraClient as PandoraDevice
from avatar.pandora_server import PandoraServer
from mobly import base_test
from typing import Any, Callable, Dict, Iterable, Iterator, List, Sized, Tuple, Type, TypeVar
# public symbols
__all__ = [
'asynchronous',
'parameterized',
'rpc_except',
'PandoraDevices',
'PandoraDevice',
'BumblePandoraDevice',
]
PANDORA_COMMON_SERVER_CLASSES: Dict[str, Type[pandora_server.PandoraServer[Any]]] = {
'PandoraDevice': pandora_server.PandoraServer,
'AndroidDevice': pandora_server.AndroidPandoraServer,
'BumbleDevice': pandora_server.BumblePandoraServer,
}
KEY_PANDORA_SERVER_CLASS = 'pandora_server_class'
class PandoraDevices(Sized, Iterable[PandoraDevice]):
"""Utility for abstracting controller registration and Pandora setup."""
_test: base_test.BaseTestClass
_clients: List[PandoraDevice]
_servers: List[PandoraServer[Any]]
def __init__(self, test: base_test.BaseTestClass) -> None:
"""Creates a PandoraDevices list.
It performs three steps:
- Register the underlying controllers to the test.
- Start the corresponding PandoraServer for each controller.
- Store a PandoraClient for each server.
The order in which the clients are returned can be determined by the
(optional) `order_<controller_class>` params in user_params. Controllers
without such a param will be set up last (order=100).
Args:
test: Instance of the Mobly test class.
"""
self._test = test
self._clients = []
self._servers = []
user_params: Dict[str, Any] = test.user_params # type: ignore
controller_configs: Dict[str, Any] = test.controller_configs.copy() # type: ignore
sorted_controllers = sorted(
controller_configs.keys(), key=lambda controller: user_params.get(f'order_{controller}', 100)
)
for controller in sorted_controllers:
# Find the corresponding PandoraServer class for the controller.
if f'{KEY_PANDORA_SERVER_CLASS}_{controller}' in user_params:
# Try to load the server dynamically if module specified in user_params.
class_path = user_params[f'{KEY_PANDORA_SERVER_CLASS}_{controller}']
logging.info('Loading Pandora server class %s from config for %s.', class_path, controller)
server_cls = _load_pandora_server_class(class_path)
else:
# Search in the list of commonly-used controllers.
try:
server_cls = PANDORA_COMMON_SERVER_CLASSES[controller]
except KeyError as e:
raise RuntimeError(
f'PandoraServer module for {controller} not found in either the '
'config or PANDORA_COMMON_SERVER_CLASSES.'
) from e
# Register the controller and load its Pandora servers.
logging.info('Starting %s(s) for %s', server_cls.__name__, controller)
devices: Optional[List[Any]] = test.register_controller(server_cls.MOBLY_CONTROLLER_MODULE) # type: ignore
assert devices
for device in devices: # type: ignore
self._servers.append(server_cls(device))
self.start_all()
def __len__(self) -> int:
return len(self._clients)
def __iter__(self) -> Iterator[PandoraDevice]:
return iter(self._clients)
def start_all(self) -> None:
"""Start all Pandora servers and returns their clients."""
if len(self._clients):
return
for server in self._servers:
self._clients.append(server.start())
def stop_all(self) -> None:
"""Closes all opened Pandora clients and servers."""
if not len(self._clients):
return
for client in self:
client.close()
for server in self._servers:
server.stop()
self._clients.clear()
def _load_pandora_server_class(class_path: str) -> Type[pandora_server.PandoraServer[Any]]:
"""Dynamically load a PandoraServer from a user-specified module+class.
Args:
class_path: String in format '<module>.<class>', where the module is fully
importable using importlib.import_module. e.g.:
my.pandora.server.module.MyPandoraServer
Returns:
The loaded PandoraServer instance.
"""
# Dynamically import the module, and get the class
module_name, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_name)
server_class = getattr(module, class_name)
# Check that the class is a subclass of PandoraServer
if not issubclass(server_class, pandora_server.PandoraServer):
raise TypeError(f'The specified class {class_path} is not a subclass of PandoraServer.')
return server_class # type: ignore
class Wrapper(object):
func: Callable[..., Any]
def __init__(self, func: Callable[..., Any]) -> None:
self.func = func
# Multiply the same function from `inputs` parameters
def parameterized(*inputs: Tuple[Any, ...]) -> Type[Wrapper]:
class wrapper(Wrapper):
def __set_name__(self, owner: str, name: str) -> None:
for input in inputs:
def decorate(input: Tuple[Any, ...]) -> Callable[..., Any]:
@functools.wraps(self.func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return self.func(*args, *input, **kwargs)
return wrapper
# we need to pass `input` here, otherwise it will be set to the value
# from the last iteration of `inputs`
setattr(owner, f"{name}{input}".replace(' ', ''), decorate(input))
delattr(owner, name)
return wrapper
_T = TypeVar('_T')
# Decorate a test function with a wrapper that catch gRPC errors
# and call a callback if the status `code` match.
def rpc_except(
excepts: Dict[grpc.StatusCode, Callable[[grpc.aio.AioRpcError], Any]],
) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
def wrap(func: Callable[..., _T]) -> Callable[..., _T]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> _T:
try:
return func(*args, **kwargs)
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
if f := excepts.get(e.code(), None): # type: ignore
return f(e) # type: ignore
raise e
return wrapper
return wrap