blob: 555a33260cfef6ea2f1d2f4dfa89c0c9df93bab0 [file] [log] [blame]
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import asyncio
import logging
from colors import color
from .. import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Information needed to parse HCI packets with a generic parser:
# For each packet type, the info represents:
# (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
}
# -----------------------------------------------------------------------------
class PacketPump:
'''
Pump HCI packets from a reader to a sink
'''
def __init__(self, reader, sink):
self.reader = reader
self.sink = sink
async def run(self):
while True:
try:
# Get a packet from the source
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
# Deliver the packet to the sink
self.sink.on_packet(packet)
except Exception as error:
logger.warning(f'!!! {error}')
# -----------------------------------------------------------------------------
class PacketParser:
'''
In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed
'''
# pylint: disable=attribute-defined-outside-init
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
def __init__(self, sink=None):
self.sink = sink
self.extended_packet_info = {}
self.reset()
def reset(self):
self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1
self.packet = bytearray()
self.packet_info = None
def feed_data(self, data):
data_offset = 0
data_left = len(data)
while data_left and self.bytes_needed:
consumed = min(self.bytes_needed, data_left)
self.packet.extend(data[data_offset : data_offset + consumed])
data_offset += consumed
data_left -= consumed
self.bytes_needed -= consumed
if self.bytes_needed == 0:
if self.state == PacketParser.NEED_TYPE:
packet_type = self.packet[0]
self.packet_info = HCI_PACKET_INFO.get(
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}')
self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH:
body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
self.bytes_needed = body_length
self.state = PacketParser.NEED_BODY
# Emit a packet if one is complete
if self.state == PacketParser.NEED_BODY and not self.bytes_needed:
if self.sink:
try:
self.sink.on_packet(bytes(self.packet))
except Exception as error:
logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset()
def set_packet_sink(self, sink):
self.sink = sink
# -----------------------------------------------------------------------------
class PacketReader:
'''
Reader that reads HCI packets from a sync source
'''
def __init__(self, source):
self.source = source
def next_packet(self):
# Get the packet type
packet_type = self.source.read(1)
if len(packet_type) != 1:
return None
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
header = self.source.read(header_size)
if len(header) != header_size:
raise ValueError('packet too short')
# Read the body
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
body = self.source.read(body_length)
if len(body) != body_length:
raise ValueError('packet too short')
return packet_type + header + body
# -----------------------------------------------------------------------------
class AsyncPacketReader:
'''
Reader that reads HCI packets from an async source
'''
def __init__(self, source):
self.source = source
async def next_packet(self):
# Get the packet type
packet_type = await self.source.readexactly(1)
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
header = await self.source.readexactly(header_size)
# Read the body
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
body = await self.source.readexactly(body_length)
return packet_type + header + body
# -----------------------------------------------------------------------------
class AsyncPipeSink:
'''
Sink that forwards packets asynchronously to another sink
'''
def __init__(self, sink):
self.sink = sink
self.loop = asyncio.get_running_loop()
def on_packet(self, packet):
self.loop.call_soon(self.sink.on_packet, packet)
# -----------------------------------------------------------------------------
class ParserSource:
"""
Base class designed to be subclassed by transport-specific source classes
"""
def __init__(self):
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink):
self.parser.set_packet_sink(sink)
async def wait_for_termination(self):
return await self.terminated
def close(self):
pass
# -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data):
self.parser.feed_data(data)
# -----------------------------------------------------------------------------
class StreamPacketSink:
def __init__(self, transport):
self.transport = transport
def on_packet(self, packet):
self.transport.write(packet)
def close(self):
self.transport.close()
# -----------------------------------------------------------------------------
class Transport:
def __init__(self, source, sink):
self.source = source
self.sink = sink
async def __aenter__(self):
return self
async def __aexit__(self, *args):
await self.close()
def __iter__(self):
return iter((self.source, self.sink))
async def close(self):
self.source.close()
self.sink.close()
# -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource):
def __init__(self, receive):
super().__init__()
self.receive_function = receive
self.pump_task = None
def start(self):
async def pump_packets():
while True:
try:
packet = await self.receive_function()
self.parser.feed_data(packet)
except asyncio.exceptions.CancelledError:
logger.debug('source pump task done')
break
except Exception as error:
logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error)
break
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:
self.pump_task.cancel()
# -----------------------------------------------------------------------------
class PumpedPacketSink:
def __init__(self, send):
self.send_function = send
self.packet_queue = asyncio.Queue()
self.pump_task = None
def on_packet(self, packet):
self.packet_queue.put_nowait(packet)
def start(self):
async def pump_packets():
while True:
try:
packet = await self.packet_queue.get()
await self.send_function(packet)
except asyncio.exceptions.CancelledError:
logger.debug('sink pump task done')
break
except Exception as error:
logger.warning(f'exception while sending packet: {error}')
break
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:
self.pump_task.cancel()
# -----------------------------------------------------------------------------
class PumpedTransport(Transport):
def __init__(self, source, sink, close_function):
super().__init__(source, sink)
self.close_function = close_function
def start(self):
self.source.start()
self.sink.start()
async def close(self):
await super().close()
await self.close_function()