blob: 3a1a82bb886bb9c726b376b907448d7e22c3e9fb [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from dataclasses import dataclass, field
import json
from pathlib import Path
import sys
from textwrap import dedent
from typing import List, Tuple, Union, Optional
from pdl import ast, core
from pdl.utils import indent
def mask(width: int) -> str:
return hex((1 << width) - 1)
def generate_prelude() -> str:
return dedent("""\
from dataclasses import dataclass, field, fields
from typing import Optional, List, Tuple
import enum
import inspect
import math
@dataclass
class Packet:
payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False)
@classmethod
def parse_all(cls, span: bytes) -> 'Packet':
packet, remain = getattr(cls, 'parse')(span)
if len(remain) > 0:
raise Exception('Unexpected parsing remainder')
return packet
@property
def size(self) -> int:
pass
def show(self, prefix: str = ''):
print(f'{self.__class__.__name__}')
def print_val(p: str, pp: str, name: str, align: int, typ, val):
if name == 'payload':
pass
# Scalar fields.
elif typ is int:
print(f'{p}{name:{align}} = {val} (0x{val:x})')
# Byte fields.
elif typ is bytes:
print(f'{p}{name:{align}} = [', end='')
line = ''
n_pp = ''
for (idx, b) in enumerate(val):
if idx > 0 and idx % 8 == 0:
print(f'{n_pp}{line}')
line = ''
n_pp = pp + (' ' * (align + 4))
line += f' {b:02x}'
print(f'{n_pp}{line} ]')
# Enum fields.
elif inspect.isclass(typ) and issubclass(typ, enum.IntEnum):
print(f'{p}{name:{align}} = {typ.__name__}::{val.name} (0x{val:x})')
# Struct fields.
elif inspect.isclass(typ) and issubclass(typ, globals().get('Packet')):
print(f'{p}{name:{align}} = ', end='')
val.show(prefix=pp)
# Array fields.
elif getattr(typ, '__origin__', None) == list:
print(f'{p}{name:{align}}')
last = len(val) - 1
align = 5
for (idx, elt) in enumerate(val):
n_p = pp + ('├── ' if idx != last else '└── ')
n_pp = pp + ('│ ' if idx != last else ' ')
print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx])
# Custom fields.
elif inspect.isclass(typ):
print(f'{p}{name:{align}} = {repr(val)}')
else:
print(f'{p}{name:{align}} = ##{typ}##')
last = len(fields(self)) - 1
align = max(len(f.name) for f in fields(self) if f.name != 'payload')
for (idx, f) in enumerate(fields(self)):
p = prefix + ('├── ' if idx != last else '└── ')
pp = prefix + ('│ ' if idx != last else ' ')
val = getattr(self, f.name)
print_val(p, pp, f.name, align, f.type, val)
""")
@dataclass
class FieldParser:
byteorder: str
offset: int = 0
shift: int = 0
chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: [])
unchecked_code: List[str] = field(default_factory=lambda: [])
code: List[str] = field(default_factory=lambda: [])
def unchecked_append_(self, line: str):
"""Append unchecked field parsing code.
The function check_size_ must be called to generate a size guard
after parsing is completed."""
self.unchecked_code.append(line)
def append_(self, line: str):
"""Append field parsing code.
There must be no unchecked code left before this function is called."""
assert len(self.unchecked_code) == 0
self.code.append(line)
def check_size_(self, size: str):
"""Generate a check of the current span size."""
self.append_(f"if len(span) < {size}:")
self.append_(f" raise Exception('Invalid packet size')")
def check_code_(self):
"""Generate a size check for pending field parsing."""
if len(self.unchecked_code) > 0:
assert len(self.chunk) == 0
unchecked_code = self.unchecked_code
self.unchecked_code = []
self.check_size_(str(self.offset))
self.code.extend(unchecked_code)
def consume_span_(self, keep: int = 0) -> str:
"""Skip consumed span bytes."""
if self.offset > 0:
self.check_code_()
self.append_(f'span = span[{self.offset - keep}:]')
self.offset = 0
def parse_array_element_dynamic_(self, field: ast.ArrayField, span: str):
"""Parse a single array field element of variable size."""
if isinstance(field.type, ast.StructDeclaration):
self.append_(f" element, {span} = {field.type_id}.parse({span})")
self.append_(f" {field.id}.append(element)")
else:
raise Exception(f'Unexpected array element type {field.type_id} {field.width}')
def parse_array_element_static_(self, field: ast.ArrayField, span: str):
"""Parse a single array field element of constant size."""
if field.width is not None:
element = f"int.from_bytes({span}, byteorder='{self.byteorder}')"
self.append_(f" {field.id}.append({element})")
elif isinstance(field.type, ast.EnumDeclaration):
element = f"int.from_bytes({span}, byteorder='{self.byteorder}')"
element = f"{field.type_id}({element})"
self.append_(f" {field.id}.append({element})")
else:
element = f"{field.type_id}.parse_all({span})"
self.append_(f" {field.id}.append({element})")
def parse_byte_array_field_(self, field: ast.ArrayField):
"""Parse the selected u8 array field."""
array_size = core.get_array_field_size(field)
padded_size = field.padded_size
# Shift the span to reset the offset to 0.
self.consume_span_()
# Derive the array size.
if isinstance(array_size, int):
size = array_size
elif isinstance(array_size, ast.SizeField):
size = f'{field.id}_size - {field.size_modifier}' if field.size_modifier else f'{field.id}_size'
elif isinstance(array_size, ast.CountField):
size = f'{field.id}_count'
else:
size = None
# Parse from the padded array if padding is present.
if padded_size and size is not None:
self.check_size_(padded_size)
self.append_(f"if {size} > {padded_size}:")
self.append_(" raise Exception('Array size is larger than the padding size')")
self.append_(f"fields['{field.id}'] = list(span[:{size}])")
self.append_(f"span = span[{padded_size}:]")
elif size is not None:
self.check_size_(size)
self.append_(f"fields['{field.id}'] = list(span[:{size}])")
self.append_(f"span = span[{size}:]")
else:
self.append_(f"fields['{field.id}'] = list(span)")
self.append_(f"span = bytes()")
def parse_array_field_(self, field: ast.ArrayField):
"""Parse the selected array field."""
array_size = core.get_array_field_size(field)
element_width = core.get_array_element_size(field)
padded_size = field.padded_size
if element_width:
if element_width % 8 != 0:
raise Exception('Array element size is not a multiple of 8')
element_width = int(element_width / 8)
if isinstance(array_size, int):
size = None
count = array_size
elif isinstance(array_size, ast.SizeField):
size = f'{field.id}_size'
count = None
elif isinstance(array_size, ast.CountField):
size = None
count = f'{field.id}_count'
else:
size = None
count = None
# Shift the span to reset the offset to 0.
self.consume_span_()
# Apply the size modifier.
if field.size_modifier and size:
self.append_(f"{size} = {size} - {field.size_modifier}")
# Parse from the padded array if padding is present.
if padded_size:
self.check_size_(padded_size)
self.append_(f"remaining_span = span[{padded_size}:]")
self.append_(f"span = span[:{padded_size}]")
# The element width is not known, but the array full octet size
# is known by size field. Parse elements item by item as a vector.
if element_width is None and size is not None:
self.check_size_(size)
self.append_(f"array_span = span[:{size}]")
self.append_(f"{field.id} = []")
self.append_("while len(array_span) > 0:")
self.parse_array_element_dynamic_(field, 'array_span')
self.append_(f"fields['{field.id}'] = {field.id}")
self.append_(f"span = span[{size}:]")
# The element width is not known, but the array element count
# is known statically or by count field.
# Parse elements item by item as a vector.
elif element_width is None and count is not None:
self.append_(f"{field.id} = []")
self.append_(f"for n in range({count}):")
self.parse_array_element_dynamic_(field, 'span')
self.append_(f"fields['{field.id}'] = {field.id}")
# Neither the count not size is known,
# parse elements until the end of the span.
elif element_width is None:
self.append_(f"{field.id} = []")
self.append_("while len(span) > 0:")
self.parse_array_element_dynamic_(field, 'span')
self.append_(f"fields['{field.id}'] = {field.id}")
# The element width is known, and the array element count is known
# statically, or by count field.
elif count is not None:
array_size = (f'{count}' if element_width == 1 else f'{count} * {element_width}')
self.check_size_(array_size)
self.append_(f"{field.id} = []")
self.append_(f"for n in range({count}):")
span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]')
self.parse_array_element_static_(field, span)
self.append_(f"fields['{field.id}'] = {field.id}")
self.append_(f"span = span[{array_size}:]")
# The element width is known, and the array full size is known
# by size field, or unknown (in which case it is the remaining span
# length).
else:
if size is not None:
self.check_size_(size)
array_size = size or 'len(span)'
if element_width != 1:
self.append_(f"if {array_size} % {element_width} != 0:")
self.append_(" raise Exception('Array size is not a multiple of the element size')")
self.append_(f"{field.id}_count = int({array_size} / {element_width})")
array_count = f'{field.id}_count'
else:
array_count = array_size
self.append_(f"{field.id} = []")
self.append_(f"for n in range({array_count}):")
span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]')
self.parse_array_element_static_(field, span)
self.append_(f"fields['{field.id}'] = {field.id}")
if size is not None:
self.append_(f"span = span[{size}:]")
else:
self.append_(f"span = bytes()")
# Drop the padding
if padded_size:
self.append_(f"span = remaining_span")
def parse_bit_field_(self, field: ast.Field):
"""Parse the selected field as a bit field.
The field is added to the current chunk. When a byte boundary
is reached all saved fields are extracted together."""
# Add to current chunk.
width = core.get_field_size(field)
self.chunk.append((self.shift, width, field))
self.shift += width
# Wait for more fields if not on a byte boundary.
if (self.shift % 8) != 0:
return
# Parse the backing integer using the configured endiannes,
# extract field values.
size = int(self.shift / 8)
end_offset = self.offset + size
if size == 1:
value = f"span[{self.offset}]"
else:
span = f"span[{self.offset}:{end_offset}]"
self.unchecked_append_(f"value_ = int.from_bytes({span}, byteorder='{self.byteorder}')")
value = "value_"
for shift, width, field in self.chunk:
v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}")
if isinstance(field, ast.ScalarField):
self.unchecked_append_(f"fields['{field.id}'] = {v}")
elif isinstance(field, ast.FixedField) and field.enum_id:
self.unchecked_append_(f"if {v} != {field.enum_id}.{field.tag_id}:")
self.unchecked_append_(f" raise Exception('Unexpected fixed field value')")
elif isinstance(field, ast.FixedField):
self.unchecked_append_(f"if {v} != {hex(field.value)}:")
self.unchecked_append_(f" raise Exception('Unexpected fixed field value')")
elif isinstance(field, ast.TypedefField):
self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}({v})")
elif isinstance(field, ast.SizeField):
self.unchecked_append_(f"{field.field_id}_size = {v}")
elif isinstance(field, ast.CountField):
self.unchecked_append_(f"{field.field_id}_count = {v}")
elif isinstance(field, ast.ReservedField):
pass
else:
raise Exception(f'Unsupported bit field type {field.kind}')
# Reset state.
self.offset = end_offset
self.shift = 0
self.chunk = []
def parse_typedef_field_(self, field: ast.TypedefField):
"""Parse a typedef field, to the exclusion of Enum fields."""
if self.shift != 0:
raise Exception('Typedef field does not start on an octet boundary')
if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
raise Exception('Derived struct used in typedef field')
width = core.get_declaration_size(field.type)
if width is None:
self.consume_span_()
self.append_(f"{field.id}, span = {field.type_id}.parse(span)")
self.append_(f"fields['{field.id}'] = {field.id}")
else:
if width % 8 != 0:
raise Exception('Typedef field type size is not a multiple of 8')
width = int(width / 8)
end_offset = self.offset + width
# Checksum value field is generated alongside checksum start.
# Deal with this field as padding.
if not isinstance(field.type, ast.ChecksumDeclaration):
span = f'span[{self.offset}:{end_offset}]'
self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.parse_all({span})")
self.offset = end_offset
def parse_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
"""Parse body and payload fields."""
payload_size = core.get_payload_field_size(field)
offset_from_end = core.get_field_offset_from_end(field)
# If the payload is not byte aligned, do parse the bit fields
# that can be extracted, but do not consume the input bytes as
# they will also be included in the payload span.
if self.shift != 0:
if payload_size:
raise Exception("Unexpected payload size for non byte aligned payload")
rounded_size = int((self.shift + 7) / 8)
padding_bits = 8 * rounded_size - self.shift
self.parse_bit_field_(core.make_reserved_field(padding_bits))
self.consume_span_(rounded_size)
else:
self.consume_span_()
# The payload or body has a known size.
# Consume the payload and update the span in case
# fields are placed after the payload.
if payload_size:
if getattr(field, 'size_modifier', None):
self.append_(f"{field.id}_size -= {field.size_modifier}")
self.check_size_(f'{field.id}_size')
self.append_(f"payload = span[:{field.id}_size]")
self.append_(f"span = span[{field.id}_size:]")
# The payload or body is the last field of a packet,
# consume the remaining span.
elif offset_from_end == 0:
self.append_(f"payload = span")
self.append_(f"span = bytes([])")
# The payload or body is followed by fields of static size.
# Consume the span that is not reserved for the following fields.
elif offset_from_end is not None:
if (offset_from_end % 8) != 0:
raise Exception('Payload field offset from end of packet is not a multiple of 8')
offset_from_end = int(offset_from_end / 8)
self.check_size_(f'{offset_from_end}')
self.append_(f"payload = span[:-{offset_from_end}]")
self.append_(f"span = span[-{offset_from_end}:]")
self.append_(f"fields['payload'] = payload")
def parse_checksum_field_(self, field: ast.ChecksumField):
"""Generate a checksum check."""
# The checksum value field can be read starting from the current
# offset if the fields in between are of fixed size, or from the end
# of the span otherwise.
self.consume_span_()
value_field = core.get_packet_field(field.parent, field.field_id)
offset_from_start = 0
offset_from_end = 0
start_index = field.parent.fields.index(field)
value_index = field.parent.fields.index(value_field)
value_size = int(core.get_field_size(value_field) / 8)
for f in field.parent.fields[start_index + 1:value_index]:
size = core.get_field_size(f)
if size is None:
offset_from_start = None
break
else:
offset_from_start += size
trailing_fields = field.parent.fields[value_index:]
trailing_fields.reverse()
for f in trailing_fields:
size = core.get_field_size(f)
if size is None:
offset_from_end = None
break
else:
offset_from_end += size
if offset_from_start is not None:
if offset_from_start % 8 != 0:
raise Exception('Checksum value field is not aligned to an octet boundary')
offset_from_start = int(offset_from_start / 8)
checksum_span = f'span[:{offset_from_start}]'
if value_size > 1:
start = offset_from_start
end = offset_from_start + value_size
value = f"int.from_bytes(span[{start}:{end}], byteorder='{self.byteorder}')"
else:
value = f'span[{offset_from_start}]'
self.check_size_(offset_from_start + value_size)
elif offset_from_end is not None:
sign = ''
if offset_from_end % 8 != 0:
raise Exception('Checksum value field is not aligned to an octet boundary')
offset_from_end = int(offset_from_end / 8)
checksum_span = f'span[:-{offset_from_end}]'
if value_size > 1:
start = offset_from_end
end = offset_from_end - value_size
value = f"int.from_bytes(span[-{start}:-{end}], byteorder='{self.byteorder}')"
else:
value = f'span[-{offset_from_end}]'
self.check_size_(offset_from_end)
else:
raise Exception('Checksum value field cannot be read at constant offset')
self.append_(f"{value_field.id} = {value}")
self.append_(f"fields['{value_field.id}'] = {value_field.id}")
self.append_(f"computed_{value_field.id} = {value_field.type.function}({checksum_span})")
self.append_(f"if computed_{value_field.id} != {value_field.id}:")
self.append_(" raise Exception(f'Invalid checksum computation:" +
f" {{computed_{value_field.id}}} != {{{value_field.id}}}')")
def parse(self, field: ast.Field):
# Field has bit granularity.
# Append the field to the current chunk,
# check if a byte boundary was reached.
if core.is_bit_field(field):
self.parse_bit_field_(field)
# Padding fields.
elif isinstance(field, ast.PaddingField):
pass
# Array fields.
elif isinstance(field, ast.ArrayField) and field.width == 8:
self.parse_byte_array_field_(field)
elif isinstance(field, ast.ArrayField):
self.parse_array_field_(field)
# Other typedef fields.
elif isinstance(field, ast.TypedefField):
self.parse_typedef_field_(field)
# Payload and body fields.
elif isinstance(field, (ast.PayloadField, ast.BodyField)):
self.parse_payload_field_(field)
# Checksum fields.
elif isinstance(field, ast.ChecksumField):
self.parse_checksum_field_(field)
else:
raise Exception(f'Unimplemented field type {field.kind}')
def done(self):
self.consume_span_()
@dataclass
class FieldSerializer:
byteorder: str
shift: int = 0
value: List[str] = field(default_factory=lambda: [])
code: List[str] = field(default_factory=lambda: [])
indent: int = 0
def indent_(self):
self.indent += 1
def unindent_(self):
self.indent -= 1
def append_(self, line: str):
"""Append field serializing code."""
lines = line.split('\n')
self.code.extend([' ' * self.indent + line for line in lines])
def extend_(self, value: str, length: int):
"""Append data to the span being constructed."""
if length == 1:
self.append_(f"_span.append({value})")
else:
self.append_(f"_span.extend(int.to_bytes({value}, length={length}, byteorder='{self.byteorder}'))")
def serialize_array_element_(self, field: ast.ArrayField):
"""Serialize a single array field element."""
if field.width is not None:
length = int(field.width / 8)
self.extend_('_elt', length)
elif isinstance(field.type, ast.EnumDeclaration):
length = int(field.type.width / 8)
self.extend_('_elt', length)
else:
self.append_("_span.extend(_elt.serialize())")
def serialize_array_field_(self, field: ast.ArrayField):
"""Serialize the selected array field."""
if field.padded_size:
self.append_(f"_{field.id}_start = len(_span)")
if field.width == 8:
self.append_(f"_span.extend(self.{field.id})")
else:
self.append_(f"for _elt in self.{field.id}:")
self.indent_()
self.serialize_array_element_(field)
self.unindent_()
if field.padded_size:
self.append_(f"_span.extend([0] * ({field.padded_size} - len(_span) + _{field.id}_start))")
def serialize_bit_field_(self, field: ast.Field):
"""Serialize the selected field as a bit field.
The field is added to the current chunk. When a byte boundary
is reached all saved fields are serialized together."""
# Add to current chunk.
width = core.get_field_size(field)
shift = self.shift
if isinstance(field, str):
self.value.append(f"({field} << {shift})")
elif isinstance(field, ast.ScalarField):
max_value = (1 << field.width) - 1
self.append_(f"if self.{field.id} > {max_value}:")
self.append_(f" print(f\"Invalid value for field {field.parent.id}::{field.id}:" +
f" {{self.{field.id}}} > {max_value}; the value will be truncated\")")
self.append_(f" self.{field.id} &= {max_value}")
self.value.append(f"(self.{field.id} << {shift})")
elif isinstance(field, ast.FixedField) and field.enum_id:
self.value.append(f"({field.enum_id}.{field.tag_id} << {shift})")
elif isinstance(field, ast.FixedField):
self.value.append(f"({field.value} << {shift})")
elif isinstance(field, ast.TypedefField):
self.value.append(f"(self.{field.id} << {shift})")
elif isinstance(field, ast.SizeField):
max_size = (1 << field.width) - 1
value_field = core.get_packet_field(field.parent, field.field_id)
size_modifier = ''
if getattr(value_field, 'size_modifier', None):
size_modifier = f' + {value_field.size_modifier}'
if isinstance(value_field, (ast.PayloadField, ast.BodyField)):
self.append_(f"_payload_size = len(payload or self.payload or []){size_modifier}")
self.append_(f"if _payload_size > {max_size}:")
self.append_(f" print(f\"Invalid length for payload field:" +
f" {{_payload_size}} > {max_size}; the packet cannot be generated\")")
self.append_(f" raise Exception(\"Invalid payload length\")")
array_size = "_payload_size"
elif isinstance(value_field, ast.ArrayField) and value_field.width:
array_size = f"(len(self.{value_field.id}) * {int(value_field.width / 8)}{size_modifier})"
elif isinstance(value_field, ast.ArrayField) and isinstance(value_field.type, ast.EnumDeclaration):
array_size = f"(len(self.{value_field.id}) * {int(value_field.type.width / 8)}{size_modifier})"
elif isinstance(value_field, ast.ArrayField):
self.append_(
f"_{value_field.id}_size = sum([elt.size for elt in self.{value_field.id}]){size_modifier}")
array_size = f"_{value_field.id}_size"
else:
raise Exception("Unsupported field type")
self.value.append(f"({array_size} << {shift})")
elif isinstance(field, ast.CountField):
max_count = (1 << field.width) - 1
self.append_(f"if len(self.{field.field_id}) > {max_count}:")
self.append_(f" print(f\"Invalid length for field {field.parent.id}::{field.field_id}:" +
f" {{len(self.{field.field_id})}} > {max_count}; the array will be truncated\")")
self.append_(f" del self.{field.field_id}[{max_count}:]")
self.value.append(f"(len(self.{field.field_id}) << {shift})")
elif isinstance(field, ast.ReservedField):
pass
else:
raise Exception(f'Unsupported bit field type {field.kind}')
# Check if a byte boundary is reached.
self.shift += width
if (self.shift % 8) == 0:
self.pack_bit_fields_()
def pack_bit_fields_(self):
"""Pack serialized bit fields."""
# Should have an integral number of bytes now.
assert (self.shift % 8) == 0
# Generate the backing integer, and serialize it
# using the configured endiannes,
size = int(self.shift / 8)
if len(self.value) == 0:
self.append_(f"_span.extend([0] * {size})")
elif len(self.value) == 1:
self.extend_(self.value[0], size)
else:
self.append_(f"_value = (")
self.append_(" " + " |\n ".join(self.value))
self.append_(")")
self.extend_('_value', size)
# Reset state.
self.shift = 0
self.value = []
def serialize_typedef_field_(self, field: ast.TypedefField):
"""Serialize a typedef field, to the exclusion of Enum fields."""
if self.shift != 0:
raise Exception('Typedef field does not start on an octet boundary')
if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
raise Exception('Derived struct used in typedef field')
if isinstance(field.type, ast.ChecksumDeclaration):
size = int(field.type.width / 8)
self.append_(f"_checksum = {field.type.function}(_span[_checksum_start:])")
self.extend_('_checksum', size)
else:
self.append_(f"_span.extend(self.{field.id}.serialize())")
def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
"""Serialize body and payload fields."""
if self.shift != 0 and self.byteorder == 'big':
raise Exception('Payload field does not start on an octet boundary')
if self.shift == 0:
self.append_(f"_span.extend(payload or self.payload or [])")
else:
# Supported case of packet inheritance;
# the incomplete fields are serialized into
# the payload, rather than separately.
# First extract the padding bits from the payload,
# then recombine them with the bit fields to be serialized.
rounded_size = int((self.shift + 7) / 8)
padding_bits = 8 * rounded_size - self.shift
self.append_(f"_payload = payload or self.payload or bytes()")
self.append_(f"if len(_payload) < {rounded_size}:")
self.append_(f" raise Exception(f\"Invalid length for payload field:" +
f" {{len(_payload)}} < {rounded_size}\")")
self.append_(
f"_padding = int.from_bytes(_payload[:{rounded_size}], byteorder='{self.byteorder}') >> {self.shift}")
self.value.append(f"(_padding << {self.shift})")
self.shift += padding_bits
self.pack_bit_fields_()
self.append_(f"_span.extend(_payload[{rounded_size}:])")
def serialize_checksum_field_(self, field: ast.ChecksumField):
"""Generate a checksum check."""
self.append_("_checksum_start = len(_span)")
def serialize(self, field: ast.Field):
# Field has bit granularity.
# Append the field to the current chunk,
# check if a byte boundary was reached.
if core.is_bit_field(field):
self.serialize_bit_field_(field)
# Padding fields.
elif isinstance(field, ast.PaddingField):
pass
# Array fields.
elif isinstance(field, ast.ArrayField):
self.serialize_array_field_(field)
# Other typedef fields.
elif isinstance(field, ast.TypedefField):
self.serialize_typedef_field_(field)
# Payload and body fields.
elif isinstance(field, (ast.PayloadField, ast.BodyField)):
self.serialize_payload_field_(field)
# Checksum fields.
elif isinstance(field, ast.ChecksumField):
self.serialize_checksum_field_(field)
else:
raise Exception(f'Unimplemented field type {field.kind}')
def generate_toplevel_packet_serializer(packet: ast.Declaration) -> List[str]:
"""Generate the serialize() function for a toplevel Packet or Struct
declaration."""
serializer = FieldSerializer(byteorder=packet.file.byteorder)
for f in packet.fields:
serializer.serialize(f)
return ['_span = bytearray()'] + serializer.code + ['return bytes(_span)']
def generate_derived_packet_serializer(packet: ast.Declaration) -> List[str]:
"""Generate the serialize() function for a derived Packet or Struct
declaration."""
packet_shift = core.get_packet_shift(packet)
if packet_shift and packet.file.byteorder == 'big':
raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")
serializer = FieldSerializer(byteorder=packet.file.byteorder, shift=packet_shift)
for f in packet.fields:
serializer.serialize(f)
return ['_span = bytearray()'
] + serializer.code + [f'return {packet.parent.id}.serialize(self, payload = bytes(_span))']
def generate_packet_parser(packet: ast.Declaration) -> List[str]:
"""Generate the parse() function for a toplevel Packet or Struct
declaration."""
packet_shift = core.get_packet_shift(packet)
if packet_shift and packet.file.byteorder == 'big':
raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")
# Convert the packet constraints to a boolean expression.
validation = []
if packet.constraints:
cond = []
for c in packet.constraints:
if c.value is not None:
cond.append(f"fields['{c.id}'] != {hex(c.value)}")
else:
field = core.get_packet_field(packet, c.id)
cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}")
validation = [f"if {' or '.join(cond)}:", " raise Exception(\"Invalid constraint field values\")"]
# Parse fields iteratively.
parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift)
for f in packet.fields:
parser.parse(f)
parser.done()
# Specialize to child packets.
children = core.get_derived_packets(packet)
decl = [] if packet.parent_id else ['fields = {\'payload\': None}']
specialization = []
if len(children) != 0:
# Try parsing every child packet successively until one is
# successfully parsed. Return a parsing error if none is valid.
# Return parent packet if no child packet matches.
# TODO: order child packets by decreasing size in case no constraint
# is given for specialization.
for _, child in children:
specialization.append("try:")
specialization.append(f" return {child.id}.parse(fields.copy(), payload)")
specialization.append("except Exception as exn:")
specialization.append(" pass")
return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"]
def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
constant_width = 0
variable_width = []
for f in packet.fields:
field_size = core.get_field_size(f)
if field_size is not None:
constant_width += field_size
elif isinstance(f, (ast.PayloadField, ast.BodyField)):
variable_width.append("len(self.payload)")
elif isinstance(f, ast.TypedefField):
variable_width.append(f"self.{f.id}.size")
elif isinstance(f, ast.ArrayField) and isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
variable_width.append(f"sum([elt.size for elt in self.{f.id}])")
elif isinstance(f, ast.ArrayField) and isinstance(f.type, ast.EnumDeclaration):
variable_width.append(f"len(self.{f.id}) * {f.type.width}")
elif isinstance(f, ast.ArrayField):
variable_width.append(f"len(self.{f.id}) * {int(f.width / 8)}")
else:
raise Exception("Unsupported field type")
constant_width = int(constant_width / 8)
if len(variable_width) == 0:
return [f"return {constant_width}"]
elif len(variable_width) == 1 and constant_width:
return [f"return {variable_width[0]} + {constant_width}"]
elif len(variable_width) == 1:
return [f"return {variable_width[0]}"]
elif len(variable_width) > 1 and constant_width:
return ([f"return {constant_width} + ("] + " +\n ".join(variable_width).split("\n") + [")"])
elif len(variable_width) > 1:
return (["return ("] + " +\n ".join(variable_width).split("\n") + [")"])
else:
assert False
def generate_packet_post_init(decl: ast.Declaration) -> List[str]:
"""Generate __post_init__ function to set constraint field values."""
# Gather all constraints from parent packets.
constraints = []
current = decl
while current.parent_id:
constraints.extend(current.constraints)
current = current.parent
if constraints:
code = []
for c in constraints:
if c.value is not None:
code.append(f"self.{c.id} = {c.value}")
else:
field = core.get_packet_field(decl, c.id)
code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}")
return code
else:
return ["pass"]
def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
"""Generate the implementation of an enum type."""
enum_name = decl.id
tag_decls = []
for t in decl.tags:
tag_decls.append(f"{t.id} = {hex(t.value)}")
return dedent("""\
class {enum_name}(enum.IntEnum):
{tag_decls}
""").format(enum_name=enum_name, tag_decls=indent(tag_decls, 1))
def generate_packet_declaration(packet: ast.Declaration) -> str:
"""Generate the implementation a toplevel Packet or Struct
declaration."""
packet_name = packet.id
field_decls = []
for f in packet.fields:
if isinstance(f, ast.ScalarField):
field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)")
elif isinstance(f, ast.TypedefField):
if isinstance(f.type, ast.EnumDeclaration):
field_decls.append(
f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type_id}.{f.type.tags[0].id})")
elif isinstance(f.type, ast.ChecksumDeclaration):
field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)")
elif isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
field_decls.append(f"{f.id}: {f.type_id} = field(kw_only=True, default_factory={f.type_id})")
else:
raise Exception("Unsupported typedef field type")
elif isinstance(f, ast.ArrayField) and f.width == 8:
field_decls.append(f"{f.id}: bytearray = field(kw_only=True, default_factory=bytearray)")
elif isinstance(f, ast.ArrayField) and f.width:
field_decls.append(f"{f.id}: List[int] = field(kw_only=True, default_factory=list)")
elif isinstance(f, ast.ArrayField) and f.type_id:
field_decls.append(f"{f.id}: List[{f.type_id}] = field(kw_only=True, default_factory=list)")
if packet.parent_id:
parent_name = packet.parent_id
parent_fields = 'fields: dict, '
serializer = generate_derived_packet_serializer(packet)
else:
parent_name = 'Packet'
parent_fields = ''
serializer = generate_toplevel_packet_serializer(packet)
parser = generate_packet_parser(packet)
size = generate_packet_size_getter(packet)
post_init = generate_packet_post_init(packet)
return dedent("""\
@dataclass
class {packet_name}({parent_name}):
{field_decls}
def __post_init__(self):
{post_init}
@staticmethod
def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]:
{parser}
def serialize(self, payload: bytes = None) -> bytes:
{serializer}
@property
def size(self) -> int:
{size}
""").format(packet_name=packet_name,
parent_name=parent_name,
parent_fields=parent_fields,
field_decls=indent(field_decls, 1),
post_init=indent(post_init, 2),
parser=indent(parser, 2),
serializer=indent(serializer, 2),
size=indent(size, 2))
def generate_custom_field_declaration_check(decl: ast.CustomFieldDeclaration) -> str:
"""Generate the code to validate a user custom field implementation.
This code is to be executed when the generated module is loaded to ensure
the user gets an immediate and clear error message when the provided
custom types do not fit the expected template.
"""
return dedent("""\
if (not callable(getattr({custom_field_name}, 'parse', None)) or
not callable(getattr({custom_field_name}, 'parse_all', None))):
raise Exception('The custom field type {custom_field_name} does not implement the parse method')
""").format(custom_field_name=decl.id)
def generate_checksum_declaration_check(decl: ast.ChecksumDeclaration) -> str:
"""Generate the code to validate a user checksum field implementation.
This code is to be executed when the generated module is loaded to ensure
the user gets an immediate and clear error message when the provided
checksum functions do not fit the expected template.
"""
return dedent("""\
if not callable({checksum_name}):
raise Exception('{checksum_name} is not callable')
""").format(checksum_name=decl.id)
def run(input: argparse.FileType, output: argparse.FileType, custom_type_location: Optional[str]):
file = ast.File.from_json(json.load(input))
core.desugar(file)
custom_types = []
custom_type_checks = ""
for d in file.declarations:
if isinstance(d, ast.CustomFieldDeclaration):
custom_types.append(d.id)
custom_type_checks += generate_custom_field_declaration_check(d)
elif isinstance(d, ast.ChecksumDeclaration):
custom_types.append(d.id)
custom_type_checks += generate_checksum_declaration_check(d)
output.write(f"# File generated from {input.name}, with the command:\n")
output.write(f"# {' '.join(sys.argv)}\n")
output.write("# /!\\ Do not edit by hand.\n")
if custom_types and custom_type_location:
output.write(f"\nfrom {custom_type_location} import {', '.join(custom_types)}\n")
output.write(generate_prelude())
output.write(custom_type_checks)
for d in file.declarations:
if isinstance(d, ast.EnumDeclaration):
output.write(generate_enum_declaration(d))
elif isinstance(d, (ast.PacketDeclaration, ast.StructDeclaration)):
output.write(generate_packet_declaration(d))
def main() -> int:
"""Generate python PDL backend."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file')
parser.add_argument('--custom-type-location',
type=str,
required=False,
help='Module of declaration of custom types')
return run(**vars(parser.parse_args()))
if __name__ == '__main__':
sys.exit(main())