Add more argement hints in L2CAP
diff --git a/bumble/l2cap.py b/bumble/l2cap.py
index 76ae585..119cb13 100644
--- a/bumble/l2cap.py
+++ b/bumble/l2cap.py
@@ -22,7 +22,18 @@
from collections import deque
from pyee import EventEmitter
-from typing import Dict, Type, List, Optional, Tuple, Callable, Any, Union, Deque
+from typing import (
+ Dict,
+ Type,
+ List,
+ Optional,
+ Tuple,
+ Callable,
+ Any,
+ Union,
+ Deque,
+ Iterable,
+)
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
@@ -155,7 +166,7 @@
'''
@staticmethod
- def from_bytes(data) -> L2CAP_PDU:
+ def from_bytes(data: bytes) -> L2CAP_PDU:
# Sanity check
if len(data) < 4:
raise ValueError('not enough data for L2CAP header')
@@ -169,7 +180,7 @@
header = struct.pack('<HH', len(self.payload), self.cid)
return header + self.payload
- def __init__(self, cid, payload) -> None:
+ def __init__(self, cid: int, payload: bytes) -> None:
self.cid = cid
self.payload = payload
@@ -191,7 +202,7 @@
name: str
@staticmethod
- def from_bytes(pdu) -> L2CAP_Control_Frame:
+ def from_bytes(pdu: bytes) -> L2CAP_Control_Frame:
code = pdu[0]
cls = L2CAP_Control_Frame.classes.get(code)
@@ -216,11 +227,11 @@
return self
@staticmethod
- def code_name(code) -> str:
+ def code_name(code: int) -> str:
return name_or_number(L2CAP_CONTROL_FRAME_NAMES, code)
@staticmethod
- def decode_configuration_options(data) -> List[Tuple[int, bytes]]:
+ def decode_configuration_options(data: bytes) -> List[Tuple[int, bytes]]:
options = []
while len(data) >= 2:
value_type = data[0]
@@ -232,7 +243,7 @@
return options
@staticmethod
- def encode_configuration_options(options) -> bytes:
+ def encode_configuration_options(options: List[Tuple[int, bytes]]) -> bytes:
return b''.join(
[bytes([option[0], len(option[1])]) + option[1] for option in options]
)
@@ -258,8 +269,9 @@
def __init__(self, pdu=None, **kwargs) -> None:
self.identifier = kwargs.get('identifier', 0)
- if hasattr(self, 'fields') and kwargs:
- HCI_Object.init_from_fields(self, self.fields, kwargs)
+ if hasattr(self, 'fields'):
+ if kwargs:
+ HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
data = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
@@ -315,7 +327,7 @@
}
@staticmethod
- def reason_name(reason) -> str:
+ def reason_name(reason: int) -> str:
return name_or_number(L2CAP_Command_Reject.REASON_NAMES, reason)
@@ -343,7 +355,7 @@
'''
@staticmethod
- def parse_psm(data, offset=0) -> Tuple[int, int]:
+ def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2
psm = data[offset] | data[offset + 1] << 8
@@ -355,7 +367,7 @@
return offset + psm_length, psm
@staticmethod
- def serialize_psm(psm) -> bytes:
+ def serialize_psm(psm: int) -> bytes:
serialized = struct.pack('<H', psm & 0xFFFF)
psm >>= 16
while psm:
@@ -405,7 +417,7 @@
}
@staticmethod
- def result_name(result) -> str:
+ def result_name(result: int) -> str:
return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result)
@@ -452,7 +464,7 @@
}
@staticmethod
- def result_name(result) -> str:
+ def result_name(result: int) -> str:
return name_or_number(L2CAP_Configure_Response.RESULT_NAMES, result)
@@ -529,7 +541,7 @@
}
@staticmethod
- def info_type_name(info_type) -> str:
+ def info_type_name(info_type: int) -> str:
return name_or_number(L2CAP_Information_Request.INFO_TYPE_NAMES, info_type)
@@ -556,7 +568,7 @@
RESULT_NAMES = {SUCCESS: 'SUCCESS', NOT_SUPPORTED: 'NOT_SUPPORTED'}
@staticmethod
- def result_name(result) -> str:
+ def result_name(result: int) -> str:
return name_or_number(L2CAP_Information_Response.RESULT_NAMES, result)
@@ -642,7 +654,7 @@
}
@staticmethod
- def result_name(result) -> str:
+ def result_name(result: int) -> str:
return name_or_number(
L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result
)
@@ -707,9 +719,16 @@
disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
+ state: int
def __init__(
- self, manager, connection, signaling_cid, psm, source_cid, mtu
+ self,
+ manager: 'ChannelManager',
+ connection,
+ signaling_cid: int,
+ psm: int,
+ source_cid: int,
+ mtu: int,
) -> None:
super().__init__()
self.manager = manager
@@ -725,7 +744,7 @@
self.disconnection_result = None
self.sink = None
- def change_state(self, new_state) -> None:
+ def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
)
@@ -1008,25 +1027,26 @@
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
+ state: int
@staticmethod
- def state_name(state) -> str:
+ def state_name(state: int) -> str:
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__(
self,
- manager,
+ manager: 'ChannelManager',
connection,
- le_psm,
- source_cid,
- destination_cid,
- mtu,
- mps,
- credits, # pylint: disable=redefined-builtin
- peer_mtu,
- peer_mps,
- peer_credits,
- connected,
+ le_psm: int,
+ source_cid: int,
+ destination_cid: int,
+ mtu: int,
+ mps: int,
+ credits: int, # pylint: disable=redefined-builtin
+ peer_mtu: int,
+ peer_mps: int,
+ peer_credits: int,
+ connected: bool,
) -> None:
super().__init__()
self.manager = manager
@@ -1059,7 +1079,7 @@
else:
self.state = LeConnectionOrientedChannel.INIT
- def change_state(self, new_state) -> None:
+ def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
@@ -1228,7 +1248,7 @@
# Cleanup
self.connection_result = None
- def on_credits(self, credits) -> None: # pylint: disable=redefined-builtin
+ def on_credits(self, credits: int) -> None: # pylint: disable=redefined-builtin
self.credits += credits
logger.debug(f'received {credits} credits, total = {self.credits}')
@@ -1310,7 +1330,7 @@
self.drained.set()
return
- def write(self, data) -> None:
+ def write(self, data: bytes) -> None:
if self.state != self.CONNECTED:
logger.warning('not connected, dropping data')
return
@@ -1360,7 +1380,9 @@
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
def __init__(
- self, extended_features=(), connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU
+ self,
+ extended_features: Iterable[int] = (),
+ connectionless_mtu: int = L2CAP_DEFAULT_CONNECTIONLESS_MTU,
) -> None:
self._host = None
self.identifiers = {} # Incrementing identifier values by connection
@@ -1390,20 +1412,20 @@
if host is not None:
host.on('disconnection', self.on_disconnection)
- def find_channel(self, connection_handle, cid):
+ def find_channel(self, connection_handle: int, cid: int):
if connection_channels := self.channels.get(connection_handle):
return connection_channels.get(cid)
return None
- def find_le_coc_channel(self, connection_handle, cid):
+ def find_le_coc_channel(self, connection_handle: int, cid: int):
if connection_channels := self.le_coc_channels.get(connection_handle):
return connection_channels.get(cid)
return None
@staticmethod
- def find_free_br_edr_cid(channels) -> int:
+ def find_free_br_edr_cid(channels: Iterable[int]) -> int:
# Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice)
@@ -1416,7 +1438,7 @@
raise RuntimeError('no free CID available')
@staticmethod
- def find_free_le_cid(channels) -> int:
+ def find_free_le_cid(channels: Iterable[int]) -> int:
# Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice)
@@ -1429,7 +1451,7 @@
raise RuntimeError('no free CID')
@staticmethod
- def check_le_coc_parameters(max_credits, mtu, mps) -> None:
+ def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None:
if (
max_credits < 1
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
@@ -1448,14 +1470,16 @@
self.identifiers[connection.handle] = identifier
return identifier
- def register_fixed_channel(self, cid, handler) -> None:
+ def register_fixed_channel(
+ self, cid: int, handler: Callable[[int, bytes], Any]
+ ) -> None:
self.fixed_channels[cid] = handler
- def deregister_fixed_channel(self, cid) -> None:
+ def deregister_fixed_channel(self, cid: int) -> None:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
- def register_server(self, psm, server) -> int:
+ def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int:
if psm == 0:
# Find a free PSM
for candidate in range(
@@ -1489,11 +1513,11 @@
def register_le_coc_server(
self,
- psm,
- server,
- max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
- mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
- mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
+ psm: int,
+ server: Callable[[LeConnectionOrientedChannel], Any],
+ max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
+ mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
+ mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
) -> int:
self.check_le_coc_parameters(max_credits, mtu, mps)
@@ -1522,7 +1546,7 @@
return psm
- def on_disconnection(self, connection_handle, _reason) -> None:
+ def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
if connection_handle in self.channels:
for _, channel in self.channels[connection_handle].items():
@@ -1535,7 +1559,7 @@
if connection_handle in self.identifiers:
del self.identifiers[connection_handle]
- def send_pdu(self, connection, cid, pdu) -> None:
+ def send_pdu(self, connection, cid: int, pdu) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
@@ -1544,7 +1568,7 @@
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
- def on_pdu(self, connection, cid, pdu) -> None:
+ def on_pdu(self, connection, cid: int, pdu) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
# Parse the L2CAP payload into a Control Frame object
control_frame = L2CAP_Control_Frame.from_bytes(pdu)
@@ -1565,7 +1589,7 @@
channel.on_pdu(pdu)
- def send_control_frame(self, connection, cid, control_frame) -> None:
+ def send_control_frame(self, connection, cid: int, control_frame) -> None:
logger.debug(
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1573,7 +1597,7 @@
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
- def on_control_frame(self, connection, cid, control_frame) -> None:
+ def on_control_frame(self, connection, cid: int, control_frame) -> None:
logger.debug(
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1610,10 +1634,10 @@
),
)
- def on_l2cap_command_reject(self, _connection, _cid, packet) -> None:
+ def on_l2cap_command_reject(self, _connection, _cid: int, packet) -> None:
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
- def on_l2cap_connection_request(self, connection, cid, request) -> None:
+ def on_l2cap_connection_request(self, connection, cid: int, request) -> None:
# Check if there's a server for this PSM
server = self.servers.get(request.psm)
if server:
@@ -1665,7 +1689,7 @@
),
)
- def on_l2cap_connection_response(self, connection, cid, response) -> None:
+ def on_l2cap_connection_response(self, connection, cid: int, response) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1680,7 +1704,7 @@
channel.on_connection_response(response)
- def on_l2cap_configure_request(self, connection, cid, request) -> None:
+ def on_l2cap_configure_request(self, connection, cid: int, request) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1695,7 +1719,7 @@
channel.on_configure_request(request)
- def on_l2cap_configure_response(self, connection, cid, response) -> None:
+ def on_l2cap_configure_response(self, connection, cid: int, response) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1710,7 +1734,7 @@
channel.on_configure_response(response)
- def on_l2cap_disconnection_request(self, connection, cid, request) -> None:
+ def on_l2cap_disconnection_request(self, connection, cid: int, request) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1725,7 +1749,7 @@
channel.on_disconnection_request(request)
- def on_l2cap_disconnection_response(self, connection, cid, response) -> None:
+ def on_l2cap_disconnection_response(self, connection, cid: int, response) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1740,7 +1764,7 @@
channel.on_disconnection_response(response)
- def on_l2cap_echo_request(self, connection, cid, request) -> None:
+ def on_l2cap_echo_request(self, connection, cid: int, request) -> None:
logger.debug(f'<<< Echo request: data={request.data.hex()}')
self.send_control_frame(
connection,
@@ -1748,11 +1772,11 @@
L2CAP_Echo_Response(identifier=request.identifier, data=request.data),
)
- def on_l2cap_echo_response(self, _connection, _cid, response) -> None:
+ def on_l2cap_echo_response(self, _connection, _cid: int, response) -> None:
logger.debug(f'<<< Echo response: data={response.data.hex()}')
# TODO notify listeners
- def on_l2cap_information_request(self, connection, cid, request) -> None:
+ def on_l2cap_information_request(self, connection, cid: int, request) -> None:
if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU:
result = L2CAP_Information_Response.SUCCESS
data = self.connectionless_mtu.to_bytes(2, 'little')
@@ -1776,7 +1800,9 @@
),
)
- def on_l2cap_connection_parameter_update_request(self, connection, cid, request):
+ def on_l2cap_connection_parameter_update_request(
+ self, connection, cid: int, request
+ ):
if connection.role == BT_CENTRAL_ROLE:
self.send_control_frame(
connection,
@@ -1795,7 +1821,7 @@
supervision_timeout=request.timeout,
min_ce_length=0,
max_ce_length=0,
- )
+ ) # type: ignore[call-arg]
)
else:
self.send_control_frame(
@@ -1808,13 +1834,13 @@
)
def on_l2cap_connection_parameter_update_response(
- self, connection, cid, response
+ self, connection, cid: int, response
) -> None:
# TODO: check response
pass
def on_l2cap_le_credit_based_connection_request(
- self, connection, cid, request
+ self, connection, cid: int, request
) -> None:
if request.le_psm in self.le_coc_servers:
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
@@ -1918,7 +1944,7 @@
)
def on_l2cap_le_credit_based_connection_response(
- self, connection, _cid, response
+ self, connection, _cid: int, response
) -> None:
# Find the pending request by identifier
request = self.le_coc_requests.get(response.identifier)
@@ -1942,7 +1968,7 @@
# Process the response
channel.on_connection_response(response)
- def on_l2cap_le_flow_control_credit(self, connection, _cid, credit) -> None:
+ def on_l2cap_le_flow_control_credit(self, connection, _cid: int, credit) -> None:
channel = self.find_le_coc_channel(connection.handle, credit.cid)
if channel is None:
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')
@@ -1950,14 +1976,14 @@
channel.on_credits(credit.credits)
- def on_channel_closed(self, channel) -> None:
+ def on_channel_closed(self, channel: Channel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
async def open_le_coc(
- self, connection, psm, max_credits, mtu, mps
+ self, connection, psm: int, max_credits: int, mtu: int, mps: int
) -> LeConnectionOrientedChannel:
self.check_le_coc_parameters(max_credits, mtu, mps)
@@ -1999,7 +2025,7 @@
return channel
- async def connect(self, connection, psm) -> Channel:
+ async def connect(self, connection, psm: int) -> Channel:
# NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel