| extern crate proc_macro; |
| |
| use proc_macro2::{Span, TokenStream}; |
| use quote::quote; |
| use syn::*; |
| |
| static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary"; |
| |
| #[proc_macro_derive(Arbitrary)] |
| pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { |
| let input = syn::parse_macro_input!(tokens as syn::DeriveInput); |
| let (lifetime_without_bounds, lifetime_with_bounds) = |
| build_arbitrary_lifetime(input.generics.clone()); |
| |
| let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone()); |
| let size_hint_method = gen_size_hint_method(&input); |
| let name = input.ident; |
| // Add a bound `T: Arbitrary` to every type parameter T. |
| let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone()); |
| |
| // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90) |
| let mut generics_with_lifetime = generics.clone(); |
| generics_with_lifetime |
| .params |
| .push(GenericParam::Lifetime(lifetime_with_bounds)); |
| let (impl_generics, _, _) = generics_with_lifetime.split_for_impl(); |
| |
| // Build TypeGenerics and WhereClause without a lifetime |
| let (_, ty_generics, where_clause) = generics.split_for_impl(); |
| |
| (quote! { |
| impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause { |
| #arbitrary_method |
| #size_hint_method |
| } |
| }) |
| .into() |
| } |
| |
| // Returns: (lifetime without bounds, lifetime with bounds) |
| // Example: ("'arbitrary", "'arbitrary: 'a + 'b") |
| fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) { |
| let lifetime_without_bounds = |
| LifetimeDef::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site())); |
| let mut lifetime_with_bounds = lifetime_without_bounds.clone(); |
| |
| for param in generics.params.iter() { |
| if let GenericParam::Lifetime(lifetime_def) = param { |
| lifetime_with_bounds |
| .bounds |
| .push(lifetime_def.lifetime.clone()); |
| } |
| } |
| |
| (lifetime_without_bounds, lifetime_with_bounds) |
| } |
| |
| // Add a bound `T: Arbitrary` to every type parameter T. |
| fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics { |
| for param in generics.params.iter_mut() { |
| if let GenericParam::Type(type_param) = param { |
| type_param |
| .bounds |
| .push(parse_quote!(arbitrary::Arbitrary<#lifetime>)); |
| } |
| } |
| generics |
| } |
| |
| fn gen_arbitrary_method(input: &DeriveInput, lifetime: LifetimeDef) -> TokenStream { |
| let ident = &input.ident; |
| let arbitrary_structlike = |fields| { |
| let arbitrary = construct(fields, |_, _| quote!(arbitrary::Arbitrary::arbitrary(u)?)); |
| let arbitrary_take_rest = construct_take_rest(fields); |
| quote! { |
| fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { |
| Ok(#ident #arbitrary) |
| } |
| |
| fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { |
| Ok(#ident #arbitrary_take_rest) |
| } |
| } |
| }; |
| match &input.data { |
| Data::Struct(data) => arbitrary_structlike(&data.fields), |
| Data::Union(data) => arbitrary_structlike(&Fields::Named(data.fields.clone())), |
| Data::Enum(data) => { |
| let variants = data.variants.iter().enumerate().map(|(i, variant)| { |
| let idx = i as u64; |
| let ctor = construct(&variant.fields, |_, _| { |
| quote!(arbitrary::Arbitrary::arbitrary(u)?) |
| }); |
| let variant_name = &variant.ident; |
| quote! { #idx => #ident::#variant_name #ctor } |
| }); |
| let variants_take_rest = data.variants.iter().enumerate().map(|(i, variant)| { |
| let idx = i as u64; |
| let ctor = construct_take_rest(&variant.fields); |
| let variant_name = &variant.ident; |
| quote! { #idx => #ident::#variant_name #ctor } |
| }); |
| let count = data.variants.len() as u64; |
| quote! { |
| fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { |
| // Use a multiply + shift to generate a ranged random number |
| // with slight bias. For details, see: |
| // https://lemire.me/blog/2016/06/30/fast-random-shuffling |
| Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 { |
| #(#variants,)* |
| _ => unreachable!() |
| }) |
| } |
| |
| fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { |
| // Use a multiply + shift to generate a ranged random number |
| // with slight bias. For details, see: |
| // https://lemire.me/blog/2016/06/30/fast-random-shuffling |
| Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 { |
| #(#variants_take_rest,)* |
| _ => unreachable!() |
| }) |
| } |
| } |
| } |
| } |
| } |
| |
| fn construct(fields: &Fields, ctor: impl Fn(usize, &Field) -> TokenStream) -> TokenStream { |
| match fields { |
| Fields::Named(names) => { |
| let names = names.named.iter().enumerate().map(|(i, f)| { |
| let name = f.ident.as_ref().unwrap(); |
| let ctor = ctor(i, f); |
| quote! { #name: #ctor } |
| }); |
| quote! { { #(#names,)* } } |
| } |
| Fields::Unnamed(names) => { |
| let names = names.unnamed.iter().enumerate().map(|(i, f)| { |
| let ctor = ctor(i, f); |
| quote! { #ctor } |
| }); |
| quote! { ( #(#names),* ) } |
| } |
| Fields::Unit => quote!(), |
| } |
| } |
| |
| fn construct_take_rest(fields: &Fields) -> TokenStream { |
| construct(fields, |idx, _| { |
| if idx + 1 == fields.len() { |
| quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? } |
| } else { |
| quote! { arbitrary::Arbitrary::arbitrary(&mut u)? } |
| } |
| }) |
| } |
| |
| fn gen_size_hint_method(input: &DeriveInput) -> TokenStream { |
| let size_hint_fields = |fields: &Fields| { |
| let tys = fields.iter().map(|f| &f.ty); |
| quote! { |
| arbitrary::size_hint::and_all(&[ |
| #( <#tys as arbitrary::Arbitrary>::size_hint(depth) ),* |
| ]) |
| } |
| }; |
| let size_hint_structlike = |fields: &Fields| { |
| let hint = size_hint_fields(fields); |
| quote! { |
| #[inline] |
| fn size_hint(depth: usize) -> (usize, Option<usize>) { |
| arbitrary::size_hint::recursion_guard(depth, |depth| #hint) |
| } |
| } |
| }; |
| match &input.data { |
| Data::Struct(data) => size_hint_structlike(&data.fields), |
| Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())), |
| Data::Enum(data) => { |
| let variants = data.variants.iter().map(|v| size_hint_fields(&v.fields)); |
| quote! { |
| #[inline] |
| fn size_hint(depth: usize) -> (usize, Option<usize>) { |
| arbitrary::size_hint::and( |
| <u32 as arbitrary::Arbitrary>::size_hint(depth), |
| arbitrary::size_hint::recursion_guard(depth, |depth| { |
| arbitrary::size_hint::or_all(&[ #( #variants ),* ]) |
| }), |
| ) |
| } |
| } |
| } |
| } |
| } |