blob: 9ecae381d3b1adf764c1d753a05a535f93da5854 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use crate::{
ast,
backends::{
intermediate::{ComputedValue, ComputedValueId, PacketOrStruct, Schema},
rust_no_allocation::utils::get_integer_type,
},
parser,
};
fn standardize_child(id: &str) -> &str {
match id {
"_body_" | "_payload_" => "_child_",
_ => id,
}
}
pub fn generate_packet_serializer(
id: &str,
parent_id: Option<&str>,
fields: &[parser::ast::Field],
schema: &Schema,
curr_schema: &PacketOrStruct,
children: &HashMap<&str, Vec<&str>>,
) -> TokenStream {
let id_ident = format_ident!("{id}Builder");
let builder_fields = fields
.iter()
.filter_map(|field| {
match &field.desc {
ast::FieldDesc::Padding { .. }
| ast::FieldDesc::Reserved { .. }
| ast::FieldDesc::FixedScalar { .. }
| ast::FieldDesc::FixedEnum { .. }
| ast::FieldDesc::ElementSize { .. }
| ast::FieldDesc::Count { .. }
| ast::FieldDesc::Size { .. } => {
// no-op, no getter generated for this type
None
}
ast::FieldDesc::Group { .. } => unreachable!(),
ast::FieldDesc::Checksum { .. } => {
unimplemented!("checksums not yet supported with this backend")
}
ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => {
let type_ident = format_ident!("{id}Child");
Some(("_child_", quote! { #type_ident }))
}
ast::FieldDesc::Array { id, width, type_id, .. } => {
let element_type = if let Some(width) = width {
get_integer_type(*width)
} else if let Some(type_id) = type_id {
if schema.enums.contains_key(type_id.as_str()) {
format_ident!("{type_id}")
} else {
format_ident!("{type_id}Builder")
}
} else {
unreachable!();
};
Some((id.as_str(), quote! { Box<[#element_type]> }))
}
ast::FieldDesc::Scalar { id, width } => {
let id_type = get_integer_type(*width);
Some((id.as_str(), quote! { #id_type }))
}
ast::FieldDesc::Typedef { id, type_id } => {
let type_ident = if schema.enums.contains_key(type_id.as_str()) {
format_ident!("{type_id}")
} else {
format_ident!("{type_id}Builder")
};
Some((id.as_str(), quote! { #type_ident }))
}
}
})
.map(|(id, typ)| {
let id_ident = format_ident!("{id}");
quote! { pub #id_ident: #typ }
});
let mut has_child = false;
let serializer = fields.iter().map(|field| {
match &field.desc {
ast::FieldDesc::Checksum { .. } | ast::FieldDesc::Group { .. } => unimplemented!(),
ast::FieldDesc::Padding { size, .. } => {
quote! {
if (most_recent_array_size_in_bits > #size * 8) {
return Err(SerializeError::NegativePadding);
}
writer.write_bits((#size * 8 - most_recent_array_size_in_bits) as usize, || Ok(0u64))?;
}
},
ast::FieldDesc::Size { field_id, width } => {
let field_id = standardize_child(field_id);
let field_ident = format_ident!("{field_id}");
// if the element-size is fixed, we can directly multiply
if let Some(ComputedValue::Constant(element_width)) = curr_schema.computed_values.get(&ComputedValueId::FieldElementSize(field_id)) {
return quote! {
writer.write_bits(
#width,
|| u64::try_from(self.#field_ident.len() * #element_width).or(Err(SerializeError::IntegerConversionFailure))
)?;
}
}
// if the field is "countable", loop over it to sum up the size
if curr_schema.computed_values.contains_key(&ComputedValueId::FieldCount(field_id)) {
return quote! {
writer.write_bits(#width, || {
let size_in_bits = self.#field_ident.iter().map(|elem| elem.size_in_bits()).fold(Ok(0), |total, next| {
let total: u64 = total?;
let next = u64::try_from(next?).or(Err(SerializeError::IntegerConversionFailure))?;
total.checked_add(next).ok_or(SerializeError::IntegerConversionFailure)
})?;
if size_in_bits % 8 != 0 {
return Err(SerializeError::AlignmentError);
}
Ok(size_in_bits / 8)
})?;
}
}
// otherwise, try to get the size directly
quote! {
writer.write_bits(#width, || {
let size_in_bits: u64 = self.#field_ident.size_in_bits()?.try_into().or(Err(SerializeError::IntegerConversionFailure))?;
if size_in_bits % 8 != 0 {
return Err(SerializeError::AlignmentError);
}
Ok(size_in_bits / 8)
})?;
}
}
ast::FieldDesc::Count { field_id, width } => {
let field_ident = format_ident!("{field_id}");
quote! { writer.write_bits(#width, || u64::try_from(self.#field_ident.len()).or(Err(SerializeError::IntegerConversionFailure)))?; }
}
ast::FieldDesc::ElementSize { field_id, width } => {
// TODO(aryarahul) - add validation for elementsize against all the other elements
let field_ident = format_ident!("{field_id}");
quote! {
let get_element_size = || Ok(if let Some(field) = self.#field_ident.get(0) {
let size_in_bits = field.size_in_bits()?;
if size_in_bits % 8 == 0 {
(size_in_bits / 8) as u64
} else {
return Err(SerializeError::AlignmentError);
}
} else {
0
});
writer.write_bits(#width, || get_element_size() )?;
}
}
ast::FieldDesc::Reserved { width, .. } => {
quote!{ writer.write_bits(#width, || Ok(0u64))?; }
}
ast::FieldDesc::Scalar { width, id } => {
let field_ident = format_ident!("{id}");
quote! { writer.write_bits(#width, || Ok(self.#field_ident))?; }
}
ast::FieldDesc::FixedScalar { width, value } => {
let width = quote! { #width };
let value = {
let value = *value as u64;
quote! { #value }
};
quote!{ writer.write_bits(#width, || Ok(#value))?; }
}
ast::FieldDesc::FixedEnum { enum_id, tag_id } => {
let width = {
let width = schema.enums[enum_id.as_str()].width;
quote! { #width }
};
let value = {
let enum_ident = format_ident!("{}", enum_id);
let tag_ident = format_ident!("{tag_id}");
quote! { #enum_ident::#tag_ident.value() }
};
quote!{ writer.write_bits(#width, || Ok(#value))?; }
}
ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => {
has_child = true;
quote! { self._child_.serialize(writer)?; }
}
ast::FieldDesc::Array { width, id, .. } => {
let id_ident = format_ident!("{id}");
if let Some(width) = width {
quote! {
for elem in self.#id_ident.iter() {
writer.write_bits(#width, || Ok(*elem))?;
}
let most_recent_array_size_in_bits = #width * self.#id_ident.len();
}
} else {
quote! {
let mut most_recent_array_size_in_bits = 0;
for elem in self.#id_ident.iter() {
most_recent_array_size_in_bits += elem.size_in_bits()?;
elem.serialize(writer)?;
}
}
}
}
ast::FieldDesc::Typedef { id, .. } => {
let id_ident = format_ident!("{id}");
quote! { self.#id_ident.serialize(writer)?; }
}
}
}).collect::<Vec<_>>();
let variant_names = children.get(id).into_iter().flatten().collect::<Vec<_>>();
let variants = variant_names.iter().map(|name| {
let name_ident = format_ident!("{name}");
let variant_ident = format_ident!("{name}Builder");
quote! { #name_ident(#variant_ident) }
});
let variant_serializers = variant_names.iter().map(|name| {
let name_ident = format_ident!("{name}");
quote! {
Self::#name_ident(x) => {
x.serialize(writer)?;
}
}
});
let children_enum = if has_child {
let enum_ident = format_ident!("{id}Child");
quote! {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum #enum_ident {
RawData(Box<[u8]>),
#(#variants),*
}
impl Serializable for #enum_ident {
fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> {
match self {
Self::RawData(data) => {
for byte in data.iter() {
writer.write_bits(8, || Ok(*byte as u64))?;
}
},
#(#variant_serializers),*
}
Ok(())
}
}
}
} else {
quote! {}
};
let parent_type_converter = if let Some(parent_id) = parent_id {
let parent_enum_ident = format_ident!("{parent_id}Child");
let variant_ident = format_ident!("{id}");
Some(quote! {
impl From<#id_ident> for #parent_enum_ident {
fn from(x: #id_ident) -> Self {
Self::#variant_ident(x)
}
}
})
} else {
None
};
let owned_packet_ident = format_ident!("Owned{id}View");
quote! {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct #id_ident {
#(#builder_fields),*
}
impl Builder for #id_ident {
type OwnedPacket = #owned_packet_ident;
}
impl Serializable for #id_ident {
fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> {
#(#serializer)*
Ok(())
}
}
#parent_type_converter
#children_enum
}
}