Snap for 9710098 from ecbe4f99c6651fa5ae36bb38e8fc00e02b70d1c2 to mainline-tzdata5-release
Change-Id: I6de3aaf5f76d838e1536ee6ec7cb3d6305989ce1
diff --git a/Android.bp b/Android.bp
index 4667c9a..067d089 100644
--- a/Android.bp
+++ b/Android.bp
@@ -21,11 +21,12 @@
"avatar/*.py",
"avatar/bumble_server/*.py",
"avatar/controllers/*.py",
+ "avatar/servers/*.py"
],
libs: [
- "mobly",
"pandora-python",
"libprotobuf-python",
"bumble",
+ "mobly",
],
}
diff --git a/README.md b/README.md
index 01262a0..211a20d 100644
--- a/README.md
+++ b/README.md
@@ -39,3 +39,15 @@
```
python examples/example.py -c examples/simulated_bumble_bumble.yml --verbose
```
+
+3. Lint with `pyright` and `mypy`
+```
+pyright
+mypy
+```
+
+3. Format & imports style
+```
+black avatar/ examples/
+isort avatar/ examples/
+```
diff --git a/avatar/__init__.py b/avatar/__init__.py
index fd86b7a..a2e2c12 100644
--- a/avatar/__init__.py
+++ b/avatar/__init__.py
@@ -19,56 +19,156 @@
__version__ = "0.0.1"
-
-import asyncio
import functools
+import importlib
+import logging
-from threading import Thread
+from avatar import pandora_server
+from avatar.aio import asynchronous
+from avatar.pandora_client import BumblePandoraClient as BumbleDevice, PandoraClient as PandoraDevice
+from avatar.pandora_server import PandoraServer
+from mobly import base_test
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Sized, Tuple, Type
+
+# public symbols
+__all__ = [
+ 'asynchronous',
+ 'parameterized',
+ 'PandoraDevices',
+ 'PandoraDevice',
+ 'BumbleDevice',
+]
-# Keep running an event loop is a separate thread,
-# which is then used to:
-# * Schedule Bumble(s) IO & gRPC server.
-# * Schedule asynchronous tests.
-loop = asyncio.new_event_loop()
+PANDORA_COMMON_SERVER_CLASSES: Dict[str, Type[pandora_server.PandoraServer[Any]]] = {
+ 'PandoraDevice': pandora_server.PandoraServer,
+ 'AndroidDevice': pandora_server.AndroidPandoraServer,
+ 'BumbleDevice': pandora_server.BumblePandoraServer,
+}
-def thread_loop():
- loop.run_forever()
- loop.run_until_complete(loop.shutdown_asyncgens())
-
-thread = Thread(target=thread_loop, daemon=True)
-thread.start()
+KEY_PANDORA_SERVER_CLASS = 'pandora_server_class'
-# run coroutine into our loop until complete
-def run_until_complete(coro):
- return asyncio.run_coroutine_threadsafe(coro, loop).result()
+class PandoraDevices(Sized, Iterable[PandoraDevice]):
+ """Utility for abstracting controller registration and Pandora setup."""
+
+ _test: base_test.BaseTestClass
+ _clients: List[PandoraDevice]
+ _servers: List[PandoraServer[Any]]
+
+ def __init__(self, test: base_test.BaseTestClass) -> None:
+ """Creates a PandoraDevices list.
+
+ It performs three steps:
+ - Register the underlying controllers to the test.
+ - Start the corresponding PandoraServer for each controller.
+ - Store a PandoraClient for each server.
+
+ The order in which the clients are returned can be determined by the
+ (optional) `order_<controller_class>` params in user_params. Controllers
+ without such a param will be set up last (order=100).
+
+ Args:
+ test: Instance of the Mobly test class.
+ """
+ self._test = test
+ self._clients = []
+ self._servers = []
+
+ user_params: Dict[str, Any] = test.user_params # type: ignore
+ controller_configs: Dict[str, Any] = test.controller_configs.copy() # type: ignore
+ sorted_controllers = sorted(
+ controller_configs.keys(), key=lambda controller: user_params.get(f'order_{controller}', 100)
+ )
+ for controller in sorted_controllers:
+ # Find the corresponding PandoraServer class for the controller.
+ if f'{KEY_PANDORA_SERVER_CLASS}_{controller}' in user_params:
+ # Try to load the server dynamically if module specified in user_params.
+ class_path = user_params[f'{KEY_PANDORA_SERVER_CLASS}_{controller}']
+ logging.info('Loading Pandora server class %s from config for %s.', class_path, controller)
+ server_cls = _load_pandora_server_class(class_path)
+ else:
+ # Search in the list of commonly-used controllers.
+ try:
+ server_cls = PANDORA_COMMON_SERVER_CLASSES[controller]
+ except KeyError as e:
+ raise RuntimeError(
+ f'PandoraServer module for {controller} not found in either the '
+ 'config or PANDORA_COMMON_SERVER_CLASSES.'
+ ) from e
+
+ # Register the controller and load its Pandora servers.
+ logging.info('Starting %s(s) for %s', server_cls.__name__, controller)
+ devices: Optional[List[Any]] = test.register_controller(server_cls.MOBLY_CONTROLLER_MODULE) # type: ignore
+ assert devices
+ for device in devices: # type: ignore
+ self._servers.append(server_cls(device))
+
+ self.start_all()
+
+ def __len__(self) -> int:
+ return len(self._clients)
+
+ def __iter__(self) -> Iterator[PandoraDevice]:
+ return iter(self._clients)
+
+ def start_all(self) -> None:
+ """Start all Pandora servers and returns their clients."""
+ if len(self._clients):
+ return
+ for server in self._servers:
+ self._clients.append(server.start())
+
+ def stop_all(self) -> None:
+ """Closes all opened Pandora clients and servers."""
+ if not len(self._clients):
+ return
+ for client in self:
+ client.close()
+ for server in self._servers:
+ server.stop()
+ self._clients.clear()
-# Convert an asynchronous function to a synchronous one by
-# executing it's code within our loop
-def asynchronous(func):
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- return run_until_complete(func(*args, **kwargs))
- return wrapper
+def _load_pandora_server_class(class_path: str) -> Type[pandora_server.PandoraServer[Any]]:
+ """Dynamically load a PandoraServer from a user-specified module+class.
+
+ Args:
+ class_path: String in format '<module>.<class>', where the module is fully
+ importable using importlib.import_module. e.g.:
+ my.pandora.server.module.MyPandoraServer
+
+ Returns:
+ The loaded PandoraServer instance.
+ """
+ # Dynamically import the module, and get the class
+ module_name, class_name = class_path.rsplit('.', 1)
+ module = importlib.import_module(module_name)
+ server_class = getattr(module, class_name)
+ # Check that the class is a subclass of PandoraServer
+ if not issubclass(server_class, pandora_server.PandoraServer):
+ raise TypeError(f'The specified class {class_path} is not a subclass of PandoraServer.')
+ return server_class # type: ignore
+
+
+class Wrapper(object):
+ func: Callable[..., Any]
+
+ def __init__(self, func: Callable[..., Any]) -> None:
+ self.func = func
# Multiply the same function from `inputs` parameters
-def parameterized(inputs):
- class wrapper(object):
- def __init__(self, func):
- self.func = func
-
- def __set_name__(self, owner, name):
+def parameterized(*inputs: Tuple[Any, ...]) -> Type[Wrapper]:
+ class wrapper(Wrapper):
+ def __set_name__(self, owner: str, name: str) -> None:
for input in inputs:
- if type(input) != tuple:
- raise ValueError(f'input type {type(input)} shall be a tuple')
- def decorate(input):
+ def decorate(input: Tuple[Any, ...]) -> Callable[..., Any]:
@functools.wraps(self.func)
- def wrapper(*args, **kwargs):
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
return self.func(*args, *input, **kwargs)
+
return wrapper
# we need to pass `input` here, otherwise it will be set to the value
diff --git a/avatar/aio.py b/avatar/aio.py
new file mode 100644
index 0000000..5b61713
--- /dev/null
+++ b/avatar/aio.py
@@ -0,0 +1,66 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import functools
+import threading
+
+from typing import Any, AsyncIterator, Awaitable, Callable, Iterable, Iterator, TypeVar
+
+_T = TypeVar('_T')
+
+
+class AsyncQueue(asyncio.Queue[_T], Iterable[_T]):
+ def __aiter__(self) -> AsyncIterator[_T]:
+ return self
+
+ def __iter__(self) -> Iterator[_T]:
+ return self
+
+ async def __anext__(self) -> _T:
+ return await self.get()
+
+ def __next__(self) -> _T:
+ return run_until_complete(self.__anext__())
+
+
+# Keep running an event loop is a separate thread,
+# which is then used to:
+# * Schedule Bumble(s) IO & gRPC server.
+# * Schedule asynchronous tests.
+loop = asyncio.new_event_loop()
+
+
+def thread_loop() -> None:
+ loop.run_forever()
+ loop.run_until_complete(loop.shutdown_asyncgens())
+
+
+thread = threading.Thread(target=thread_loop, daemon=True)
+thread.start()
+
+
+# run coroutine into our loop until complete
+def run_until_complete(coro: Awaitable[_T]) -> _T:
+ return asyncio.run_coroutine_threadsafe(coro, loop).result()
+
+
+# Convert an asynchronous function to a synchronous one by
+# executing it's code within our loop
+def asynchronous(func: Callable[..., Awaitable[_T]]) -> Callable[..., _T]:
+ @functools.wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> _T:
+ return run_until_complete(func(*args, **kwargs))
+
+ return wrapper
diff --git a/avatar/android_service.py b/avatar/android_service.py
deleted file mode 100644
index d5fc3f8..0000000
--- a/avatar/android_service.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2022 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import time
-import threading
-
-from mobly.controllers.android_device_lib.services.base_service \
- import BaseService
-
-ANDROID_SERVER_PACKAGE = 'com.android.pandora'
-ANDROID_SERVER_GRPC_PORT = 8999 # TODO: Use a dynamic port
-
-
-class AndroidService(BaseService):
-
- def __init__(self, device, configs=None):
- super().__init__(device, configs)
- self.port = configs['port']
- self._is_alive = False
-
- @property
- def is_alive(self):
- return self._is_alive
-
- def start(self):
- # Start Pandora Android gRPC server.
- self.instrumentation = threading.Thread(
- target=lambda: self._device.adb._exec_adb_cmd(
- 'shell',
- f'am instrument --no-hidden-api-checks -w {ANDROID_SERVER_PACKAGE}/.Main',
- shell=False,
- timeout=None,
- stderr=None))
-
- self.instrumentation.start()
-
- self._device.adb.forward(
- [f'tcp:{self.port}', f'tcp:{ANDROID_SERVER_GRPC_PORT}'])
-
- # Wait a few seconds for the Android gRPC server to be started.
- time.sleep(3)
-
- self._is_alive = True
-
- def stop(self):
- # Stop Pandora Android gRPC server.
- self._device.adb._exec_adb_cmd(
- 'shell',
- f'am force-stop {ANDROID_SERVER_PACKAGE}',
- shell=False,
- timeout=None,
- stderr=None)
-
- self._device.adb.forward(['--remove', f'tcp:{self.port}'])
-
- self.instrumentation.join()
-
- self._is_alive = False
diff --git a/avatar/bumble_device.py b/avatar/bumble_device.py
new file mode 100644
index 0000000..872e013
--- /dev/null
+++ b/avatar/bumble_device.py
@@ -0,0 +1,142 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Generic & dependency free Bumble (reference) device."""
+
+from bumble import transport
+from bumble.core import BT_GENERIC_AUDIO_SERVICE, BT_HANDSFREE_SERVICE, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
+from bumble.device import Device, DeviceConfiguration
+from bumble.host import Host
+from bumble.sdp import (
+ SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
+ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
+ DataElement,
+ ServiceAttribute,
+)
+from typing import Any, Dict, List, Optional
+
+
+class BumbleDevice:
+ """
+ Small wrapper around a Bumble device and it's HCI transport.
+ Notes:
+ - The Bumble device is idle by default.
+ - Repetitive calls to `open`/`close` will result on new Bumble device instances.
+ """
+
+ # Bumble device instance & configuration.
+ device: Device
+ config: Dict[str, Any]
+
+ # HCI transport name & instance.
+ _hci_name: str
+ _hci: Optional[transport.Transport] # type: ignore[name-defined]
+
+ def __init__(self, config: Dict[str, Any]) -> None:
+ self.config = config
+ self.device = _make_device(config)
+ self._hci_name = config.get('transport', '')
+ self._hci = None
+
+ @property
+ def idle(self) -> bool:
+ return self._hci is None
+
+ async def open(self) -> None:
+ if self._hci is not None:
+ return
+
+ # open HCI transport & set device host.
+ self._hci = await transport.open_transport(self._hci_name)
+ self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
+
+ # power-on.
+ await self.device.power_on()
+
+ async def close(self) -> None:
+ if self._hci is None:
+ return
+
+ # flush & re-initialize device.
+ await self.device.host.flush()
+ self.device.host = None # type: ignore[assignment]
+ self.device = _make_device(self.config)
+
+ # close HCI transport.
+ await self._hci.close()
+ self._hci = None
+
+ async def reset(self) -> None:
+ await self.close()
+ await self.open()
+
+ def info(self) -> Optional[Dict[str, str]]:
+ return {
+ 'public_bd_address': str(self.device.public_address),
+ 'random_address': str(self.device.random_address),
+ }
+
+
+def _make_device(config: Dict[str, Any]) -> Device:
+ """Initialize an idle Bumble device instance."""
+
+ # initialize bumble device.
+ device_config = DeviceConfiguration()
+ device_config.load_from_dict(config)
+ device = Device(config=device_config, host=None)
+
+ # FIXME: add `classic_enabled` to `DeviceConfiguration` ?
+ device.classic_enabled = config.get('classic_enabled', False)
+ # Add fake a2dp service to avoid Android disconnect
+ device.sdp_service_records = _make_sdp_records(1)
+
+ return device
+
+
+# TODO(b/267540823): remove when Pandora A2dp is supported
+def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]:
+ return {
+ 0x00010001: [
+ ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001)),
+ ServiceAttribute(
+ SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
+ DataElement.sequence(
+ [DataElement.uuid(BT_HANDSFREE_SERVICE), DataElement.uuid(BT_GENERIC_AUDIO_SERVICE)]
+ ),
+ ),
+ ServiceAttribute(
+ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ DataElement.sequence(
+ [
+ DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
+ DataElement.sequence(
+ [DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), DataElement.unsigned_integer_8(rfcomm_channel)]
+ ),
+ ]
+ ),
+ ),
+ ServiceAttribute(
+ SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ DataElement.sequence(
+ [
+ DataElement.sequence(
+ [DataElement.uuid(BT_HANDSFREE_SERVICE), DataElement.unsigned_integer_16(0x0105)]
+ )
+ ]
+ ),
+ ),
+ ]
+ }
diff --git a/avatar/bumble_server/__init__.py b/avatar/bumble_server/__init__.py
index 3b71019..839d2e7 100644
--- a/avatar/bumble_server/__init__.py
+++ b/avatar/bumble_server/__init__.py
@@ -12,154 +12,131 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Pandora Bumble Server."""
+"""Bumble Pandora server."""
__version__ = "0.0.1"
import asyncio
import grpc
+import grpc.aio
import logging
import os
-import random
import sys
import traceback
+from avatar.bumble_device import BumbleDevice
+from avatar.bumble_server.asha import ASHAService
+from avatar.bumble_server.host import HostService
+from avatar.bumble_server.security import SecurityService, SecurityStorageService
from bumble.smp import PairingDelegate
-from bumble.host import Host
-from bumble.device import Device, DeviceConfiguration
-from bumble.transport import open_transport
-from bumble.sdp import (
- DataElement, ServiceAttribute,
- SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
- SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
- SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
- SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
-)
-from bumble.core import (
- BT_GENERIC_AUDIO_SERVICE, BT_HANDSFREE_SERVICE,
- BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
-)
+from dataclasses import dataclass
+from pandora.asha_grpc_aio import add_ASHAServicer_to_server
+from pandora.host_grpc_aio import add_HostServicer_to_server
+from pandora.security_grpc_aio import add_SecurityServicer_to_server, add_SecurityStorageServicer_to_server
+from typing import Callable, Coroutine, List, Optional
-from .host import HostService
-from pandora.host_grpc import add_HostServicer_to_server
+# Add servicers hooks.
+_SERVICERS_HOOKS: List[Callable[['Server'], None]] = []
-from .security import SecurityService, SecurityStorageService
-from pandora.security_grpc import add_SecurityServicer_to_server, add_SecurityStorageServicer_to_server
-from pandora.asha_grpc import add_ASHAServicer_to_server
-from .asha import ASHAService
+@dataclass
+class Configuration:
+ io_capability: int
+
+
+@dataclass
+class Server:
+ port: int
+ bumble: BumbleDevice
+ server: grpc.aio.Server
+ config: Configuration
+
+ async def start(self) -> None:
+ device = self.bumble.device
+
+ # add Pandora services to the gRPC server.
+ add_HostServicer_to_server(HostService(self.server, device), self.server)
+ add_SecurityServicer_to_server(SecurityService(device, self.config.io_capability), self.server)
+ add_SecurityStorageServicer_to_server(SecurityStorageService(device), self.server)
+ add_ASHAServicer_to_server(ASHAService(device), self.server)
+
+ # call hooks if any.
+ for hook in _SERVICERS_HOOKS:
+ hook(self)
+
+ try:
+ # open device.
+ await self.bumble.open()
+ except:
+ print(traceback.format_exc(), end='', file=sys.stderr)
+ os._exit(1) # type: ignore
+
+ # Pandora require classic devices to to be discoverable & connectable.
+ if device.classic_enabled:
+ await device.set_discoverable(False)
+ await device.set_connectable(True)
+
+ # start the gRPC server.
+ await self.server.start()
+
+ async def serve(self) -> None:
+ try:
+ while True:
+ try:
+ # serve gRPC server.
+ await self.server.wait_for_termination()
+ except KeyboardInterrupt:
+ return
+ finally:
+ # close device.
+ await self.bumble.close()
+
+ # re-initialize the gRPC server & re-start.
+ self.server = grpc.aio.server()
+ self.port = self.server.add_insecure_port(f'localhost:{self.port}')
+ await self.start()
+ except KeyboardInterrupt:
+ return
+ finally:
+ # stop server.
+ await self.server.stop(None)
+
+
+def register_servicer_hook(hook: Callable[['Server'], None]) -> None:
+ _SERVICERS_HOOKS.append(hook)
+
+
+async def create_serve_task(
+ bumble: BumbleDevice,
+ grpc_server: Optional[grpc.aio.Server] = None,
+ port: int = 0,
+) -> Coroutine[None, None, None]:
+ # initialize a gRPC server if not provided.
+ server = grpc_server if grpc_server is not None else grpc.aio.server()
+ port = server.add_insecure_port(f'localhost:{port}')
+
+ # load IO capability from config.
+ io_capability_name: str = bumble.config.get('io_capability', 'no_output_no_input').upper()
+ io_capability: int = getattr(PairingDelegate, io_capability_name)
+
+ # create server.
+ bumble_server = Server(port, bumble, server, Configuration(io_capability))
+
+ # start bumble server & return serve task.
+ await bumble_server.start()
+ return bumble_server.serve()
+
BUMBLE_SERVER_GRPC_PORT = 7999
ROOTCANAL_PORT_CUTTLEFISH = 7300
-def make_sdp_records(rfcomm_channel):
- return {
- 0x00010001: [
- ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
- DataElement.unsigned_integer_32(0x00010001)),
- ServiceAttribute(
- SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
- DataElement.sequence([
- DataElement.uuid(BT_HANDSFREE_SERVICE),
- DataElement.uuid(BT_GENERIC_AUDIO_SERVICE)
- ])),
- ServiceAttribute(
- SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
- DataElement.sequence([
- DataElement.sequence(
- [DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
- DataElement.sequence([
- DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
- DataElement.unsigned_integer_8(rfcomm_channel)
- ])
- ])),
- ServiceAttribute(
- SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
- DataElement.sequence([
- DataElement.sequence([
- DataElement.uuid(BT_HANDSFREE_SERVICE),
- DataElement.unsigned_integer_16(0x0105)
- ])
- ]))
- ]
- }
-class BumblePandoraServer:
-
- def __init__(self, transport, config):
- self.transport = transport
- self.config = config
-
- async def start(self, grpc_server: grpc.aio.Server):
- self.hci = await open_transport(self.transport)
-
- # generate a random address
- random_address = f"{random.randint(192,255):02X}" # address is static random
- for c in random.sample(range(255), 5): random_address += f":{c:02X}"
-
- # initialize bumble device
- device_config = DeviceConfiguration()
- device_config.load_from_dict(self.config)
- host = Host(controller_source=self.hci.source, controller_sink=self.hci.sink)
- self.device = Device(config=device_config, host=host, address=random_address)
-
- # FIXME: add `classic_enabled` to `DeviceConfiguration` ?
- self.device.classic_enabled = self.config.get('classic_enabled', False)
- # Add fake a2dp service to avoid Android disconnect (TODO: remove when a2dp is supported)
- self.device.sdp_service_records = make_sdp_records(1)
- io_capability_name = self.config.get('io_capability', 'no_output_no_input').upper()
- io_capability = getattr(PairingDelegate, io_capability_name)
-
- # start bumble device
- await self.device.power_on()
-
- # add our services to the gRPC server
- add_HostServicer_to_server(await HostService(grpc_server, self.device).start(), grpc_server)
- add_SecurityServicer_to_server(SecurityService(self.device, io_capability), grpc_server)
- add_SecurityStorageServicer_to_server(SecurityStorageService(self.device), grpc_server)
- add_ASHAServicer_to_server(ASHAService(self.device), grpc_server)
-
- async def close(self):
- await self.device.host.flush()
- await self.hci.close()
-
- @classmethod
- async def serve(cls, transport, config, grpc_server, grpc_port, on_started=None):
- try:
- while True:
- try:
- server = cls(transport, config)
- await server.start(grpc_server)
- except:
- print(traceback.format_exc(), end='', file=sys.stderr)
- os._exit(1)
-
- if on_started:
- on_started(server)
-
- await grpc_server.start()
- await grpc_server.wait_for_termination()
- await server.close()
-
- # re-initialize gRPC server
- grpc_server = grpc.aio.server()
- grpc_server.add_insecure_port(f'localhost:{grpc_port}')
-
- finally:
- await server.close()
- await grpc_server.stop(None)
-
-
-async def serve():
- grpc_server = grpc.aio.Server()
- grpc_port = grpc_server.add_insecure_port(f'localhost:{BUMBLE_SERVER_GRPC_PORT}')
-
- transport = f'tcp-client:127.0.0.1:{ROOTCANAL_PORT_CUTTLEFISH}'
- config = {'classic_enabled': True}
-
- await BumblePandoraServer.serve(transport, config, grpc_server, grpc_port)
+async def amain() -> None:
+ bumble = BumbleDevice({'transport': f'tcp-client:127.0.0.1:{ROOTCANAL_PORT_CUTTLEFISH}', 'classic_enabled': True})
+ serve = await create_serve_task(bumble, port=BUMBLE_SERVER_GRPC_PORT)
+ await serve
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
- asyncio.run(serve())
+ asyncio.run(amain())
diff --git a/avatar/bumble_server/asha.py b/avatar/bumble_server/asha.py
index 6114701..ce27582 100644
--- a/avatar/bumble_server/asha.py
+++ b/avatar/bumble_server/asha.py
@@ -12,25 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import grpc
import logging
+from avatar.bumble_server.utils import BumbleServerLoggerAdapter
from bumble.device import Device
from bumble.profiles.asha_service import AshaService
-
-from pandora.asha_grpc import ASHAServicer
-
from google.protobuf.empty_pb2 import Empty
+from pandora.asha_grpc_aio import ASHAServicer
+from pandora.asha_pb2 import RegisterRequest
+from typing import Optional
class ASHAService(ASHAServicer):
- def __init__(self, device: Device):
+ device: Device
+ asha_service: Optional[AshaService]
+
+ def __init__(self, device: Device) -> None:
+ self.log = BumbleServerLoggerAdapter(logging.getLogger(), {'service_name': 'Asha', 'device': device})
self.device = device
+ self.asha_service = None
- super().__init__()
-
- async def Register(self, request, context):
- logging.info('Register')
+ async def Register(self, request: RegisterRequest, context: grpc.ServicerContext) -> Empty:
+ self.log.info('Register')
# asha service from bumble profile
- self.asha_service = AshaService(request.capability, request.hisyncid)
- self.device.add_service(self.asha_service)
+ self.asha_service = AshaService(request.capability, request.hisyncid, self.device)
+ self.device.add_service(self.asha_service) # type: ignore[no-untyped-call]
return Empty()
diff --git a/avatar/bumble_server/host.py b/avatar/bumble_server/host.py
index c1d743f..94ba3b9 100644
--- a/avatar/bumble_server/host.py
+++ b/avatar/bumble_server/host.py
@@ -13,64 +13,97 @@
# limitations under the License.
import asyncio
-import logging
+import bumble.device
import grpc
+import grpc.aio
+import logging
import struct
-from avatar.bumble_server.utils import address_from_request
-
+from avatar.bumble_server.utils import BumbleServerLoggerAdapter, address_from_request
from bumble.core import (
- BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT,
- AdvertisingData, ConnectionError
+ BT_BR_EDR_TRANSPORT,
+ BT_LE_TRANSPORT,
+ BT_PERIPHERAL_ROLE,
+ UUID,
+ AdvertisingData,
+ ConnectionError,
)
from bumble.device import (
- DEVICE_DEFAULT_SCAN_INTERVAL, DEVICE_DEFAULT_SCAN_WINDOW,
- AdvertisingType, Device
+ DEVICE_DEFAULT_SCAN_INTERVAL,
+ DEVICE_DEFAULT_SCAN_WINDOW,
+ Advertisement,
+ AdvertisingType,
+ Device,
)
+from bumble.gatt import Service
from bumble.hci import (
- HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, HCI_PAGE_TIMEOUT_ERROR,
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
- Address, HCI_Error
+ HCI_PAGE_TIMEOUT_ERROR,
+ HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
+ Address,
)
-from bumble.gatt import (
- Service
-)
-
-from google.protobuf.empty_pb2 import Empty
-from google.protobuf.any_pb2 import Any
-
-from pandora.host_grpc import HostServicer
+from google.protobuf import any_pb2, empty_pb2
+from pandora.host_grpc_aio import HostServicer
from pandora.host_pb2 import (
- DiscoverabilityMode, ConnectabilityMode,
- Connection, DataTypes
-)
-from pandora.host_pb2 import (
+ NOT_CONNECTABLE,
+ NOT_DISCOVERABLE,
+ PRIMARY_1M,
+ PRIMARY_CODED,
+ SECONDARY_1M,
+ SECONDARY_2M,
+ SECONDARY_CODED,
+ SECONDARY_NONE,
+ AdvertiseRequest,
+ AdvertiseResponse,
+ Connection,
+ ConnectLERequest,
+ ConnectLEResponse,
+ ConnectRequest,
+ ConnectResponse,
+ DataTypes,
+ DisconnectRequest,
+ InquiryResponse,
+ PrimaryPhy,
ReadLocalAddressResponse,
- ConnectResponse, GetConnectionResponse, WaitConnectionResponse,
- ConnectLEResponse, GetLEConnectionResponse, WaitLEConnectionResponse,
- StartAdvertisingResponse, ScanningResponse, InquiryResponse,
- GetRemoteNameResponse
+ ScanningResponse,
+ ScanRequest,
+ SecondaryPhy,
+ SetConnectabilityModeRequest,
+ SetDiscoverabilityModeRequest,
+ WaitConnectionRequest,
+ WaitConnectionResponse,
+ WaitDisconnectionRequest,
)
+from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple, cast
+
+PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = {1: PRIMARY_1M, 3: PRIMARY_CODED}
+
+SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
+ 0: SECONDARY_NONE,
+ 1: SECONDARY_1M,
+ 2: SECONDARY_2M,
+ 3: SECONDARY_CODED,
+}
class HostService(HostServicer):
+ grpc_server: grpc.aio.Server
+ device: Device
+ scan_queue: asyncio.Queue[Advertisement]
+ inquiry_queue: asyncio.Queue[Optional[Tuple[Address, int, AdvertisingData, int]]]
+ waited_connections: Set[int]
- def __init__(self, grpc_server: grpc.aio.Server, device: Device):
+ def __init__(self, grpc_server: grpc.aio.Server, device: Device) -> None:
super().__init__()
+ self.log = BumbleServerLoggerAdapter(logging.getLogger(), {'service_name': 'Host', 'device': device})
self.grpc_server = grpc_server
self.device = device
self.scan_queue = asyncio.Queue()
self.inquiry_queue = asyncio.Queue()
+ self.waited_connections = set()
- async def start(self) -> "HostService":
- # According to `host.proto`:
- # At startup, the Host must be in BR/EDR connectable mode
- await self.device.set_discoverable(False)
- await self.device.set_connectable(True)
- return self
-
- async def FactoryReset(self, request, context):
- logging.info('FactoryReset')
+ async def FactoryReset(self, request: empty_pb2.Empty, context: grpc.ServicerContext) -> empty_pb2.Empty:
+ self.log.info('FactoryReset')
# delete all bonds
if self.device.keystore is not None:
@@ -78,167 +111,137 @@
# trigger gRCP server stop then return
asyncio.create_task(self.grpc_server.stop(None))
- return Empty()
+ return empty_pb2.Empty()
- async def Reset(self, request, context):
- logging.info('Reset')
+ async def Reset(self, request: empty_pb2.Empty, context: grpc.ServicerContext) -> empty_pb2.Empty:
+ self.log.info('Reset')
+
+ # clear service.
+ self.waited_connections.clear()
+ self.scan_queue = asyncio.Queue()
+ self.inquiry_queue = asyncio.Queue()
# (re) power device on
await self.device.power_on()
- return Empty()
+ return empty_pb2.Empty()
- async def ReadLocalAddress(self, request, context):
- logging.info('ReadLocalAddress')
- return ReadLocalAddressResponse(
- address=bytes(reversed(bytes(self.device.public_address))))
+ async def ReadLocalAddress(
+ self, request: empty_pb2.Empty, context: grpc.ServicerContext
+ ) -> ReadLocalAddressResponse:
+ self.log.info('ReadLocalAddress')
+ return ReadLocalAddressResponse(address=bytes(reversed(bytes(self.device.public_address))))
- async def Connect(self, request, context):
+ async def Connect(self, request: ConnectRequest, context: grpc.ServicerContext) -> ConnectResponse:
# Need to reverse bytes order since Bumble Address is using MSB.
address = Address(bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS)
- logging.info(f"Connect: {address}")
+ self.log.info(f"Connect to {address}")
try:
- logging.info("Connecting...")
connection = await self.device.connect(address, transport=BT_BR_EDR_TRANSPORT)
- logging.info("Connected")
except ConnectionError as e:
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
- logging.warning(f"Peer not found: {e}")
- return ConnectResponse(peer_not_found=Empty())
+ self.log.warning(f"Peer not found: {e}")
+ return ConnectResponse(peer_not_found=empty_pb2.Empty())
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
- logging.warning(f"Connection already exists: {e}")
- return ConnectResponse(connection_already_exists=Empty())
+ self.log.warning(f"Connection already exists: {e}")
+ return ConnectResponse(connection_already_exists=empty_pb2.Empty())
raise e
- logging.info(f"Connect: connection handle: {connection.handle}")
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
+ self.log.info(f"Connect to {address} done (handle={connection.handle})")
+
+ cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectResponse(connection=Connection(cookie=cookie))
- async def GetConnection(self, request, context):
- # Need to reverse bytes order since Bumble Address is using MSB.
- address = Address(bytes(reversed(request.address)))
- logging.info(f"GetConnection: {address}")
+ async def WaitConnection(
+ self, request: WaitConnectionRequest, context: grpc.ServicerContext
+ ) -> WaitConnectionResponse:
+ if not request.address:
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore
+ raise ValueError('Request address field must be set')
- connection = self.device.find_connection_by_bd_addr(
- address, transport=BT_BR_EDR_TRANSPORT)
+ # Need to reverse bytes order since Bumble Address is using MSB.
+ address = Address(bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS)
+ if address in (Address.NIL, Address.ANY):
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore
+ raise ValueError('Invalid address')
+
+ self.log.info(f"WaitConnection from {address}...")
+
+ connection = self.device.find_connection_by_bd_addr(address, transport=BT_BR_EDR_TRANSPORT)
+ if connection and id(connection) in self.waited_connections:
+ # this connection was already returned: wait for a new one.
+ connection = None
if not connection:
- return GetConnectionResponse(peer_not_found=Empty())
+ connection = await self.device.accept(address)
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
- return GetConnectionResponse(connection=Connection(cookie=cookie))
+ # save connection has waited and respond.
+ self.waited_connections.add(id(connection))
- async def WaitConnection(self, request, context):
- # Need to reverse bytes order since Bumble Address is using MSB.
- if request.address:
- address = Address(bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS)
- logging.info(f"WaitConnection: {address}")
+ self.log.info(f"WaitConnection from {address} done (handle={connection.handle})")
- connection = self.device.find_connection_by_bd_addr(
- address, transport=BT_BR_EDR_TRANSPORT)
-
- if connection:
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
- return WaitConnectionResponse(connection=Connection(cookie=cookie))
- else:
- address = Address.ANY
- logging.info(f"WaitConnection: {address}")
-
- logging.info("Wait connection...")
- connection = await self.device.accept(address)
- logging.info("Connected")
-
- logging.info(f"WaitConnection: connection handle: {connection.handle}")
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
+ cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return WaitConnectionResponse(connection=Connection(cookie=cookie))
- async def ConnectLE(self, request, context):
+ async def ConnectLE(self, request: ConnectLERequest, context: grpc.ServicerContext) -> ConnectLEResponse:
address = address_from_request(request, request.WhichOneof("address"))
- logging.info(f"ConnectLE: {address}")
+ if address in (Address.NIL, Address.ANY):
+ context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore
+ raise ValueError('Invalid address')
+
+ self.log.info(f"ConnectLE to {address}...")
try:
- logging.info("Connecting...")
- connection = await self.device.connect(address,
- transport=BT_LE_TRANSPORT, own_address_type=request.own_address_type)
- logging.info("Connected")
+ connection = await self.device.connect(
+ address, transport=BT_LE_TRANSPORT, own_address_type=request.own_address_type
+ )
except ConnectionError as e:
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
- logging.warning(f"Peer not found: {e}")
- return ConnectLEResponse(peer_not_found=Empty())
+ self.log.warning(f"Peer not found: {e}")
+ return ConnectLEResponse(peer_not_found=empty_pb2.Empty())
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
- logging.warning(f"Connection already exists: {e}")
- return ConnectLEResponse(connection_already_exists=Empty())
+ self.log.warning(f"Connection already exists: {e}")
+ return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
+ context.set_code(grpc.StatusCode.ABORTED) # type: ignore
raise e
- logging.info(f"ConnectLE: connection handle: {connection.handle}")
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
+ self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
+
+ cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectLEResponse(connection=Connection(cookie=cookie))
- async def GetLEConnection(self, request, context):
- address = address_from_request(request, request.WhichOneof("address"))
- logging.info(f"GetLEConnection: {address}")
-
- connection = self.device.find_connection_by_bd_addr(
- address, transport=BT_LE_TRANSPORT, check_address_type=True)
-
- if not connection:
- return GetLEConnectionResponse(peer_not_found=Empty())
-
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
- return GetLEConnectionResponse(connection=Connection(cookie=cookie))
-
- async def WaitLEConnection(self, request, context):
- address = address_from_request(request, request.WhichOneof("address"))
- logging.info(f"WaitLEConnection: {address}")
-
- connection = self.device.find_connection_by_bd_addr(
- address, transport=BT_LE_TRANSPORT, check_address_type=True)
-
- if connection:
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
- return WaitLEConnectionResponse(connection=Connection(cookie=cookie))
-
- pending_connection = asyncio.get_running_loop().create_future()
- handler = self.device.on('connection', lambda connection:
- pending_connection.set_result(connection)
- if connection.transport == BT_LE_TRANSPORT and connection.peer_address == address else None)
- failure_handler = self.device.on('connection_failure', lambda error:
- pending_connection.set_exception(error)
- if error.transport == BT_LE_TRANSPORT and error.peer_address == address else None)
-
- try:
- connection = await pending_connection
- cookie = Any(value=connection.handle.to_bytes(4, 'big'))
- return WaitLEConnectionResponse(connection=Connection(cookie=cookie))
- finally:
- self.device.remove_listener('connection', handler)
- self.device.remove_listener('connection_failure', failure_handler)
-
- async def Disconnect(self, request, context):
+ async def Disconnect(self, request: DisconnectRequest, context: grpc.ServicerContext) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
- logging.info(f"Disconnect: {connection_handle}")
+ self.log.info(f"Disconnect: {connection_handle}")
- logging.info("Disconnecting...")
- connection = self.device.lookup_connection(connection_handle)
- await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
- logging.info("Disconnected")
+ self.log.info("Disconnecting...")
+ if connection := self.device.lookup_connection(connection_handle):
+ await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
+ self.log.info("Disconnected")
- return Empty()
+ return empty_pb2.Empty()
- async def WaitDisconnection(self, request, context):
+ async def WaitDisconnection(
+ self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
+ ) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
- logging.info(f"WaitDisconnection: {connection_handle}")
+ self.log.info(f"WaitDisconnection: {connection_handle}")
- if (connection := self.device.lookup_connection(connection_handle)):
- disconnection_future = asyncio.get_running_loop().create_future()
- connection.on('disconnection', lambda _: disconnection_future.set_result(True))
+ if connection := self.device.lookup_connection(connection_handle):
+ disconnection_future: asyncio.Future[None] = asyncio.get_running_loop().create_future()
+
+ def on_disconnection(_: None) -> None:
+ disconnection_future.set_result(None)
+
+ connection.on('disconnection', on_disconnection)
await disconnection_future
- logging.info("Disconnected")
+ self.log.info("Disconnected")
- return Empty()
+ return empty_pb2.Empty()
- # TODO: use advertising set commands
- async def StartAdvertising(self, request, context):
+ async def Advertise(
+ self, request: AdvertiseRequest, context: grpc.ServicerContext
+ ) -> AsyncGenerator[AdvertiseResponse, None]:
# TODO: add support for extended advertising in Bumble
# TODO: add support for `request.interval`
# TODO: add support for `request.interval_range`
@@ -250,30 +253,43 @@
assert not request.primary_phy
assert not request.secondary_phy
- logging.info('StartAdvertising')
+ if self.device.is_advertising:
+ # TODO: add support for advertising sets.
+ context.set_code(grpc.StatusCode.ABORTED) # type: ignore
+ raise RuntimeError('Advertising sets are not yet supported, only one `Advertise` is possible at a time')
if data := request.data:
self.device.advertising_data = bytes(self.unpack_data_types(data))
- # Retrieve services data
- for service in self.device.gatt_server.attributes:
- if isinstance(service, Service) and (data := service.get_advertising_data()) and (
- service.uuid.to_hex_str() in request.data.incomplete_service_class_uuids16 or
- service.uuid.to_hex_str() in request.data.complete_service_class_uuids16 or
- service.uuid.to_hex_str() in request.data.incomplete_service_class_uuids32 or
- service.uuid.to_hex_str() in request.data.complete_service_class_uuids32 or
- service.uuid.to_hex_str() in request.data.incomplete_service_class_uuids128 or
- service.uuid.to_hex_str() in request.data.complete_service_class_uuids128
- ):
- self.device.advertising_data += data
-
if scan_response_data := request.scan_response_data:
- self.device.scan_response_data = bytes(
- self.unpack_data_types(scan_response_data))
+ self.device.scan_response_data = bytes(self.unpack_data_types(scan_response_data))
scannable = True
else:
scannable = False
+ # Retrieve services data
+ for service in self.device.gatt_server.attributes:
+ if isinstance(service, Service) and (service_data := service.get_advertising_data()):
+ service_uuid = service.uuid.to_hex_str()
+ if (
+ service_uuid in request.data.incomplete_service_class_uuids16
+ or service_uuid in request.data.complete_service_class_uuids16
+ or service_uuid in request.data.incomplete_service_class_uuids32
+ or service_uuid in request.data.complete_service_class_uuids32
+ or service_uuid in request.data.incomplete_service_class_uuids128
+ or service_uuid in request.data.complete_service_class_uuids128
+ ):
+ self.device.advertising_data += service_data
+ if (
+ service_uuid in scan_response_data.incomplete_service_class_uuids16
+ or service_uuid in scan_response_data.complete_service_class_uuids16
+ or service_uuid in scan_response_data.incomplete_service_class_uuids32
+ or service_uuid in scan_response_data.complete_service_class_uuids32
+ or service_uuid in scan_response_data.incomplete_service_class_uuids128
+ or service_uuid in scan_response_data.complete_service_class_uuids128
+ ):
+ self.device.scan_response_data += service_data
+
target = None
if request.connectable and scannable:
advertising_type = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE
@@ -281,86 +297,123 @@
advertising_type = AdvertisingType.UNDIRECTED_SCANNABLE
else:
advertising_type = AdvertisingType.UNDIRECTED
+ else:
+ target = None
+ advertising_type = AdvertisingType.UNDIRECTED
- # Need to reverse bytes order since Bumble Address is using MSB.
- if request.WhichOneof("target") == "public":
- target = Address(bytes(reversed(request.public)), Address.PUBLIC_DEVICE_ADDRESS)
- advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY # FIXME: HIGH_DUTY ?
- elif request.WhichOneof("target") == "random":
- target = Address(bytes(reversed(request.random)), Address.RANDOM_DEVICE_ADDRESS)
- advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY # FIXME: HIGH_DUTY ?
+ if request.target:
+ # Need to reverse bytes order since Bumble Address is using MSB.
+ target_bytes = bytes(reversed(request.target))
+ if request.target_variant() == "public":
+ target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
+ advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY # FIXME: HIGH_DUTY ?
+ else:
+ target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
+ advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY # FIXME: HIGH_DUTY ?
- await self.device.start_advertising(
- target = target,
- advertising_type = advertising_type,
- own_address_type = request.own_address_type
- )
+ if request.connectable:
- # FIXME: wait for advertising sets to have a correct set, use `None` for now
- return StartAdvertisingResponse(set=None)
+ def on_connection(connection: bumble.device.Connection) -> None:
+ if connection.transport == BT_LE_TRANSPORT and connection.role == BT_PERIPHERAL_ROLE:
+ pending_connection.set_result(connection)
- # TODO: use advertising set commands
- async def StopAdvertising(self, request, context):
- logging.info('StopAdvertising')
- await self.device.stop_advertising()
- return Empty()
+ self.device.on('connection', on_connection)
- async def Scan(self, request, context):
+ try:
+ while True:
+ if not self.device.is_advertising:
+ self.log.info('Advertise')
+ await self.device.start_advertising(
+ target=target, advertising_type=advertising_type, own_address_type=request.own_address_type
+ )
+
+ if not request.connectable:
+ await asyncio.sleep(1)
+ continue
+
+ pending_connection: asyncio.Future[
+ bumble.device.Connection
+ ] = asyncio.get_running_loop().create_future()
+
+ self.log.info('Wait for LE connection...')
+ connection = await pending_connection
+
+ self.log.info(f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})")
+
+ cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
+ yield AdvertiseResponse(connection=Connection(cookie=cookie))
+
+ # wait a small delay before restarting the advertisement.
+ await asyncio.sleep(1)
+ finally:
+ if request.connectable:
+ self.device.remove_listener('connection', on_connection) # type: ignore
+
+ self.log.info('Stop advertising')
+ await self.device.abort_on('flush', self.device.stop_advertising())
+
+ async def Scan(
+ self, request: ScanRequest, context: grpc.ServicerContext
+ ) -> AsyncGenerator[ScanningResponse, None]:
# TODO: add support for `request.phys`
+ # TODO: modify `start_scanning` to accept floats instead of int for ms values
assert not request.phys
- logging.info('Scan')
+ self.log.info('Scan')
handler = self.device.on('advertisement', self.scan_queue.put_nowait)
await self.device.start_scanning(
- legacy = request.legacy,
- active = not request.passive,
- own_address_type = request.own_address_type,
- scan_interval = request.interval if request.interval else DEVICE_DEFAULT_SCAN_INTERVAL,
- scan_window = request.window if request.window else DEVICE_DEFAULT_SCAN_WINDOW
+ legacy=request.legacy,
+ active=not request.passive,
+ own_address_type=request.own_address_type,
+ scan_interval=int(request.interval) if request.interval else DEVICE_DEFAULT_SCAN_INTERVAL,
+ scan_window=int(request.window) if request.window else DEVICE_DEFAULT_SCAN_WINDOW,
)
try:
# TODO: add support for `direct_address` in Bumble
# TODO: add support for `periodic_advertising_interval` in Bumble
while adv := await self.scan_queue.get():
- kwargs = {
- 'legacy': adv.is_legacy,
- 'connectable': adv.is_connectable,
- 'scannable': adv.is_scannable,
- 'truncated': adv.is_truncated,
- 'sid': adv.sid,
- 'primary_phy': adv.primary_phy,
- 'secondary_phy': adv.secondary_phy,
- 'tx_power': adv.tx_power,
- 'rssi': adv.rssi,
- 'data': self.pack_data_types(adv.data)
- }
+ sr = ScanningResponse(
+ legacy=adv.is_legacy,
+ connectable=adv.is_connectable,
+ scannable=adv.is_scannable,
+ truncated=adv.is_truncated,
+ sid=adv.sid,
+ primary_phy=PRIMARY_PHY_MAP[adv.primary_phy],
+ secondary_phy=SECONDARY_PHY_MAP[adv.secondary_phy],
+ tx_power=adv.tx_power,
+ rssi=adv.rssi,
+ data=self.pack_data_types(adv.data),
+ )
if adv.address.address_type == Address.PUBLIC_DEVICE_ADDRESS:
- kwargs['public'] = bytes(reversed(bytes(adv.address)))
+ sr.public = bytes(reversed(bytes(adv.address)))
elif adv.address.address_type == Address.RANDOM_DEVICE_ADDRESS:
- kwargs['random'] = bytes(reversed(bytes(adv.address)))
+ sr.random = bytes(reversed(bytes(adv.address)))
elif adv.address.address_type == Address.PUBLIC_IDENTITY_ADDRESS:
- kwargs['public_identity'] = bytes(reversed(bytes(adv.address)))
- elif adv.address.address_type == Address.RANDOM_IDENTITY_ADDRESS:
- kwargs['random_static_identity'] = bytes(reversed(bytes(adv.address)))
+ sr.public_identity = bytes(reversed(bytes(adv.address)))
+ else:
+ sr.random_static_identity = bytes(reversed(bytes(adv.address)))
- yield ScanningResponse(**kwargs)
+ yield sr
finally:
- self.device.remove_listener('advertisement', handler)
+ self.device.remove_listener('advertisement', handler) # type: ignore
self.scan_queue = asyncio.Queue()
await self.device.abort_on('flush', self.device.stop_scanning())
- async def Inquiry(self, request, context):
- logging.info('Inquiry')
+ async def Inquiry(
+ self, request: empty_pb2.Empty, context: grpc.ServicerContext
+ ) -> AsyncGenerator[InquiryResponse, None]:
+ self.log.info('Inquiry')
complete_handler = self.device.on('inquiry_complete', lambda: self.inquiry_queue.put_nowait(None))
- result_handler = self.device.on(
+ result_handler = self.device.on( # type: ignore
'inquiry_result',
- lambda address, class_of_device, eir_data, rssi:
- self.inquiry_queue.put_nowait((address, class_of_device, eir_data, rssi))
+ lambda address, class_of_device, eir_data, rssi: self.inquiry_queue.put_nowait( # type: ignore
+ (address, class_of_device, eir_data, rssi) # type: ignore
+ ),
)
await self.device.start_discovery(auto_restart=False)
@@ -372,246 +425,224 @@
address=bytes(reversed(bytes(address))),
class_of_device=class_of_device,
rssi=rssi,
- data=self.pack_data_types(eir_data)
+ data=self.pack_data_types(eir_data),
)
finally:
- self.device.remove_listener('inquiry_complete', complete_handler)
- self.device.remove_listener('inquiry_result', result_handler)
+ self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
+ self.device.remove_listener('inquiry_result', result_handler) # type: ignore
self.inquiry_queue = asyncio.Queue()
await self.device.abort_on('flush', self.device.stop_discovery())
- async def SetDiscoverabilityMode(self, request, context):
- logging.info("SetDiscoverabilityMode")
- await self.device.set_discoverable(request.mode != DiscoverabilityMode.NOT_DISCOVERABLE)
- return Empty()
+ async def SetDiscoverabilityMode(
+ self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
+ ) -> empty_pb2.Empty:
+ self.log.info("SetDiscoverabilityMode")
+ await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
+ return empty_pb2.Empty()
- async def SetConnectabilityMode(self, request, context):
- logging.info("SetConnectabilityMode")
- await self.device.set_connectable(request.mode != ConnectabilityMode.NOT_CONNECTABLE)
- return Empty()
+ async def SetConnectabilityMode(
+ self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
+ ) -> empty_pb2.Empty:
+ self.log.info("SetConnectabilityMode")
+ await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
+ return empty_pb2.Empty()
- async def GetRemoteName(self, request, context):
- if request.WhichOneof('remote') == 'connection':
- connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
- logging.info(f"GetRemoteName: {connection_handle}")
+ def unpack_data_types(self, dt: DataTypes) -> AdvertisingData:
+ ad_structures: List[Tuple[int, bytes]] = []
- remote = self.device.lookup_connection(connection_handle)
- else:
- # Need to reverse bytes order since Bumble Address is using MSB.
- remote = Address(bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS)
- logging.info(f"GetRemoteName: {remote}")
+ uuids: List[str]
+ datas: Dict[str, bytes]
- try:
- remote_name = await self.device.request_remote_name(remote)
- return GetRemoteNameResponse(name=remote_name)
- except HCI_Error as e:
- if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
- logging.warning(f"Peer not found: {e}")
- return GetRemoteNameResponse(remote_not_found=Empty())
- raise e
-
-
- def unpack_data_types(self, datas) -> AdvertisingData:
- res = AdvertisingData()
- if data := datas.incomplete_service_class_uuids16:
- res.ad_structures.append((
- AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
- b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in data])
- ))
- if data := datas.complete_service_class_uuids16:
- res.ad_structures.append((
- AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
- b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in data])
- ))
- if data := datas.incomplete_service_class_uuids32:
- res.ad_structures.append((
- AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
- b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in data])
- ))
- if data := datas.complete_service_class_uuids32:
- res.ad_structures.append((
- AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
- b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in data])
- ))
- if data := datas.incomplete_service_class_uuids128:
- res.ad_structures.append((
- AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
- b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in data])
- ))
- if data := datas.complete_service_class_uuids128:
- res.ad_structures.append((
- AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
- b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in data])
- ))
- if datas.HasField('include_shortened_local_name'):
- res.ad_structures.append((
- AdvertisingData.SHORTENED_LOCAL_NAME,
- bytes(self.device.name[:8], 'utf-8')
- ))
- elif data := datas.shortened_local_name:
- res.ad_structures.append((
- AdvertisingData.SHORTENED_LOCAL_NAME,
- bytes(data, 'utf-8')
- ))
- if datas.HasField('include_complete_local_name'):
- res.ad_structures.append((
- AdvertisingData.COMPLETE_LOCAL_NAME,
- bytes(self.device.name, 'utf-8')
- ))
- elif data := datas.complete_local_name:
- res.ad_structures.append((
- AdvertisingData.COMPLETE_LOCAL_NAME,
- bytes(data, 'utf-8')
- ))
- if datas.HasField('include_tx_power_level'):
+ if uuids := dt.incomplete_service_class_uuids16:
+ ad_structures.append(
+ (
+ AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if uuids := dt.complete_service_class_uuids16:
+ ad_structures.append(
+ (
+ AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if uuids := dt.incomplete_service_class_uuids32:
+ ad_structures.append(
+ (
+ AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if uuids := dt.complete_service_class_uuids32:
+ ad_structures.append(
+ (
+ AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if uuids := dt.incomplete_service_class_uuids128:
+ ad_structures.append(
+ (
+ AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if uuids := dt.complete_service_class_uuids128:
+ ad_structures.append(
+ (
+ AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if dt.HasField('include_shortened_local_name'):
+ ad_structures.append((AdvertisingData.SHORTENED_LOCAL_NAME, bytes(self.device.name[:8], 'utf-8')))
+ elif dt.shortened_local_name:
+ ad_structures.append((AdvertisingData.SHORTENED_LOCAL_NAME, bytes(dt.shortened_local_name, 'utf-8')))
+ if dt.HasField('include_complete_local_name'):
+ ad_structures.append((AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.device.name, 'utf-8')))
+ elif dt.complete_local_name:
+ ad_structures.append((AdvertisingData.COMPLETE_LOCAL_NAME, bytes(dt.complete_local_name, 'utf-8')))
+ if dt.HasField('include_tx_power_level'):
raise ValueError('unsupported data type')
- elif data := datas.tx_power_level:
- res.ad_structures.append((
- AdvertisingData.TX_POWER_LEVEL,
- bytes(struct.pack('<I', data)[:1])
- ))
- if datas.HasField('include_class_of_device'):
- res.ad_structures.append((
- AdvertisingData.CLASS_OF_DEVICE,
- bytes(struct.pack('<I', self.device.class_of_device)[:-1])
- ))
- elif data := datas.class_of_device:
- res.ad_structures.append((
- AdvertisingData.CLASS_OF_DEVICE,
- bytes(struct.pack('<I', data)[:-1])
- ))
- if data := datas.peripheral_connection_interval_min:
- res.ad_structures.append((
- AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE,
- bytes([
- *struct.pack('<H', data),
- *struct.pack('<H', datas.peripheral_connection_interval_max \
- if datas.peripheral_connection_interval_max else data)
- ])
- ))
- if data := datas.service_solicitation_uuids16:
- res.ad_structures.append((
- AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
- bytes([reversed(bytes.fromhex(uuid)) for uuid in data])
- ))
- if data := datas.service_solicitation_uuids32:
- res.ad_structures.append((
- AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
- bytes([reversed(bytes.fromhex(uuid)) for uuid in data])
- ))
- if data := datas.service_solicitation_uuids128:
- res.ad_structures.append((
- AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
- bytes([reversed(bytes.fromhex(uuid)) for uuid in data])
- ))
- # TODO: use `bytes.fromhex(uuid) + (data)` instead of `.extend`.
- # we may also need to remove all the `reverse`
- if data := datas.service_data_uuid16:
- res.ad_structures.extend([(
- AdvertisingData.SERVICE_DATA_16_BIT_UUID,
- bytes.fromhex(uuid).extend(data)
- ) for uuid, data in data.items()])
- if data := datas.service_data_uuid32:
- res.ad_structures.extend([(
- AdvertisingData.SERVICE_DATA_32_BIT_UUID,
- bytes.fromhex(uuid).extend(data)
- ) for uuid, data in data.items()])
- if data := datas.service_data_uuid128:
- res.ad_structures.extend([(
- AdvertisingData.SERVICE_DATA_128_BIT_UUID,
- bytes.fromhex(uuid).extend(data)
- ) for uuid, data in data.items()])
- if data := datas.appearance:
- res.ad_structures.append((
- AdvertisingData.APPEARANCE,
- struct.pack('<H', data)
- ))
- if data := datas.advertising_interval:
- res.ad_structures.append((
- AdvertisingData.ADVERTISING_INTERVAL,
- struct.pack('<H', data)
- ))
- if data := datas.uri:
- res.ad_structures.append((
- AdvertisingData.URI,
- bytes(data, 'utf-8')
- ))
- if data := datas.le_supported_features:
- res.ad_structures.append((
- AdvertisingData.LE_SUPPORTED_FEATURES,
- data
- ))
- if data := datas.manufacturer_specific_data:
- res.ad_structures.append((
- AdvertisingData.MANUFACTURER_SPECIFIC_DATA,
- data
- ))
- return res
+ elif dt.tx_power_level:
+ ad_structures.append((AdvertisingData.TX_POWER_LEVEL, bytes(struct.pack('<I', dt.tx_power_level)[:1])))
+ if dt.HasField('include_class_of_device'):
+ ad_structures.append(
+ (AdvertisingData.CLASS_OF_DEVICE, bytes(struct.pack('<I', self.device.class_of_device)[:-1]))
+ )
+ elif dt.class_of_device:
+ ad_structures.append((AdvertisingData.CLASS_OF_DEVICE, bytes(struct.pack('<I', dt.class_of_device)[:-1])))
+ if dt.peripheral_connection_interval_min:
+ ad_structures.append(
+ (
+ AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE,
+ bytes(
+ [
+ *struct.pack('<H', dt.peripheral_connection_interval_min),
+ *struct.pack(
+ '<H',
+ dt.peripheral_connection_interval_max
+ if dt.peripheral_connection_interval_max
+ else dt.peripheral_connection_interval_min,
+ ),
+ ]
+ ),
+ )
+ )
+ if uuids := dt.service_solicitation_uuids16:
+ ad_structures.append(
+ (
+ AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if uuids := dt.service_solicitation_uuids32:
+ ad_structures.append(
+ (
+ AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if uuids := dt.service_solicitation_uuids128:
+ ad_structures.append(
+ (
+ AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
+ b''.join([bytes(reversed(bytes.fromhex(uuid))) for uuid in uuids]),
+ )
+ )
+ if datas := dt.service_data_uuid16:
+ ad_structures.extend(
+ [
+ (AdvertisingData.SERVICE_DATA_16_BIT_UUID, bytes.fromhex(uuid) + data)
+ for uuid, data in datas.items()
+ ]
+ )
+ if datas := dt.service_data_uuid32:
+ ad_structures.extend(
+ [
+ (AdvertisingData.SERVICE_DATA_32_BIT_UUID, bytes.fromhex(uuid) + data)
+ for uuid, data in datas.items()
+ ]
+ )
+ if datas := dt.service_data_uuid128:
+ ad_structures.extend(
+ [
+ (AdvertisingData.SERVICE_DATA_128_BIT_UUID, bytes.fromhex(uuid) + data)
+ for uuid, data in datas.items()
+ ]
+ )
+ if dt.appearance:
+ ad_structures.append((AdvertisingData.APPEARANCE, struct.pack('<H', dt.appearance)))
+ if dt.advertising_interval:
+ ad_structures.append((AdvertisingData.ADVERTISING_INTERVAL, struct.pack('<H', dt.advertising_interval)))
+ if dt.uri:
+ ad_structures.append((AdvertisingData.URI, bytes(dt.uri, 'utf-8')))
+ if dt.le_supported_features:
+ ad_structures.append((AdvertisingData.LE_SUPPORTED_FEATURES, dt.le_supported_features))
+ if dt.manufacturer_specific_data:
+ ad_structures.append((AdvertisingData.MANUFACTURER_SPECIFIC_DATA, dt.manufacturer_specific_data))
+ return AdvertisingData(ad_structures)
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
- kwargs = {
- 'service_data_uuid16': {},
- 'service_data_uuid32': {},
- 'service_data_uuid128': {}
- }
+ dt = DataTypes()
+ uuids: List[UUID]
+ s: str
+ i: int
+ ij: Tuple[int, int]
+ uuid_data: Tuple[UUID, bytes]
+ data: bytes
- if data := ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS):
- kwargs['incomplete_service_class_uuids16'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS):
- kwargs['complete_service_class_uuids16'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS):
- kwargs['incomplete_service_class_uuids32'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS):
- kwargs['complete_service_class_uuids32'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS):
- kwargs['incomplete_service_class_uuids128'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS):
- kwargs['complete_service_class_uuids128'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.SHORTENED_LOCAL_NAME):
- kwargs['shortened_local_name'] = data
- if data := ad.get(AdvertisingData.COMPLETE_LOCAL_NAME):
- kwargs['complete_local_name'] = data
- if data := ad.get(AdvertisingData.TX_POWER_LEVEL):
- kwargs['tx_power_level'] = data
- if data := ad.get(AdvertisingData.CLASS_OF_DEVICE):
- kwargs['class_of_device'] = data
- if data := ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE):
- kwargs['peripheral_connection_interval_min'] = data[0]
- kwargs['peripheral_connection_interval_max'] = data[1]
- if data := ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS):
- kwargs['service_solicitation_uuids16'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS):
- kwargs['service_solicitation_uuids32'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS):
- kwargs['service_solicitation_uuids128'] = list(map(lambda x: x.to_hex_str(), data))
- if data := ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID):
- kwargs['service_data_uuid16'][data[0].to_hex_str()] = data[1]
- if data := ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID):
- kwargs['service_data_uuid32'][data[0].to_hex_str()] = data[1]
- if data := ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID):
- kwargs['service_data_uuid128'][data[0].to_hex_str()] = data[1]
- if data := ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True):
- kwargs['public_target_addresses'] = [data[i*6::i*6+6] for i in range(len(data) / 6)]
- if data := ad.get(AdvertisingData.RANDOM_TARGET_ADDRESS, raw=True):
- kwargs['random_target_addresses'] = [data[i*6::i*6+6] for i in range(len(data) / 6)]
- if data := ad.get(AdvertisingData.APPEARANCE):
- kwargs['appearance'] = data
- if data := ad.get(AdvertisingData.ADVERTISING_INTERVAL):
- kwargs['advertising_interval'] = data
- if data := ad.get(AdvertisingData.URI):
- kwargs['uri'] = data
- if data := ad.get(AdvertisingData.LE_SUPPORTED_FEATURES, raw=True):
- kwargs['le_supported_features'] = data
- if data := ad.get(AdvertisingData.MANUFACTURER_SPECIFIC_DATA, raw=True):
- kwargs['manufacturer_specific_data'] = data
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS)):
+ dt.incomplete_service_class_uuids16.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS)):
+ dt.complete_service_class_uuids16.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS)):
+ dt.incomplete_service_class_uuids32.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS)):
+ dt.complete_service_class_uuids32.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS)):
+ dt.incomplete_service_class_uuids128.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS)):
+ dt.complete_service_class_uuids128.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if s := cast(str, ad.get(AdvertisingData.SHORTENED_LOCAL_NAME)):
+ dt.shortened_local_name = s
+ if s := cast(str, ad.get(AdvertisingData.COMPLETE_LOCAL_NAME)):
+ dt.complete_local_name = s
+ if i := cast(int, ad.get(AdvertisingData.TX_POWER_LEVEL)):
+ dt.tx_power_level = i
+ if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)):
+ dt.class_of_device = i
+ if ij := cast(Tuple[int, int], ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE)):
+ dt.peripheral_connection_interval_min = ij[0]
+ dt.peripheral_connection_interval_max = ij[1]
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS)):
+ dt.service_solicitation_uuids16.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS)):
+ dt.service_solicitation_uuids32.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuids := cast(List[UUID], ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS)):
+ dt.service_solicitation_uuids128.extend(list(map(lambda x: x.to_hex_str(), uuids)))
+ if uuid_data := cast(Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID)):
+ dt.service_data_uuid16[uuid_data[0].to_hex_str()] = uuid_data[1]
+ if uuid_data := cast(Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID)):
+ dt.service_data_uuid32[uuid_data[0].to_hex_str()] = uuid_data[1]
+ if uuid_data := cast(Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID)):
+ dt.service_data_uuid128[uuid_data[0].to_hex_str()] = uuid_data[1]
+ if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)):
+ dt.public_target_addresses.extend([data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))])
+ if data := cast(bytes, ad.get(AdvertisingData.RANDOM_TARGET_ADDRESS, raw=True)):
+ dt.random_target_addresses.extend([data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))])
+ if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
+ dt.appearance = i
+ if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
+ dt.advertising_interval = i
+ if s := cast(str, ad.get(AdvertisingData.URI)):
+ dt.uri = s
+ if data := cast(bytes, ad.get(AdvertisingData.LE_SUPPORTED_FEATURES, raw=True)):
+ dt.le_supported_features = data
+ if data := cast(bytes, ad.get(AdvertisingData.MANUFACTURER_SPECIFIC_DATA, raw=True)):
+ dt.manufacturer_specific_data = data
- if not len(kwargs['service_data_uuid16']):
- del kwargs['service_data_uuid16']
- if not len(kwargs['service_data_uuid32']):
- del kwargs['service_data_uuid32']
- if not len(kwargs['service_data_uuid128']):
- del kwargs['service_data_uuid128']
-
- return DataTypes(**kwargs)
+ return dt
diff --git a/avatar/bumble_server/security.py b/avatar/bumble_server/security.py
index c00cad4..8db39da 100644
--- a/avatar/bumble_server/security.py
+++ b/avatar/bumble_server/security.py
@@ -13,121 +13,172 @@
# limitations under the License.
import asyncio
+import grpc
import logging
-from contextlib import suppress
-from typing import AsyncIterator, Optional
-
-from avatar.bumble_server.utils import address_from_request
-
+from avatar.bumble_server.utils import BumbleServerLoggerAdapter, address_from_request
from bumble import hci
-from bumble.core import BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
+from bumble.core import BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT, ProtocolError
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
-from bumble.smp import PairingConfig, PairingDelegate as BasePairingDelegate, Session
-
-from google.protobuf.any_pb2 import Any
-from google.protobuf.empty_pb2 import Empty
+from bumble.smp import PairingConfig, PairingDelegate as BasePairingDelegate
+from contextlib import suppress
+from google.protobuf import any_pb2, empty_pb2, wrappers_pb2
from google.protobuf.wrappers_pb2 import BoolValue
-
from pandora.host_pb2 import Connection
-
-from pandora.security_grpc import SecurityServicer, SecurityStorageServicer
+from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import (
- SecurityLevel, LESecurityLevel,
- PairingEvent, PairingEventAnswer,
- SecureResponse, WaitSecurityResponse
+ LE_LEVEL1,
+ LE_LEVEL2,
+ LE_LEVEL3,
+ LE_LEVEL4,
+ LEVEL0,
+ LEVEL1,
+ LEVEL2,
+ LEVEL3,
+ LEVEL4,
+ DeleteBondRequest,
+ IsBondedRequest,
+ LESecurityLevel,
+ PairingEvent,
+ PairingEventAnswer,
+ SecureRequest,
+ SecureResponse,
+ SecurityLevel,
+ WaitSecurityRequest,
+ WaitSecurityResponse,
)
+from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union, cast
class PairingDelegate(BasePairingDelegate):
-
- def __init__(self,
+ def __init__(
+ self,
connection: BumbleConnection,
service: "SecurityService",
- **kwargs
- ):
+ io_capability: int = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
+ local_initiator_key_distribution: int = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
+ local_responder_key_distribution: int = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
+ ) -> None:
+ self.log = BumbleServerLoggerAdapter(
+ logging.getLogger(), {'service_name': 'Security', 'device': connection.device}
+ )
self.connection = connection
self.service = service
- super().__init__(**kwargs)
+ super().__init__(io_capability, local_initiator_key_distribution, local_responder_key_distribution)
- async def accept(self):
+ async def accept(self) -> bool:
return True
- def build_event(self, **kwargs):
- if self.connection.handle is not None:
- kwargs['connection'] = Connection(cookie=Any(value=self.connection.handle.to_bytes(4, 'big')))
+ def add_origin(self, ev: PairingEvent) -> PairingEvent:
+ if not self.connection.is_incomplete:
+ assert ev.connection
+ ev.connection.CopyFrom(Connection(cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big'))))
else:
# In BR/EDR, connection may not be complete,
# use address instead
assert self.connection.transport == BT_BR_EDR_TRANSPORT
- kwargs['address'] = bytes(reversed(bytes(self.connection.peer_address)))
+ ev.address = bytes(reversed(bytes(self.connection.peer_address)))
- return PairingEvent(**kwargs)
+ return ev
- async def confirm(self):
- logging.info(f"Pairing event: `just_works` (io_capability: {self.io_capability})")
+ async def confirm(self) -> bool:
+ self.log.info(f"Pairing event: `just_works` (io_capability: {self.io_capability})")
- if not self.service.event_queue:
+ if not self.service.event_queue or not self.service.event_answer:
return True
- self.service.event_queue.put_nowait((event := self.build_event(just_works=Empty())))
+ event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
+ self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer)
assert answer.event == event
-
+ assert answer.confirm
return answer.confirm
- async def compare_numbers(self, number, digits=6):
- logging.info(f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})")
+ async def compare_numbers(self, number: int, digits: int = 6) -> bool:
+ self.log.info(f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})")
- if not self.service.event_queue:
+ if not self.service.event_queue or not self.service.event_answer:
raise RuntimeError('security: unhandled number comparison request')
- self.service.event_queue.put_nowait((event := self.build_event(numeric_comparison=number)))
+ event = self.add_origin(PairingEvent(numeric_comparison=number))
+ self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer)
assert answer.event == event
-
+ assert answer.confirm
return answer.confirm
- async def get_number(self):
- logging.info(f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})")
+ async def get_number(self) -> int:
+ self.log.info(f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})")
- if not self.service.event_queue:
+ if not self.service.event_queue or not self.service.event_answer:
raise RuntimeError('security: unhandled number request')
- self.service.event_queue.put_nowait((event := self.build_event(passkey_entry_request=Empty())))
+ event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
+ self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer)
assert answer.event == event
-
+ assert answer.passkey is not None
return answer.passkey
- async def display_number(self, number, digits=6):
- logging.info(f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})")
+ async def display_number(self, number: int, digits: int = 6) -> None:
+ self.log.info(f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})")
if not self.service.event_queue:
raise RuntimeError('security: unhandled number display request')
- self.service.event_queue.put_nowait(self.build_event(passkey_entry_notification=number))
+ event = self.add_origin(PairingEvent(passkey_entry_notification=number))
+ self.service.event_queue.put_nowait(event)
+
+
+BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
+ LEVEL0: lambda connection: True,
+ LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
+ LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
+ LEVEL3: lambda connection: connection.encryption != 0
+ and connection.authenticated
+ and connection.link_key_type
+ in (
+ hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
+ hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
+ ),
+ LEVEL4: lambda connection: connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM
+ and connection.authenticated
+ and connection.link_key_type == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
+}
+
+LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
+ LE_LEVEL1: lambda connection: True,
+ LE_LEVEL2: lambda connection: connection.encryption != 0,
+ LE_LEVEL3: lambda connection: connection.encryption != 0 and connection.authenticated,
+ LE_LEVEL4: lambda connection: connection.encryption != 0 and connection.authenticated and connection.sc,
+}
class SecurityService(SecurityServicer):
-
- def __init__(self, device: Device, io_capability):
- super().__init__()
- self.event_queue: asyncio.Queue = None
+ def __init__(self, device: Device, io_capability: int) -> None:
+ self.log = BumbleServerLoggerAdapter(logging.getLogger(), {'service_name': 'Security', 'device': device})
+ self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
self.device = device
- self.device.io_capability = io_capability
- self.device.pairing_config_factory = lambda connection: PairingConfig(
- sc=True, mitm=True, bonding=True,
- delegate=PairingDelegate(
- connection, self,
- io_capability=self.device.io_capability,
- )
- )
- async def OnPairing(self, event_answer, context):
- logging.info('OnPairing')
+ def pairing_config_factory(connection: BumbleConnection) -> PairingConfig:
+ return PairingConfig(
+ sc=True,
+ mitm=True,
+ bonding=True,
+ delegate=PairingDelegate(
+ connection, self, io_capability=cast(int, getattr(self.device, 'io_capability'))
+ ),
+ )
+
+ setattr(device, 'io_capability', io_capability)
+ self.device.pairing_config_factory = pairing_config_factory
+
+ async def OnPairing(
+ self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
+ ) -> AsyncGenerator[PairingEvent, None]:
+ self.log.info('OnPairing')
if self.event_queue:
raise RuntimeError('already streaming pairing events')
@@ -136,7 +187,7 @@
raise RuntimeError('the `OnPairing` method shall be initiated before establishing any connections.')
self.event_queue = asyncio.Queue()
- self.event_answer = event_answer
+ self.event_answer = request
try:
while event := await self.event_queue.get():
@@ -146,117 +197,122 @@
self.event_queue = None
self.event_answer = None
- async def Secure(self, request, context):
+ async def Secure(self, request: SecureRequest, context: grpc.ServicerContext) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
- logging.info(f"Secure: {connection_handle}")
+ self.log.info(f"Secure: {connection_handle}")
- connection: BumbleConnection = self.device.lookup_connection(connection_handle)
+ connection = self.device.lookup_connection(connection_handle)
assert connection
oneof = request.WhichOneof('level')
level = getattr(request, oneof)
- assert { BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le' }[connection.transport] == oneof
+ assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[connection.transport] == oneof
# security level already reached
if self.reached_security_level(connection, level):
- return SecureResponse(success=Empty())
+ return SecureResponse(success=empty_pb2.Empty())
# trigger pairing if needed
if self.need_pairing(connection, level):
try:
- logging.info('Pair...')
+ self.log.info('Pair...')
await connection.pair()
- logging.info('Paired')
+ self.log.info('Paired')
except asyncio.CancelledError:
- logging.warning(f"Connection died during encryption")
- return SecureResponse(connection_died=Empty())
- except HCI_Error as e:
- logging.warning(f"Pairing failure: {e}")
- return SecureResponse(pairing_failure=Empty())
+ self.log.warning(f"Connection died during encryption")
+ return SecureResponse(connection_died=empty_pb2.Empty())
+ except (HCI_Error, ProtocolError) as e:
+ self.log.warning(f"Pairing failure: {e}")
+ return SecureResponse(pairing_failure=empty_pb2.Empty())
# trigger authentication if needed
if self.need_authentication(connection, level):
try:
- logging.info('Authenticate...')
+ self.log.info('Authenticate...')
await connection.authenticate()
- logging.info('Authenticated')
+ self.log.info('Authenticated')
except asyncio.CancelledError:
- logging.warning(f"Connection died during authentication")
- return SecureResponse(connection_died=Empty())
- except HCI_Error as e:
- logging.warning(f"Authentication failure: {e}")
- return SecureResponse(authentication_failure=Empty())
+ self.log.warning(f"Connection died during authentication")
+ return SecureResponse(connection_died=empty_pb2.Empty())
+ except (HCI_Error, ProtocolError) as e:
+ self.log.warning(f"Authentication failure: {e}")
+ return SecureResponse(authentication_failure=empty_pb2.Empty())
# trigger encryption if needed
if self.need_encryption(connection, level):
try:
- logging.info('Encrypt...')
+ self.log.info('Encrypt...')
await connection.encrypt()
- logging.info('Encrypted')
+ self.log.info('Encrypted')
except asyncio.CancelledError:
- logging.warning(f"Connection died during encryption")
- return SecureResponse(connection_died=Empty())
- except HCI_Error as e:
- logging.warning(f"Encryption failure: {e}")
- return SecureResponse(encryption_failure=Empty())
+ self.log.warning(f"Connection died during encryption")
+ return SecureResponse(connection_died=empty_pb2.Empty())
+ except (HCI_Error, ProtocolError) as e:
+ self.log.warning(f"Encryption failure: {e}")
+ return SecureResponse(encryption_failure=empty_pb2.Empty())
# security level has been reached ?
if self.reached_security_level(connection, level):
- return SecureResponse(success=Empty())
- return SecureResponse(not_reached=Empty())
+ return SecureResponse(success=empty_pb2.Empty())
+ return SecureResponse(not_reached=empty_pb2.Empty())
-
- async def WaitSecurity(self, request, context):
+ async def WaitSecurity(self, request: WaitSecurityRequest, context: grpc.ServicerContext) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
- logging.info(f"WaitSecurity: {connection_handle}")
+ self.log.info(f"WaitSecurity: {connection_handle}")
- connection: BumbleConnection = self.device.lookup_connection(connection_handle)
+ connection = self.device.lookup_connection(connection_handle)
assert connection
- oneof = request.WhichOneof('level')
- level = getattr(request, oneof)
- assert { BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le' }[connection.transport] == oneof
+ assert request.level
+ level = request.level
+ assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[connection.transport] == request.level_variant()
- wait_for_security = asyncio.get_running_loop().create_future()
- authenticate_task = None
+ wait_for_security: asyncio.Future[str] = asyncio.get_running_loop().create_future()
+ authenticate_task: Optional[asyncio.Future[None]] = None
- async def authenticate():
+ async def authenticate() -> None:
+ assert connection
if (encryption := connection.encryption) != 0:
- logging.debug('Disable encryption...')
- try: await connection.encrypt(enable=0x00)
- except: pass
- logging.debug('Disable encryption: done')
+ self.log.debug('Disable encryption...')
+ try:
+ await connection.encrypt(enable=False)
+ except:
+ pass
+ self.log.debug('Disable encryption: done')
- logging.debug('Authenticate...')
+ self.log.debug('Authenticate...')
await connection.authenticate()
- logging.debug('Authenticate: done')
+ self.log.debug('Authenticate: done')
if encryption != 0 and connection.encryption != encryption:
- logging.debug('Re-enable encryption...')
+ self.log.debug('Re-enable encryption...')
await connection.encrypt()
- logging.debug('Re-enable encryption: done')
+ self.log.debug('Re-enable encryption: done')
- def set_failure(name):
- def wrapper(*args):
- logging.info(f'Wait for security: error `{name}`: {args}')
+ def set_failure(name: str) -> Callable[..., None]:
+ def wrapper(*args: Any) -> None:
+ self.log.info(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name)
+
return wrapper
- def try_set_success(*_):
+ def try_set_success(*_: Any) -> None:
+ assert connection
if self.reached_security_level(connection, level):
- logging.info(f'Wait for security: done')
+ self.log.info(f'Wait for security: done')
wait_for_security.set_result('success')
- def on_encryption_change(*_):
+ def on_encryption_change(*_: Any) -> None:
+ assert connection
if self.reached_security_level(connection, level):
- logging.info(f'Wait for security: done')
+ self.log.info(f'Wait for security: done')
wait_for_security.set_result('success')
elif connection.transport == BT_BR_EDR_TRANSPORT and self.need_authentication(connection, level):
nonlocal authenticate_task
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
- listeners = {
+ listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'),
@@ -272,91 +328,75 @@
# security level already reached
if self.reached_security_level(connection, level):
- return WaitSecurityResponse(success=Empty())
+ return WaitSecurityResponse(success=empty_pb2.Empty())
- logging.info('Wait for security...')
+ self.log.info('Wait for security...')
kwargs = {}
- kwargs[await wait_for_security] = Empty()
+ kwargs[await wait_for_security] = empty_pb2.Empty()
# remove event handlers
for event, listener in listeners.items():
- connection.remove_listener(event, listener)
+ connection.remove_listener(event, listener) # type: ignore
# wait for `authenticate` to finish if any
if authenticate_task is not None:
- try: await authenticate_task
- except: pass
+ self.log.info('Wait for authentication...')
+ try:
+ await authenticate_task # type: ignore
+ except:
+ pass
+ self.log.info('Authenticated')
return WaitSecurityResponse(**kwargs)
- def reached_security_level(self, connection: BumbleConnection, level: int):
- logging.debug(str({
- 'level': level,
- 'encryption': connection.encryption,
- 'authenticated': connection.authenticated,
- 'sc': connection.sc,
- 'link_key_type': connection.link_key_type
- }))
+ def reached_security_level(
+ self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
+ ) -> bool:
+ self.log.debug(
+ str(
+ {
+ 'level': level,
+ 'encryption': connection.encryption,
+ 'authenticated': connection.authenticated,
+ 'sc': connection.sc,
+ 'link_key_type': connection.link_key_type,
+ }
+ )
+ )
+ if isinstance(level, LESecurityLevel):
+ return LE_LEVEL_REACHED[level](connection)
+
+ return BR_LEVEL_REACHED[level](connection)
+
+ def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT:
- return {
- LESecurityLevel.LE_LEVEL1: lambda:
- True,
- LESecurityLevel.LE_LEVEL2: lambda:
- connection.encryption != 0,
- LESecurityLevel.LE_LEVEL3: lambda:
- connection.encryption != 0 and
- connection.authenticated,
- LESecurityLevel.LE_LEVEL4: lambda:
- connection.encryption != 0 and
- connection.authenticated and
- connection.sc,
- }[level]()
-
- return {
- SecurityLevel.LEVEL0: lambda:
- True,
- SecurityLevel.LEVEL1: lambda:
- connection.encryption == 0 or connection.authenticated,
- SecurityLevel.LEVEL2: lambda:
- connection.encryption != 0 and
- connection.authenticated,
- SecurityLevel.LEVEL3: lambda:
- connection.encryption != 0 and
- connection.authenticated and
- connection.link_key_type in (hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE),
- SecurityLevel.LEVEL4: lambda:
- connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM and
- connection.authenticated and
- connection.link_key_type == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
- }[level]()
-
- def need_pairing(self, connection: BumbleConnection, level: int):
- if connection.transport == BT_LE_TRANSPORT:
- return level >= LESecurityLevel.LE_LEVEL3 and not connection.authenticated
+ return level >= LE_LEVEL3 and not connection.authenticated
return False
- def need_authentication(self, connection: BumbleConnection, level: int):
+ def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT:
return False
- if level == SecurityLevel.LEVEL2 and connection.encryption != 0:
+ if level == LEVEL2 and connection.encryption != 0:
return not connection.authenticated
- return level >= SecurityLevel.LEVEL2 and not connection.authenticated
+ return level >= LEVEL2 and not connection.authenticated
- def need_encryption(self, connection: BumbleConnection, level: int):
+ def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT:
- return level == LESecurityLevel.LE_LEVEL2 and not connection.encryption
- return level >= SecurityLevel.LEVEL2 and not connection.encryption
+ return level == LE_LEVEL2 and not connection.encryption
+ return level >= LEVEL2 and not connection.encryption
+
class SecurityStorageService(SecurityStorageServicer):
-
- def __init__(self, device: Device):
- super().__init__()
+ def __init__(self, device: Device) -> None:
+ self.log = BumbleServerLoggerAdapter(
+ logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device}
+ )
self.device = device
- async def IsBonded(self, request, context):
+ async def IsBonded(self, request: IsBondedRequest, context: grpc.ServicerContext) -> wrappers_pb2.BoolValue:
address = address_from_request(request, request.WhichOneof("address"))
- logging.info(f"IsBonded: {address}")
+ self.log.info(f"IsBonded: {address}")
if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None
@@ -365,12 +405,12 @@
return BoolValue(value=is_bonded)
- async def DeleteBond(self, request, context):
+ async def DeleteBond(self, request: DeleteBondRequest, context: grpc.ServicerContext) -> empty_pb2.Empty:
address = address_from_request(request, request.WhichOneof("address"))
- logging.info(f"DeleteBond: {address}")
+ self.log.info(f"DeleteBond: {address}")
if self.device.keystore is not None:
with suppress(KeyError):
await self.device.keystore.delete(str(address))
- return Empty()
+ return empty_pb2.Empty()
diff --git a/avatar/bumble_server/utils.py b/avatar/bumble_server/utils.py
index 30060e7..f841da0 100644
--- a/avatar/bumble_server/utils.py
+++ b/avatar/bumble_server/utils.py
@@ -12,15 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from bumble.device import Device
from bumble.hci import Address
+from google.protobuf.message import Message
+from typing import Any, MutableMapping, Optional, Tuple
-
-ADDRESS_TYPES = {
+ADDRESS_TYPES: dict[str, int] = {
"public": Address.PUBLIC_DEVICE_ADDRESS,
"random": Address.RANDOM_DEVICE_ADDRESS,
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
- "random_static_identity": Address.RANDOM_IDENTITY_ADDRESS
+ "random_static_identity": Address.RANDOM_IDENTITY_ADDRESS,
}
-def address_from_request(request, field):
- return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
\ No newline at end of file
+
+def address_from_request(request: Message, field: Optional[str]) -> Address:
+ if field is None:
+ return Address.ANY
+ return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
+
+
+class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore
+ """Formats logs from the PandoraClient."""
+
+ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]:
+ assert self.extra
+ service_name = self.extra['service_name']
+ assert isinstance(service_name, str)
+ device = self.extra['device']
+ assert isinstance(device, Device)
+ addr_bytes = bytes(reversed(bytes(device.public_address)))
+ addr = ':'.join([f'{x:02X}' for x in addr_bytes[4:]])
+ return (f'[bumble.{service_name}:{addr}] {msg}', kwargs)
diff --git a/avatar/controllers/__init__.py b/avatar/controllers/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/avatar/controllers/__init__.py
+++ /dev/null
diff --git a/avatar/controllers/bumble_device.py b/avatar/controllers/bumble_device.py
new file mode 100644
index 0000000..157139c
--- /dev/null
+++ b/avatar/controllers/bumble_device.py
@@ -0,0 +1,42 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Bumble device Mobly controller."""
+
+import asyncio
+import avatar.aio
+
+from avatar.bumble_device import BumbleDevice
+from typing import Any, Dict, List, Optional
+
+MOBLY_CONTROLLER_CONFIG_NAME = 'BumbleDevice'
+
+
+def create(configs: List[Dict[str, Any]]) -> List[BumbleDevice]:
+ """Create a list of `BumbleDevice` from configs."""
+ return [BumbleDevice(config) for config in configs]
+
+
+def destroy(devices: List[BumbleDevice]) -> None:
+ """Destroy each `BumbleDevice`"""
+
+ async def close_devices() -> None:
+ await asyncio.gather(*(device.close() for device in devices))
+
+ avatar.aio.run_until_complete(close_devices())
+
+
+def get_info(devices: List[BumbleDevice]) -> List[Optional[Dict[str, str]]]:
+ """Return the device info for each `BumbleDevice`."""
+ return [device.info() for device in devices]
diff --git a/avatar/controllers/pandora_device.py b/avatar/controllers/pandora_device.py
index cc4e349..298950a 100644
--- a/avatar/controllers/pandora_device.py
+++ b/avatar/controllers/pandora_device.py
@@ -1,7 +1,3 @@
-# Copyright 2022 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
@@ -12,144 +8,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import grpc
+"""Pandora device Mobly controller."""
+
import importlib
-import asyncio
-import avatar
-import contextlib
-import mobly.controllers.android_device
-import mobly.signals
-
-from contextlib import suppress
-
-from ..android_service import ANDROID_SERVER_GRPC_PORT, AndroidService
-from ..bumble_server import BumblePandoraServer
-from ..utils import Address
-
-from pandora.host_grpc import Host
-from pandora.security_grpc import Security, SecurityStorage
-from pandora.asha_grpc import ASHA
-
-from bumble.device import Device
+from avatar.pandora_client import PandoraClient
+from typing import Any, Dict, List, Optional, cast
MOBLY_CONTROLLER_CONFIG_NAME = 'PandoraDevice'
-def create(configs):
- def create_device(config):
- module_name = config.pop('module', PandoraDevice.__module__)
- class_name = config.pop('class', PandoraDevice.__name__)
+def create(configs: List[Dict[str, Any]]) -> List[PandoraClient]:
+ """Create a list of `PandoraClient` from configs."""
+
+ def create_device(config: Dict[str, Any]) -> PandoraClient:
+ module_name = config.pop('module', PandoraClient.__module__)
+ class_name = config.pop('class', PandoraClient.__name__)
module = importlib.import_module(module_name)
- return getattr(module, class_name)(**config)
+ return cast(PandoraClient, getattr(module, class_name)(**config))
return list(map(create_device, configs))
-def destroy(devices):
- [device.destroy() for device in devices]
+
+def destroy(devices: List['PandoraClient']) -> None:
+ """Destroy each `PandoraClient`"""
+ for device in devices:
+ device.close()
-class PandoraDevice:
-
- def __init__(self, target):
- self._address = Address(b'\x00\x00\x00\x00\x00\x00')
- self._target = target
- self._channel = grpc.insecure_channel(target)
- self._aio_channel = None
- self.log = PandoraDeviceLoggerAdapter(logging.getLogger(), self)
-
- def destroy(self):
- self._channel.close()
- if self._aio_channel:
- avatar.run_until_complete(self._aio_channel.close())
-
- @property
- def channel(self):
- # Force the use of the asynchronous channel when running in our event loop.
- with contextlib.suppress(RuntimeError):
- if asyncio.get_running_loop() == avatar.loop:
- if not self._aio_channel:
- self._aio_channel = grpc.aio.insecure_channel(self._target)
- return self._aio_channel
- return self._channel
-
- @property
- def address(self):
- return self._address
-
- @address.setter
- def address(self, bytes):
- self._address = Address(bytes)
-
- @property
- def host(self) -> Host:
- return Host(self.channel)
-
- @property
- def security(self) -> Security:
- return Security(self.channel)
-
- @property
- def security_storage(self) -> SecurityStorage:
- return SecurityStorage(self.channel)
-
- @property
- def asha(self) -> ASHA:
- return ASHA(self.channel)
-
-
-class PandoraDeviceLoggerAdapter(logging.LoggerAdapter):
-
- def process(self, msg, kwargs):
- msg = f'[{self.extra.__class__.__name__}|{self.extra.address}] {msg}'
- return (msg, kwargs)
-
-
-class AndroidPandoraDevice(PandoraDevice):
-
- def __init__(self, config):
- android_devices = mobly.controllers.android_device.create(config)
- if not android_devices:
- raise mobly.signals.ControllerError(
- 'Expected to get at least 1 android controller objects, got 0.')
- head, *tail = android_devices
- mobly.controllers.android_device.destroy(tail)
-
- self.android_device = head
- port = ANDROID_SERVER_GRPC_PORT
- self.android_device.services.register('pandora', AndroidService, configs={
- 'port': port
- })
- super().__init__(f'localhost:{port}')
-
- def destroy(self):
- super().destroy()
- mobly.controllers.android_device.destroy([self.android_device])
-
-
-class BumblePandoraDevice(PandoraDevice):
-
- def __init__(self, transport, **config):
- asyncio.set_event_loop(avatar.loop)
- grpc_server = grpc.aio.server()
- grpc_port = grpc_server.add_insecure_port('localhost:0')
-
- super().__init__(f'localhost:{grpc_port}')
-
- self.device: Device = None
- self.server_task = avatar.loop.create_task(
- BumblePandoraServer.serve(
- transport, config, grpc_server, grpc_port,
- on_started=lambda server: setattr(self, 'device', server.device)
- )
- )
-
- def destroy(self):
- async def server_stop():
- self.server_task.cancel()
- with suppress(asyncio.CancelledError): await self.server_task
-
- super().destroy()
- avatar.run_until_complete(server_stop())
+def get_info(devices: List['PandoraClient']) -> List[Optional[Dict[str, Any]]]:
+ """Return the device info for each `PandoraClient`."""
+ return [{'grpc_target': device.grpc_target, 'bd_addr': str(device.address)} for device in devices]
diff --git a/avatar/pandora_client.py b/avatar/pandora_client.py
new file mode 100644
index 0000000..c576883
--- /dev/null
+++ b/avatar/pandora_client.py
@@ -0,0 +1,190 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Pandora client interface for Avatar tests."""
+
+import avatar.aio
+import bumble
+import bumble.device
+import grpc
+import grpc.aio
+import logging
+
+from avatar.bumble_device import BumbleDevice
+from bumble.hci import Address as BumbleAddress
+from dataclasses import dataclass
+from pandora import asha_grpc, asha_grpc_aio, host_grpc, host_grpc_aio, security_grpc, security_grpc_aio
+from typing import Any, MutableMapping, Optional, Tuple, Union
+
+
+class Address(bytes):
+ def __new__(cls, address: Union[bytes, str, BumbleAddress]) -> 'Address':
+ if type(address) is bytes:
+ address_bytes = address
+ elif type(address) is str:
+ address_bytes = bytes.fromhex(address.replace(':', ''))
+ elif isinstance(address, BumbleAddress):
+ address_bytes = bytes(reversed(bytes(address)))
+ else:
+ raise ValueError('Invalid address format')
+
+ if len(address_bytes) != 6:
+ raise ValueError('Invalid address length')
+
+ return bytes.__new__(cls, address_bytes)
+
+ def __str__(self) -> str:
+ return ':'.join([f'{x:02X}' for x in self])
+
+
+class PandoraClient:
+ """Provides Pandora interface access to a device via gRPC."""
+
+ # public fields
+ grpc_target: str # Server address for the gRPC channel.
+ log: 'PandoraClientLoggerAdapter' # Logger adapter.
+ channel: grpc.Channel # Synchronous gRPC channel.
+
+ # private fields
+ _address: Address # Bluetooth device address
+ _aio: Optional['PandoraClient.Aio'] # Asynchronous gRPC channel.
+
+ def __init__(self, grpc_target: str, name: str = '..') -> None:
+ """Creates a PandoraClient.
+
+ Establishes a channel with the Pandora gRPC server.
+
+ Args:
+ grpc_target: Server address for the gRPC channel.
+ """
+ self.grpc_target = grpc_target
+ self.log = PandoraClientLoggerAdapter(logging.getLogger(), {'client': self, 'client_name': name})
+ self.channel = grpc.insecure_channel(grpc_target) # type: ignore
+ self._address = Address(b'\x00\x00\x00\x00\x00\x00')
+ self._aio = None
+
+ def close(self) -> None:
+ """Closes the gRPC channels."""
+ self.channel.close()
+ if self._aio:
+ avatar.aio.run_until_complete(self._aio.channel.close())
+
+ @property
+ def address(self) -> Address:
+ """Returns the BD address."""
+ return self._address
+
+ @address.setter
+ def address(self, address: Union[bytes, str, BumbleAddress]) -> None:
+ """Sets the BD address."""
+ self._address = Address(address)
+
+ async def reset(self) -> None:
+ """Factory reset the device & read it's BD address."""
+ await self.aio.host.FactoryReset()
+ # Factory reset stopped the server, close the client too.
+ assert self._aio
+ await self._aio.channel.close()
+ self._aio = None
+ # Try to connect to the new server 3 times before failing.
+ for _ in range(0, 3):
+ try:
+ self._address = Address((await self.aio.host.ReadLocalAddress(wait_for_ready=True)).address)
+ return
+ except grpc.RpcError as e:
+ assert e.code() == grpc.StatusCode.UNAVAILABLE # type: ignore
+ raise RuntimeError('unable to establish a new connection after a `FactoryReset`')
+
+ # Pandora interfaces
+
+ @property
+ def host(self) -> host_grpc.Host:
+ """Returns the Pandora Host gRPC interface."""
+ return host_grpc.Host(self.channel)
+
+ @property
+ def security(self) -> security_grpc.Security:
+ """Returns the Pandora Security gRPC interface."""
+ return security_grpc.Security(self.channel)
+
+ @property
+ def security_storage(self) -> security_grpc.SecurityStorage:
+ """Returns the Pandora SecurityStorage gRPC interface."""
+ return security_grpc.SecurityStorage(self.channel)
+
+ @property
+ def asha(self) -> asha_grpc.ASHA:
+ """Returns the Pandora ASHA gRPC interface."""
+ return asha_grpc.ASHA(self.channel)
+
+ @dataclass
+ class Aio:
+ channel: grpc.aio.Channel
+
+ @property
+ def host(self) -> host_grpc_aio.Host:
+ """Returns the Pandora Host gRPC interface."""
+ return host_grpc_aio.Host(self.channel)
+
+ @property
+ def security(self) -> security_grpc_aio.Security:
+ """Returns the Pandora Security gRPC interface."""
+ return security_grpc_aio.Security(self.channel)
+
+ @property
+ def security_storage(self) -> security_grpc_aio.SecurityStorage:
+ """Returns the Pandora SecurityStorage gRPC interface."""
+ return security_grpc_aio.SecurityStorage(self.channel)
+
+ @property
+ def asha(self) -> asha_grpc_aio.ASHA:
+ """Returns the Pandora ASHA gRPC interface."""
+ return asha_grpc_aio.ASHA(self.channel)
+
+ @property
+ def aio(self) -> 'PandoraClient.Aio':
+ if not self._aio:
+ self._aio = PandoraClient.Aio(grpc.aio.insecure_channel(self.grpc_target))
+ return self._aio
+
+
+class PandoraClientLoggerAdapter(logging.LoggerAdapter): # type: ignore
+ """Formats logs from the PandoraClient."""
+
+ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]:
+ assert self.extra
+ client = self.extra['client']
+ assert isinstance(client, PandoraClient)
+ client_name = self.extra.get('client_name', client.__class__.__name__)
+ addr = ':'.join([f'{x:02X}' for x in client.address[4:]])
+ return (f'[{client_name}:{addr}] {msg}', kwargs)
+
+
+class BumblePandoraClient(PandoraClient):
+ """Special Pandora client which also give access to a Bumble device instance."""
+
+ _bumble: BumbleDevice # Bumble device wrapper.
+
+ def __init__(self, grpc_target: str, bumble: BumbleDevice) -> None:
+ super().__init__(grpc_target, 'bumble')
+ self._bumble = bumble
+
+ @property
+ def device(self) -> bumble.device.Device:
+ return self._bumble.device
+
+ @property
+ def random_address(self) -> Address:
+ return Address(self.device.random_address)
diff --git a/avatar/pandora_server.py b/avatar/pandora_server.py
new file mode 100644
index 0000000..0e00fcf
--- /dev/null
+++ b/avatar/pandora_server.py
@@ -0,0 +1,151 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Interface for controller-specific Pandora server management."""
+
+import asyncio
+import avatar.aio
+import grpc
+import grpc.aio
+import threading
+import time
+import types
+
+from avatar.bumble_device import BumbleDevice
+from avatar.bumble_server import create_serve_task
+from avatar.controllers import bumble_device, pandora_device
+from avatar.pandora_client import BumblePandoraClient, PandoraClient
+from contextlib import suppress
+from mobly.controllers import android_device
+from mobly.controllers.android_device import AndroidDevice
+from typing import Generic, Optional, TypeVar
+
+ANDROID_SERVER_PACKAGE = 'com.android.pandora'
+ANDROID_SERVER_GRPC_PORT = 8999 # TODO: Use a dynamic port
+
+
+# Generic type for `PandoraServer`.
+TDevice = TypeVar('TDevice')
+
+
+class PandoraServer(Generic[TDevice]):
+ """Abstract interface to manage the Pandora gRPC server on the device."""
+
+ MOBLY_CONTROLLER_MODULE: types.ModuleType = pandora_device
+
+ device: TDevice
+
+ def __init__(self, device: TDevice) -> None:
+ """Creates a PandoraServer.
+
+ Args:
+ device: A Mobly controller instance.
+ """
+ self.device = device
+
+ def start(self) -> PandoraClient:
+ """Sets up and starts the Pandora server on the device."""
+ assert isinstance(self.device, PandoraClient)
+ return self.device
+
+ def stop(self) -> None:
+ """Stops and cleans up the Pandora server on the device."""
+
+
+class BumblePandoraServer(PandoraServer[BumbleDevice]):
+ """Manages the Pandora gRPC server on an BumbleDevice."""
+
+ MOBLY_CONTROLLER_MODULE = bumble_device
+
+ _task: Optional[asyncio.Task[None]] = None
+
+ def start(self) -> BumblePandoraClient:
+ """Sets up and starts the Pandora server on the Bumble device."""
+ assert self._task is None
+
+ # set the event loop to make sure the gRPC server use the avatar one.
+ asyncio.set_event_loop(avatar.aio.loop)
+
+ # create gRPC server & port.
+ server = grpc.aio.server()
+ port = server.add_insecure_port(f'localhost:{0}')
+
+ self._task = avatar.aio.loop.create_task(
+ avatar.aio.run_until_complete(
+ create_serve_task(
+ self.device,
+ grpc_server=server,
+ port=port,
+ )
+ )
+ )
+
+ return BumblePandoraClient(f'localhost:{port}', self.device)
+
+ def stop(self) -> None:
+ """Stops and cleans up the Pandora server on the Bumble device."""
+
+ async def server_stop() -> None:
+ assert self._task is not None
+ self._task.cancel()
+ with suppress(asyncio.CancelledError):
+ await self._task
+ self._task = None
+
+ avatar.aio.run_until_complete(server_stop())
+
+
+class AndroidPandoraServer(PandoraServer[AndroidDevice]):
+ """Manages the Pandora gRPC server on an AndroidDevice."""
+
+ MOBLY_CONTROLLER_MODULE = android_device
+
+ _instrumentation: Optional[threading.Thread] = None
+ _port: int = ANDROID_SERVER_GRPC_PORT
+
+ def start(self) -> PandoraClient:
+ """Sets up and starts the Pandora server on the Android device."""
+ assert self._instrumentation is None
+
+ # start Pandora Android gRPC server.
+ self._instrumentation = threading.Thread(
+ target=lambda: self.device.adb._exec_adb_cmd( # type: ignore
+ 'shell',
+ f'am instrument --no-hidden-api-checks -w {ANDROID_SERVER_PACKAGE}/.Main',
+ shell=False,
+ timeout=None,
+ stderr=None,
+ )
+ )
+
+ self._instrumentation.start()
+ self.device.adb.forward([f'tcp:{self._port}', f'tcp:{ANDROID_SERVER_GRPC_PORT}']) # type: ignore
+
+ # wait a few seconds for the Android gRPC server to be started.
+ time.sleep(3)
+
+ return PandoraClient(f'localhost:{self._port}')
+
+ def stop(self) -> None:
+ """Stops and cleans up the Pandora server on the Android device."""
+ assert self._instrumentation is not None
+
+ # Stop Pandora Android gRPC server.
+ self.device.adb._exec_adb_cmd( # type: ignore
+ 'shell', f'am force-stop {ANDROID_SERVER_PACKAGE}', shell=False, timeout=None, stderr=None
+ )
+
+ self.device.adb.forward(['--remove', f'tcp:{ANDROID_SERVER_GRPC_PORT}']) # type: ignore
+ self._instrumentation.join()
diff --git a/avatar/utils.py b/avatar/utils.py
deleted file mode 100644
index 189ff2a..0000000
--- a/avatar/utils.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2022 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import avatar
-
-from bumble.hci import Address as BumbleAddress
-
-
-class Address(bytes):
-
- def __new__(cls, address):
- if type(address) is bytes:
- address_bytes = address
- elif type(address) is str:
- address_bytes = bytes.fromhex(address.replace(':', ''))
- elif isinstance(address, BumbleAddress):
- address_bytes = bytes(reversed(bytes(address)))
- else:
- raise ValueError('Invalid address format')
-
- if len(address_bytes) != 6:
- raise ValueError('Invalid address length')
-
- return bytes.__new__(cls, address_bytes)
-
- def __str__(self):
- return ':'.join([f'{x:02X}' for x in self])
-
-
-class AsyncQueue(asyncio.Queue):
-
- def __aiter__(self):
- return self
-
- def __iter__(self):
- return self
-
- async def __anext__(self):
- return await self.get()
-
- def __next__(self):
- return avatar.run_until_complete(self.__anext__())
diff --git a/bt-test-interfaces b/bt-test-interfaces
new file mode 120000
index 0000000..4cee449
--- /dev/null
+++ b/bt-test-interfaces
@@ -0,0 +1 @@
+../bt-test-interfaces/
\ No newline at end of file
diff --git a/examples/asha_test.py b/examples/asha_test.py
index cd3b7f5..7017ba3 100644
--- a/examples/asha_test.py
+++ b/examples/asha_test.py
@@ -12,71 +12,100 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import avatar
import asyncio
import logging
-import grpc
-from mobly import test_runner, base_test
-from bumble.core import AdvertisingData
-from bumble.hci import UUID
+from avatar import BumbleDevice, PandoraDevice, PandoraDevices, asynchronous
from bumble.gatt import GATT_ASHA_SERVICE
-
-from avatar.utils import Address
-from avatar.controllers import pandora_device
-from pandora.host_pb2 import (
- DiscoverabilityMode, DataTypes, OwnAddressType, Connection,
- ConnectabilityMode, OwnAddressType
-)
+from mobly import base_test, test_runner
+from mobly.asserts import assert_equal # type: ignore
+from mobly.asserts import assert_in # type: ignore
+from pandora.host_pb2 import DataTypes
-class ASHATest(base_test.BaseTestClass):
- def setup_class(self):
- self.pandora_devices = self.register_controller(pandora_device)
- self.dut: pandora_device.PandoraDevice = self.pandora_devices[0]
- self.ref: pandora_device.PandoraDevice = self.pandora_devices[1]
+class ASHATest(base_test.BaseTestClass): # type: ignore[misc]
+ ASHA_UUID = GATT_ASHA_SERVICE.to_hex_str()
- @avatar.asynchronous
- async def setup_test(self):
- async def reset(device: pandora_device.PandoraDevice):
- await device.host.FactoryReset()
- device.address = (await device.host.ReadLocalAddress(wait_for_ready=True)).address
+ dut: PandoraDevice
+ ref: BumbleDevice
+
+ def setup_class(self) -> None:
+ dut, ref = PandoraDevices(self)
+ assert isinstance(ref, BumbleDevice)
+ self.dut, self.ref = dut, ref
+
+ @asynchronous
+ async def setup_test(self) -> None:
+ async def reset(device: PandoraDevice) -> None:
+ await device.aio.host.FactoryReset()
+ device.address = (await device.aio.host.ReadLocalAddress(wait_for_ready=True)).address # type: ignore[assignment]
await asyncio.gather(reset(self.dut), reset(self.ref))
- def test_ASHA_advertising(self):
+ def test_ASHA_advertising(self) -> None:
complete_local_name = 'Bumble'
- ASHA_UUID = GATT_ASHA_SERVICE.to_hex_str()
protocol_version = 0x01
capability = 0x00
hisyncid = [0x01, 0x02, 0x03, 0x04, 0x5, 0x6, 0x7, 0x8]
truncated_hisyncid = hisyncid[:4]
- self.ref.asha.Register(capability=capability,
- hisyncid=hisyncid)
+ self.ref.asha.Register(capability=capability, hisyncid=hisyncid)
- self.ref.host.StartAdvertising(
+ advertisement = self.ref.host.Advertise(
legacy=True,
data=DataTypes(
- complete_local_name=complete_local_name,
- incomplete_service_class_uuids16=[ASHA_UUID]
- )
+ complete_local_name=complete_local_name, incomplete_service_class_uuids16=[ASHATest.ASHA_UUID]
+ ),
)
- peers = self.dut.host.Scan()
+ scan = self.dut.host.Scan()
- scan_response = next((x for x in peers if
- x.data.complete_local_name == complete_local_name))
- logging.info(f"scan_response.data: {scan_response}")
- assert ASHA_UUID in scan_response.data.service_data_uuid16
- assert type(scan_response.data.complete_local_name) == str
- expected_advertisement_data = "{:02x}".format(protocol_version) + \
- "{:02x}".format(capability) + \
- "".join([("{:02x}".format(x)) for x in
- truncated_hisyncid])
- assert expected_advertisement_data == \
- (scan_response.data.service_data_uuid16[ASHA_UUID]).hex()
+ scan_result = next((x for x in scan if x.data.complete_local_name == complete_local_name))
+ logging.debug(f"scan_response.data: {scan_result}")
+
+ advertisement.cancel()
+ scan.cancel()
+
+ assert_in(ASHATest.ASHA_UUID, scan_result.data.service_data_uuid16)
+ assert_equal(type(scan_result.data.complete_local_name), str)
+ expected_advertisement_data = (
+ "{:02x}".format(protocol_version)
+ + "{:02x}".format(capability)
+ + "".join([("{:02x}".format(x)) for x in truncated_hisyncid])
+ )
+ assert_equal(expected_advertisement_data, (scan_result.data.service_data_uuid16[ASHATest.ASHA_UUID]).hex())
+
+ def test_ASHA_scan_response(self) -> None:
+ complete_local_name = 'Bumble'
+ protocol_version = 0x01
+ capability = 0x00
+ hisyncid = [0x01, 0x02, 0x03, 0x04, 0x5, 0x6, 0x7, 0x8]
+ truncated_hisyncid = hisyncid[:4]
+
+ self.ref.asha.Register(capability=capability, hisyncid=hisyncid)
+
+ advertisement = self.ref.host.Advertise(
+ legacy=True,
+ scan_response_data=DataTypes(
+ complete_local_name=complete_local_name, incomplete_service_class_uuids16=[ASHATest.ASHA_UUID]
+ ),
+ )
+ scan = self.dut.host.Scan()
+
+ scan_response = next((x for x in scan if x.data.complete_local_name == complete_local_name))
+ logging.debug(f"scan_response.data: {scan_response}")
+
+ advertisement.cancel()
+ scan.cancel()
+
+ assert_in(ASHATest.ASHA_UUID, scan_response.data.service_data_uuid16)
+ expected_advertisement_data = (
+ "{:02x}".format(protocol_version)
+ + "{:02x}".format(capability)
+ + "".join([("{:02x}".format(x)) for x in truncated_hisyncid])
+ )
+ assert_equal(expected_advertisement_data, (scan_response.data.service_data_uuid16[ASHATest.ASHA_UUID]).hex())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
- test_runner.main()
+ test_runner.main() # type: ignore
diff --git a/examples/example.py b/examples/example.py
index 031fa16..749bc09 100644
--- a/examples/example.py
+++ b/examples/example.py
@@ -12,340 +12,295 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import avatar
import asyncio
import grpc
import logging
+from avatar import BumbleDevice, PandoraDevice, PandoraDevices, asynchronous, parameterized
+from bumble.smp import PairingDelegate
from concurrent import futures
from contextlib import suppress
-
from mobly import base_test, test_runner
-from mobly.asserts import *
-
-from bumble.smp import PairingDelegate
-
-from avatar.utils import Address, AsyncQueue
-from avatar.controllers import pandora_device
+from mobly.asserts import assert_equal # type: ignore
+from mobly.asserts import assert_in # type: ignore
+from mobly.asserts import assert_is_none # type: ignore
+from mobly.asserts import assert_is_not_none # type: ignore
+from mobly.asserts import fail # type: ignore
from pandora.host_pb2 import (
- DiscoverabilityMode, DataTypes, OwnAddressType
+ DISCOVERABLE_GENERAL,
+ DISCOVERABLE_LIMITED,
+ NOT_DISCOVERABLE,
+ PUBLIC,
+ RANDOM,
+ DataTypes,
+ DiscoverabilityMode,
+ OwnAddressType,
)
-from pandora.security_pb2 import (
- PairingEventAnswer, SecurityLevel, LESecurityLevel
-)
+from pandora.security_pb2 import LE_LEVEL3, LEVEL2, PairingEventAnswer
+from typing import NoReturn, Optional
-class ExampleTest(base_test.BaseTestClass):
- def setup_class(self):
- self.pandora_devices = self.register_controller(pandora_device)
- self.dut: pandora_device.PandoraDevice = self.pandora_devices[0]
- self.ref: pandora_device.BumblePandoraDevice = self.pandora_devices[1]
+class ExampleTest(base_test.BaseTestClass): # type: ignore[misc]
+ devices: Optional[PandoraDevices] = None
- @avatar.asynchronous
- async def setup_test(self):
- async def reset(device: pandora_device.PandoraDevice):
- await device.host.FactoryReset()
- device.address = (await device.host.ReadLocalAddress(wait_for_ready=True)).address
+ # pandora devices.
+ dut: PandoraDevice
+ ref: BumbleDevice
- await asyncio.gather(reset(self.dut), reset(self.ref))
+ def setup_class(self) -> None:
+ self.devices = PandoraDevices(self)
+ dut, ref = self.devices
+ assert isinstance(ref, BumbleDevice)
+ self.dut, self.ref = dut, ref
- def test_print_addresses(self):
+ def teardown_class(self) -> None:
+ if self.devices:
+ self.devices.stop_all()
+
+ @asynchronous
+ async def setup_test(self) -> None:
+ await asyncio.gather(self.dut.reset(), self.ref.reset())
+
+ def test_print_addresses(self) -> None:
dut_address = self.dut.address
self.dut.log.info(f'Address: {dut_address}')
ref_address = self.ref.address
self.ref.log.info(f'Address: {ref_address}')
- def test_get_remote_name(self):
- dut_name = self.ref.host.GetRemoteName(address=self.dut.address)
- assert_equal(dut_name.WhichOneof('result'), 'name')
- self.ref.log.info(f'DUT remote name: {dut_name.name}')
- ref_name = self.dut.host.GetRemoteName(address=self.ref.address)
- assert_equal(ref_name.WhichOneof('result'), 'name')
- self.dut.log.info(f'REF remote name: {ref_name.name}')
-
- def test_classic_connect(self):
+ def test_classic_connect(self) -> None:
dut_address = self.dut.address
self.dut.log.info(f'Address: {dut_address}')
connection = self.ref.host.Connect(address=dut_address).connection
- dut_name = self.ref.host.GetRemoteName(connection=connection).name
- self.ref.log.info(f'Connected with: "{dut_name}" {dut_address}')
+ assert connection
+ self.ref.log.info(f'Connected with: {dut_address}')
self.ref.host.Disconnect(connection=connection)
# Using this decorator allow us to write one `test_le_connect`, and
# run it multiple time with different parameters.
# Here we check that no matter the address type we use for both sides
# the connection still complete.
- @avatar.parameterized([
- (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC),
- (OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
- (OwnAddressType.RANDOM, OwnAddressType.RANDOM),
- (OwnAddressType.RANDOM, OwnAddressType.PUBLIC),
- ])
- def test_le_connect(self, dut_address_type: OwnAddressType, ref_address_type: OwnAddressType):
- self.ref.host.StartAdvertising(legacy=True, connectable=True, own_address_type=ref_address_type)
- peers = self.dut.host.Scan(own_address_type=dut_address_type)
- if ref_address_type == OwnAddressType.PUBLIC:
- scan_response = next((x for x in peers if x.public == self.ref.address))
- connection = self.dut.host.ConnectLE(public=scan_response.public, own_address_type=dut_address_type).connection
+ @parameterized(
+ (RANDOM, RANDOM),
+ (RANDOM, PUBLIC),
+ ) # type: ignore[misc]
+ def test_le_connect(self, dut_address_type: OwnAddressType, ref_address_type: OwnAddressType) -> None:
+ advertisement = self.ref.host.Advertise(legacy=True, connectable=True, own_address_type=ref_address_type)
+ scan = self.dut.host.Scan(own_address_type=dut_address_type)
+ if ref_address_type == PUBLIC:
+ scan_response = next((x for x in scan if x.public == self.ref.address))
+ dut_ref = self.dut.host.ConnectLE(
+ public=scan_response.public,
+ own_address_type=dut_address_type,
+ ).connection
else:
- scan_response = next((x for x in peers if x.random == Address(self.ref.device.random_address)))
- connection = self.dut.host.ConnectLE(random=scan_response.random, own_address_type=dut_address_type).connection
- peers.cancel()
- self.dut.host.Disconnect(connection=connection)
+ scan_response = next((x for x in scan if x.random == self.ref.random_address))
+ dut_ref = self.dut.host.ConnectLE(
+ random=scan_response.random,
+ own_address_type=dut_address_type,
+ ).connection
+ scan.cancel()
+ ref_dut = next(advertisement).connection
+ advertisement.cancel()
+ assert dut_ref and ref_dut
+ self.dut.host.Disconnect(connection=dut_ref)
- def test_not_discoverable(self):
- self.dut.host.SetDiscoverabilityMode(mode=DiscoverabilityMode.NOT_DISCOVERABLE)
- peers = self.ref.host.Inquiry(timeout=3.0)
+ def test_not_discoverable(self) -> None:
+ self.dut.host.SetDiscoverabilityMode(mode=NOT_DISCOVERABLE)
+ inquiry = self.ref.host.Inquiry(timeout=3.0)
try:
- assert_is_none(next((x for x in peers if x.address == self.dut.address), None))
+ assert_is_none(next((x for x in inquiry if x.address == self.dut.address), None))
except grpc.RpcError as e:
# No peers found; StartInquiry times out
- assert_equal(e.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+ assert_equal(e.code(), grpc.StatusCode.DEADLINE_EXCEEDED) # type: ignore
finally:
- peers.cancel()
+ inquiry.cancel()
- @avatar.parameterized([
- (DiscoverabilityMode.DISCOVERABLE_LIMITED, ),
- (DiscoverabilityMode.DISCOVERABLE_GENERAL, ),
- ])
- def test_discoverable(self, mode):
+ @parameterized(
+ (DISCOVERABLE_LIMITED,),
+ (DISCOVERABLE_GENERAL,),
+ ) # type: ignore[misc]
+ def test_discoverable(self, mode: DiscoverabilityMode) -> None:
self.dut.host.SetDiscoverabilityMode(mode=mode)
- peers = self.ref.host.Inquiry(timeout=15.0)
+ inquiry = self.ref.host.Inquiry(timeout=15.0)
try:
- assert_is_not_none(next((x for x in peers if x.address == self.dut.address), None))
+ assert_is_not_none(next((x for x in inquiry if x.address == self.dut.address), None))
finally:
- peers.cancel()
+ inquiry.cancel()
- @avatar.asynchronous
- async def test_wait_connection(self):
- dut_ref = self.dut.host.WaitConnection(address=self.ref.address)
- ref_dut = await self.ref.host.Connect(address=self.dut.address)
- dut_ref = await dut_ref
+ @asynchronous
+ async def test_wait_connection(self) -> None:
+ dut_ref_co = self.dut.aio.host.WaitConnection(address=self.ref.address)
+ ref_dut = await self.ref.aio.host.Connect(address=self.dut.address)
+ dut_ref = await dut_ref_co
assert_is_not_none(ref_dut.connection)
assert_is_not_none(dut_ref.connection)
- await self.ref.host.Disconnect(connection=ref_dut.connection)
+ assert ref_dut.connection
+ await self.ref.aio.host.Disconnect(connection=ref_dut.connection)
- @avatar.asynchronous
- async def test_wait_any_connection(self):
- dut_ref = self.dut.host.WaitConnection()
- ref_dut = await self.ref.host.Connect(address=self.dut.address)
- dut_ref = await dut_ref
- assert_is_not_none(ref_dut.connection)
- assert_is_not_none(dut_ref.connection)
- await self.ref.host.Disconnect(connection=ref_dut.connection)
-
- def test_scan_response_data(self):
- self.dut.host.StartAdvertising(
+ def test_scan_response_data(self) -> None:
+ advertisement = self.dut.host.Advertise(
legacy=True,
data=DataTypes(
- include_shortened_local_name=True,
- tx_power_level=42,
- incomplete_service_class_uuids16=['FDF0']
+ complete_service_class_uuids16=['FDF0'],
),
- scan_response_data=DataTypes(include_complete_local_name=True, include_class_of_device=True)
+ scan_response_data=DataTypes(
+ include_class_of_device=True,
+ ),
)
- peers = self.ref.host.Scan()
- scan_response = next((x for x in peers if x.public == self.dut.address))
- peers.cancel()
+ scan = self.ref.host.Scan()
+ scan_response = next((x for x in scan if x.public == self.dut.address))
- assert_equal(type(scan_response.data.complete_local_name), str)
- assert_equal(type(scan_response.data.shortened_local_name), str)
+ scan.cancel()
+ advertisement.cancel()
+
assert_equal(type(scan_response.data.class_of_device), int)
- assert_equal(type(scan_response.data.incomplete_service_class_uuids16[0]), str)
- assert_equal(scan_response.data.tx_power_level, 42)
+ assert_equal(type(scan_response.data.complete_service_class_uuids16[0]), str)
- @avatar.parameterized([
- (PairingDelegate.NO_OUTPUT_NO_INPUT, ),
- (PairingDelegate.KEYBOARD_INPUT_ONLY, ),
- (PairingDelegate.DISPLAY_OUTPUT_ONLY, ),
- (PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, ),
- (PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, ),
- ])
- @avatar.asynchronous
- async def test_classic_pairing(self, ref_io_capability):
- # override reference device IO capability
- self.ref.device.io_capability = ref_io_capability
+ async def handle_pairing_events(self) -> NoReturn:
+ ref_pairing_stream = self.ref.aio.security.OnPairing()
+ dut_pairing_stream = self.dut.aio.security.OnPairing()
- await self.ref.security_storage.DeleteBond(public=self.dut.address)
+ try:
+ while True:
+ ref_pairing_event, dut_pairing_event = await asyncio.gather(
+ anext(ref_pairing_stream),
+ anext(dut_pairing_stream),
+ )
- async def handle_pairing_events():
- on_ref_pairing = self.ref.security.OnPairing((ref_answer_queue := AsyncQueue()))
- on_dut_pairing = self.dut.security.OnPairing((dut_answer_queue := AsyncQueue()))
-
- try:
- while True:
- dut_pairing_event = await anext(aiter(on_dut_pairing))
- ref_pairing_event = await anext(aiter(on_ref_pairing))
-
- if dut_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works'):
- assert_in(ref_pairing_event.WhichOneof('method'), ('numeric_comparison', 'just_works'))
- dut_answer_queue.put_nowait(PairingEventAnswer(
+ if dut_pairing_event.method_variant() in ('numeric_comparison', 'just_works'):
+ assert_in(ref_pairing_event.method_variant(), ('numeric_comparison', 'just_works'))
+ dut_pairing_stream.send_nowait(
+ PairingEventAnswer(
event=dut_pairing_event,
confirm=True,
- ))
- ref_answer_queue.put_nowait(PairingEventAnswer(
+ )
+ )
+ ref_pairing_stream.send_nowait(
+ PairingEventAnswer(
event=ref_pairing_event,
confirm=True,
- ))
- elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_notification':
- assert_equal(ref_pairing_event.WhichOneof('method'), 'passkey_entry_request')
- ref_answer_queue.put_nowait(PairingEventAnswer(
+ )
+ )
+ elif dut_pairing_event.method_variant() == 'passkey_entry_notification':
+ assert_equal(ref_pairing_event.method_variant(), 'passkey_entry_request')
+ ref_pairing_stream.send_nowait(
+ PairingEventAnswer(
event=ref_pairing_event,
passkey=dut_pairing_event.passkey_entry_notification,
- ))
- elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_request':
- assert_equal(ref_pairing_event.WhichOneof('method'), 'passkey_entry_notification')
- dut_answer_queue.put_nowait(PairingEventAnswer(
+ )
+ )
+ elif dut_pairing_event.method_variant() == 'passkey_entry_request':
+ assert_equal(ref_pairing_event.method_variant(), 'passkey_entry_notification')
+ dut_pairing_stream.send_nowait(
+ PairingEventAnswer(
event=dut_pairing_event,
passkey=ref_pairing_event.passkey_entry_notification,
- ))
- else:
- fail()
+ )
+ )
+ else:
+ fail("unreachable")
- finally:
- on_ref_pairing.cancel()
- on_dut_pairing.cancel()
+ finally:
+ ref_pairing_stream.cancel()
+ dut_pairing_stream.cancel()
- pairing = asyncio.create_task(handle_pairing_events())
- (dut_ref, ref_dut) = await asyncio.gather(
- self.dut.host.WaitConnection(address=self.ref.address),
- self.ref.host.Connect(address=self.dut.address),
+ @parameterized(
+ (PairingDelegate.NO_OUTPUT_NO_INPUT,),
+ (PairingDelegate.KEYBOARD_INPUT_ONLY,),
+ (PairingDelegate.DISPLAY_OUTPUT_ONLY,),
+ (PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,),
+ (PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,),
+ ) # type: ignore[misc]
+ @asynchronous
+ async def test_classic_pairing(self, ref_io_capability: int) -> None:
+ # override reference device IO capability
+ setattr(self.ref.device, 'io_capability', ref_io_capability)
+
+ pairing = asyncio.create_task(self.handle_pairing_events())
+ (dut_ref_res, ref_dut_res) = await asyncio.gather(
+ self.dut.aio.host.WaitConnection(address=self.ref.address),
+ self.ref.aio.host.Connect(address=self.dut.address),
)
- assert_equal(ref_dut.WhichOneof('result'), 'connection')
- assert_equal(dut_ref.WhichOneof('result'), 'connection')
- ref_dut = ref_dut.connection
- dut_ref = dut_ref.connection
+ assert_equal(ref_dut_res.result_variant(), 'connection')
+ assert_equal(dut_ref_res.result_variant(), 'connection')
+ ref_dut = ref_dut_res.connection
+ dut_ref = dut_ref_res.connection
+ assert ref_dut and dut_ref
(secure, wait_security) = await asyncio.gather(
- self.ref.security.Secure(connection=ref_dut, classic=SecurityLevel.LEVEL2),
- self.dut.security.WaitSecurity(connection=dut_ref, classic=SecurityLevel.LEVEL2)
+ self.ref.aio.security.Secure(connection=ref_dut, classic=LEVEL2),
+ self.dut.aio.security.WaitSecurity(connection=dut_ref, classic=LEVEL2),
)
pairing.cancel()
with suppress(asyncio.CancelledError, futures.CancelledError):
await pairing
- assert_equal(secure.WhichOneof('result'), 'success')
- assert_equal(wait_security.WhichOneof('result'), 'success')
+ assert_equal(secure.result_variant(), 'success')
+ assert_equal(wait_security.result_variant(), 'success')
await asyncio.gather(
- self.dut.host.Disconnect(connection=dut_ref),
- self.ref.host.WaitDisconnection(connection=ref_dut)
+ self.dut.aio.host.Disconnect(connection=dut_ref),
+ self.ref.aio.host.WaitDisconnection(connection=ref_dut),
)
- @avatar.parameterized([
- (OwnAddressType.RANDOM, OwnAddressType.RANDOM, PairingDelegate.NO_OUTPUT_NO_INPUT),
- (OwnAddressType.RANDOM, OwnAddressType.RANDOM, PairingDelegate.KEYBOARD_INPUT_ONLY),
- (OwnAddressType.RANDOM, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_ONLY),
- (OwnAddressType.RANDOM, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT),
- (OwnAddressType.RANDOM, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
- (OwnAddressType.PUBLIC, OwnAddressType.RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
- (OwnAddressType.PUBLIC, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
- (OwnAddressType.RANDOM, OwnAddressType.PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
- ])
- @avatar.asynchronous
- async def test_le_pairing(self,
- dut_address_type: OwnAddressType,
- ref_address_type: OwnAddressType,
- ref_io_capability
- ):
+ @parameterized(
+ (RANDOM, RANDOM, PairingDelegate.NO_OUTPUT_NO_INPUT),
+ (RANDOM, RANDOM, PairingDelegate.KEYBOARD_INPUT_ONLY),
+ (RANDOM, RANDOM, PairingDelegate.DISPLAY_OUTPUT_ONLY),
+ (RANDOM, RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT),
+ (RANDOM, RANDOM, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
+ (RANDOM, PUBLIC, PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT),
+ ) # type: ignore[misc]
+ @asynchronous
+ async def test_le_pairing(
+ self, dut_address_type: OwnAddressType, ref_address_type: OwnAddressType, ref_io_capability: int
+ ) -> None:
# override reference device IO capability
- self.ref.device.io_capability = ref_io_capability
+ setattr(self.ref.device, 'io_capability', ref_io_capability)
- if ref_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC):
- ref_address = {'public': self.ref.address}
- else:
- ref_address = {'random': Address(self.ref.device.random_address)}
-
- await self.dut.security_storage.DeleteBond(**ref_address)
- await self.dut.host.StartAdvertising(
- legacy=True, connectable=True,
+ advertisement = self.dut.aio.host.Advertise(
+ legacy=True,
+ connectable=True,
own_address_type=dut_address_type,
- data=DataTypes(manufacturer_specific_data=b'pause cafe')
+ data=DataTypes(manufacturer_specific_data=b'pause cafe'),
)
- dut = None
- peers = self.ref.host.Scan(own_address_type=ref_address_type)
- async for peer in aiter(peers):
- if b'pause cafe' in peer.data.manufacturer_specific_data:
- dut = peer
- break
- assert_is_not_none(dut)
- if dut_address_type in (OwnAddressType.PUBLIC, OwnAddressType.RESOLVABLE_OR_PUBLIC):
- dut_address = {'public': Address(dut.public)}
- else:
- dut_address = {'random': Address(dut.random)}
- peers.cancel()
+ scan = self.ref.aio.host.Scan(own_address_type=ref_address_type)
+ dut = await anext((x async for x in scan if b'pause cafe' in x.data.manufacturer_specific_data))
+ scan.cancel()
+ assert dut
- async def handle_pairing_events():
- on_ref_pairing = self.ref.security.OnPairing((ref_answer_queue := AsyncQueue()))
- on_dut_pairing = self.dut.security.OnPairing((dut_answer_queue := AsyncQueue()))
-
- try:
- while True:
- dut_pairing_event = await anext(aiter(on_dut_pairing))
- ref_pairing_event = await anext(aiter(on_ref_pairing))
-
- if dut_pairing_event.WhichOneof('method') in ('numeric_comparison', 'just_works'):
- assert_in(ref_pairing_event.WhichOneof('method'), ('numeric_comparison', 'just_works'))
- dut_answer_queue.put_nowait(PairingEventAnswer(
- event=dut_pairing_event,
- confirm=True,
- ))
- ref_answer_queue.put_nowait(PairingEventAnswer(
- event=ref_pairing_event,
- confirm=True,
- ))
- elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_notification':
- assert_equal(ref_pairing_event.WhichOneof('method'), 'passkey_entry_request')
- ref_answer_queue.put_nowait(PairingEventAnswer(
- event=ref_pairing_event,
- passkey=dut_pairing_event.passkey_entry_notification,
- ))
- elif dut_pairing_event.WhichOneof('method') == 'passkey_entry_request':
- assert_equal(ref_pairing_event.WhichOneof('method'), 'passkey_entry_notification')
- dut_answer_queue.put_nowait(PairingEventAnswer(
- event=dut_pairing_event,
- passkey=ref_pairing_event.passkey_entry_notification,
- ))
- else:
- fail()
-
- finally:
- on_ref_pairing.cancel()
- on_dut_pairing.cancel()
-
- pairing = asyncio.create_task(handle_pairing_events())
- (dut_ref, ref_dut) = await asyncio.gather(
- self.dut.host.WaitLEConnection(**ref_address),
- self.ref.host.ConnectLE(own_address_type=ref_address_type, **dut_address),
+ pairing = asyncio.create_task(self.handle_pairing_events())
+ (ref_dut_res, dut_ref_res) = await asyncio.gather(
+ self.ref.aio.host.ConnectLE(own_address_type=ref_address_type, **dut.address_asdict()),
+ anext(aiter(advertisement)),
)
- assert_equal(ref_dut.WhichOneof('result'), 'connection')
- assert_equal(dut_ref.WhichOneof('result'), 'connection')
- ref_dut = ref_dut.connection
- dut_ref = dut_ref.connection
+ advertisement.cancel()
+ ref_dut, dut_ref = ref_dut_res.connection, dut_ref_res.connection
+ assert ref_dut and dut_ref
(secure, wait_security) = await asyncio.gather(
- self.ref.security.Secure(connection=ref_dut, le=LESecurityLevel.LE_LEVEL4),
- self.dut.security.WaitSecurity(connection=dut_ref, le=LESecurityLevel.LE_LEVEL4)
+ self.ref.aio.security.Secure(connection=ref_dut, le=LE_LEVEL3),
+ self.dut.aio.security.WaitSecurity(connection=dut_ref, le=LE_LEVEL3),
)
pairing.cancel()
with suppress(asyncio.CancelledError, futures.CancelledError):
await pairing
- assert_equal(secure.WhichOneof('result'), 'success')
- assert_equal(wait_security.WhichOneof('result'), 'success')
+ assert_equal(secure.result_variant(), 'success')
+ assert_equal(wait_security.result_variant(), 'success')
await asyncio.gather(
- self.dut.host.Disconnect(connection=dut_ref),
- self.ref.host.WaitDisconnection(connection=ref_dut)
+ self.dut.aio.host.Disconnect(connection=dut_ref),
+ self.ref.aio.host.WaitDisconnection(connection=ref_dut),
)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
- test_runner.main()
+ test_runner.main() # type: ignore
diff --git a/examples/simulated_bumble_android.yml b/examples/simulated_bumble_android.yml
index bd43582..0e0f461 100644
--- a/examples/simulated_bumble_android.yml
+++ b/examples/simulated_bumble_android.yml
@@ -4,9 +4,6 @@
- Name: ExampleTest
Controllers:
AndroidDevice: '*'
- PandoraDevice:
- - class: AndroidPandoraDevice
- config: '*'
- - class: BumblePandoraDevice
- transport: 'tcp-client:127.0.0.1:7300'
+ BumbleDevice:
+ - transport: 'tcp-client:127.0.0.1:7300'
classic_enabled: true
diff --git a/examples/simulated_bumble_bumble.yml b/examples/simulated_bumble_bumble.yml
index eaf256c..1d79c42 100644
--- a/examples/simulated_bumble_bumble.yml
+++ b/examples/simulated_bumble_bumble.yml
@@ -1,6 +1,6 @@
---
-# BumblePandoraDevice configuration:
+# BumbleDevice configuration:
# classic_enabled: [true, false] # (false by default)
# class_of_device: 1234 # See assigned numbers
# keystore: JsonKeyStore # or empty
@@ -14,17 +14,15 @@
TestBeds:
- Name: ExampleTest
Controllers:
- PandoraDevice:
+ BumbleDevice:
# DUT device
- - class: BumblePandoraDevice
- transport: 'tcp-client:127.0.0.1:6402'
+ - transport: 'tcp-client:127.0.0.1:6402'
classic_enabled: true
class_of_device: 2360324
keystore: 'JsonKeyStore'
io_capability: display_output_only
# Reference device
- - class: BumblePandoraDevice
- transport: 'tcp-client:127.0.0.1:6402'
+ - transport: 'tcp-client:127.0.0.1:6402'
classic_enabled: true
class_of_device: 2360324
keystore: 'JsonKeyStore'
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..257985f
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,67 @@
+[project]
+name = "avatar"
+authors = [{name = "Pandora", email = "pandora-core@google.com"}]
+readme = "README.md"
+dynamic = ["version", "description"]
+dependencies = [
+ "bt-test-interfaces",
+ "bumble",
+ "grpcio==1.51.1",
+ "mobly>=1.12",
+ "bitstruct>=8.12",
+]
+
+[project.optional-dependencies]
+dev = [
+ "grpcio-tools==1.51.1",
+ "black==22.10.0",
+ "pyright==1.1.294",
+ "mypy==1.0",
+ "isort==5.12.0",
+ "types-psutil>=5.9.5.6",
+ "types-setuptools>=65.7.0.3",
+ "types-protobuf>=4.21.0.3"
+]
+
+[tool.black]
+line-length = 119
+target-version = ["py38", "py39", "py310", "py311"]
+skip-string-normalization = true
+
+[tool.isort]
+profile = "black"
+line_length = 119
+no_sections = true
+lines_between_types = 1
+combine_as_imports = true
+
+[tool.mypy]
+strict = true
+warn_unused_ignores = false
+files = ["avatar", "examples"]
+mypy_path = '$MYPY_CONFIG_FILE_DIR/bt-test-interfaces/python:$MYPY_CONFIG_FILE_DIR/third-party/bumble'
+exclude = 'third-party/bumble'
+
+[[tool.mypy.overrides]]
+module = "grpc.*"
+ignore_missing_imports = true
+
+[[tool.mypy.overrides]]
+module = "mobly.*"
+ignore_missing_imports = true
+
+[tool.pyright]
+include = ["avatar", "examples"]
+exclude = ["**/__pycache__"]
+typeCheckingMode = "strict"
+useLibraryCodeForTypes = true
+verboseOutput = false
+extraPaths = [
+ 'bt-test-interfaces/python',
+ 'third-party/bumble'
+]
+reportMissingTypeStubs = false
+
+[build-system]
+requires = ["flit_core==3.7.1"]
+build-backend = "flit_core.buildapi"
diff --git a/third-party/bumble b/third-party/bumble
new file mode 120000
index 0000000..e5285f6
--- /dev/null
+++ b/third-party/bumble
@@ -0,0 +1 @@
+../../../python/bumble/
\ No newline at end of file