blob: 37ef6550f4a87d6f899f379a19a2ce974afdf96d [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use crate::backends::intermediate::{
ComputedOffset, ComputedOffsetId, ComputedValue, ComputedValueId,
};
/// This trait is implemented on computed quantities (offsets and values) that can be retrieved via a function call
pub trait Declarable {
fn get_name(&self) -> String;
fn get_ident(&self) -> Ident {
format_ident!("try_get_{}", self.get_name())
}
fn call_fn(&self) -> TokenStream {
let fn_name = self.get_ident();
quote! { self.#fn_name()? }
}
fn declare_fn(&self, body: TokenStream) -> TokenStream {
let fn_name = self.get_ident();
quote! {
#[inline]
fn #fn_name(&self) -> Result<usize, ParseError> {
#body
}
}
}
}
impl Declarable for ComputedValueId<'_> {
fn get_name(&self) -> String {
match self {
ComputedValueId::FieldSize(field) => format!("{field}_size"),
ComputedValueId::FieldElementSize(field) => format!("{field}_element_size"),
ComputedValueId::FieldCount(field) => format!("{field}_count"),
ComputedValueId::Custom(i) => format!("custom_value_{i}"),
}
}
}
impl Declarable for ComputedOffsetId<'_> {
fn get_name(&self) -> String {
match self {
ComputedOffsetId::HeaderStart => "header_start_offset".to_string(),
ComputedOffsetId::PacketEnd => "packet_end_offset".to_string(),
ComputedOffsetId::FieldOffset(field) => format!("{field}_offset"),
ComputedOffsetId::FieldEndOffset(field) => format!("{field}_end_offset"),
ComputedOffsetId::Custom(i) => format!("custom_offset_{i}"),
ComputedOffsetId::TrailerStart => "trailer_start_offset".to_string(),
}
}
}
/// This trait is implemented on computed expressions that are computed on-demand (i.e. not via a function call)
pub trait Computable {
fn compute(&self) -> TokenStream;
}
impl Computable for ComputedValue<'_> {
fn compute(&self) -> TokenStream {
match self {
ComputedValue::Constant(k) => quote! { Ok(#k) },
ComputedValue::CountStructsUpToSize { base_id, size, struct_type } => {
let base_offset = base_id.call_fn();
let size = size.call_fn();
let struct_type = format_ident!("{struct_type}View");
quote! {
let mut cnt = 0;
let mut view = self.buf.offset(#base_offset)?;
let mut remaining_size = #size;
while remaining_size > 0 {
let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?;
if next_struct_size > remaining_size {
return Err(ParseError::OutOfBoundsAccess);
}
remaining_size -= next_struct_size;
view = view.offset(next_struct_size * 8)?;
cnt += 1;
}
Ok(cnt)
}
}
ComputedValue::SizeOfNStructs { base_id, n, struct_type } => {
let base_offset = base_id.call_fn();
let n = n.call_fn();
let struct_type = format_ident!("{struct_type}View");
quote! {
let mut view = self.buf.offset(#base_offset)?;
let mut size = 0;
for _ in 0..#n {
let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?;
size += next_struct_size;
view = view.offset(next_struct_size * 8)?;
}
Ok(size)
}
}
ComputedValue::Product(x, y) => {
let x = x.call_fn();
let y = y.call_fn();
quote! { #x.checked_mul(#y).ok_or(ParseError::ArithmeticOverflow) }
}
ComputedValue::Divide(x, y) => {
let x = x.call_fn();
let y = y.call_fn();
quote! {
if #y == 0 || #x % #y != 0 {
return Err(ParseError::DivisionFailure)
}
Ok(#x / #y)
}
}
ComputedValue::Difference(x, y) => {
let x = x.call_fn();
let y = y.call_fn();
quote! {
let bit_difference = #x.checked_sub(#y).ok_or(ParseError::ArithmeticOverflow)?;
if bit_difference % 8 != 0 {
return Err(ParseError::DivisionFailure);
}
Ok(bit_difference / 8)
}
}
ComputedValue::ValueAt { offset, width } => {
let offset = offset.call_fn();
quote! { self.buf.offset(#offset)?.slice(#width)?.try_parse() }
}
}
}
}
impl Computable for ComputedOffset<'_> {
fn compute(&self) -> TokenStream {
match self {
ComputedOffset::ConstantPlusOffsetInBits(base_id, offset) => {
let base_id = base_id.call_fn();
quote! { #base_id.checked_add_signed(#offset as isize).ok_or(ParseError::ArithmeticOverflow) }
}
ComputedOffset::SumWithOctets(x, y) => {
let x = x.call_fn();
let y = y.call_fn();
quote! {
#x.checked_add(#y.checked_mul(8).ok_or(ParseError::ArithmeticOverflow)?)
.ok_or(ParseError::ArithmeticOverflow)
}
}
ComputedOffset::Alias(alias) => {
let alias = alias.call_fn();
quote! { Ok(#alias) }
}
}
}
}