| use crate::{ |
| ast::{Container, Data, Field, Variant}, |
| attr::WithAttr, |
| }; |
| use std::collections::BTreeSet; |
| use syn::{punctuated::Punctuated, Ident}; |
| |
| // This logic is heavily based on serde_derive: |
| // https://github.com/serde-rs/serde/blob/a1ddb18c92f32d64b2ccaf31ddd776e56be34ba2/serde_derive/src/bound.rs#L91 |
| |
| pub fn find_trait_bounds<'a>(orig_generics: &'a syn::Generics, cont: &mut Container<'a>) { |
| if orig_generics.params.is_empty() { |
| return; |
| } |
| |
| let all_type_params = orig_generics |
| .type_params() |
| .map(|param| ¶m.ident) |
| .collect(); |
| |
| assert!(cont.rename_type_params.is_subset(&all_type_params)); |
| |
| let mut visitor = FindTyParams { |
| all_type_params, |
| relevant_type_params: cont.rename_type_params.clone(), |
| type_params_for_bound: cont.rename_type_params.clone(), |
| }; |
| |
| let mut field_explicit_bounds = Vec::new(); |
| |
| if visitor.all_type_params.len() > visitor.relevant_type_params.len() { |
| match &cont.data { |
| Data::Enum(variants) => { |
| for variant in variants { |
| let relevant_fields = variant |
| .fields |
| .iter() |
| .filter(|field| needs_jsonschema_bound(field, Some(variant))); |
| |
| for field in relevant_fields { |
| field_explicit_bounds.extend(field.serde_attrs.de_bound()); |
| visitor.visit_field(field); |
| } |
| } |
| } |
| Data::Struct(_, fields) => { |
| let relevant_fields = fields |
| .iter() |
| .filter(|field| needs_jsonschema_bound(field, None)); |
| |
| for field in relevant_fields { |
| field_explicit_bounds.extend(field.serde_attrs.de_bound()); |
| visitor.visit_field(field); |
| } |
| } |
| } |
| } |
| |
| cont.relevant_type_params = visitor.relevant_type_params; |
| |
| let where_clause = cont.generics.make_where_clause(); |
| |
| if let Some(bounds) = cont.serde_attrs.de_bound() { |
| where_clause.predicates.extend(bounds.iter().cloned()); |
| } else { |
| where_clause |
| .predicates |
| .extend(visitor.type_params_for_bound.into_iter().map(|ty| { |
| syn::WherePredicate::Type(syn::PredicateType { |
| lifetimes: None, |
| bounded_ty: syn::Type::Path(syn::TypePath { |
| qself: None, |
| path: syn::Path { |
| leading_colon: None, |
| segments: Punctuated::from_iter([syn::PathSegment { |
| ident: (*ty).clone(), |
| arguments: syn::PathArguments::None, |
| }]), |
| }, |
| }), |
| colon_token: <Token![:]>::default(), |
| bounds: Punctuated::from_iter([syn::TypeParamBound::Trait(syn::TraitBound { |
| paren_token: None, |
| modifier: syn::TraitBoundModifier::None, |
| lifetimes: None, |
| path: parse_quote!(schemars::JsonSchema), |
| })]), |
| }) |
| })); |
| } |
| |
| where_clause |
| .predicates |
| .extend(field_explicit_bounds.into_iter().flatten().cloned()); |
| } |
| |
| fn needs_jsonschema_bound(field: &Field, variant: Option<&Variant>) -> bool { |
| if let Some(variant) = variant { |
| if variant.serde_attrs.skip_deserializing() && variant.serde_attrs.skip_serializing() { |
| return false; |
| } |
| } |
| |
| if field.serde_attrs.skip_deserializing() && field.serde_attrs.skip_serializing() { |
| return false; |
| } |
| |
| true |
| } |
| |
| struct FindTyParams<'ast> { |
| all_type_params: BTreeSet<&'ast Ident>, |
| relevant_type_params: BTreeSet<&'ast Ident>, |
| type_params_for_bound: BTreeSet<&'ast Ident>, |
| } |
| |
| #[allow(clippy::match_same_arms)] |
| impl FindTyParams<'_> { |
| fn visit_field(&mut self, field: &Field) { |
| match &field.attrs.with { |
| Some(WithAttr::Type(ty)) => self.visit_type(field, ty), |
| Some(WithAttr::Function(_)) => { |
| // `schema_with` function type params may or may not implement `JsonSchema` |
| } |
| None => self.visit_type(field, &field.original.ty), |
| } |
| } |
| |
| fn visit_path(&mut self, field: &Field, path: &syn::Path) { |
| if let Some(seg) = path.segments.last() { |
| if seg.ident == "PhantomData" { |
| // Hardcoded exception, because PhantomData<T> implements |
| // JsonSchema whether or not T implements it. |
| return; |
| } |
| } |
| |
| if path.leading_colon.is_none() { |
| if let Some(first_segment) = path.segments.first() { |
| let id = &first_segment.ident; |
| if let Some(id) = self.all_type_params.get(id) { |
| self.relevant_type_params.insert(id); |
| if field.serde_attrs.de_bound().is_none() { |
| self.type_params_for_bound.insert(id); |
| } |
| } |
| } |
| } |
| |
| for segment in &path.segments { |
| self.visit_path_segment(field, segment); |
| } |
| } |
| |
| fn visit_type(&mut self, field: &Field, ty: &syn::Type) { |
| match ty { |
| syn::Type::Array(ty) => self.visit_type(field, &ty.elem), |
| syn::Type::BareFn(ty) => { |
| for arg in &ty.inputs { |
| self.visit_type(field, &arg.ty); |
| } |
| self.visit_return_type(field, &ty.output); |
| } |
| syn::Type::Group(ty) => self.visit_type(field, &ty.elem), |
| syn::Type::ImplTrait(ty) => { |
| for bound in &ty.bounds { |
| self.visit_type_param_bound(field, bound); |
| } |
| } |
| syn::Type::Macro(ty) => self.visit_macro(field, &ty.mac), |
| syn::Type::Paren(ty) => self.visit_type(field, &ty.elem), |
| syn::Type::Path(ty) => { |
| if let Some(qself) = &ty.qself { |
| self.visit_type(field, &qself.ty); |
| } |
| self.visit_path(field, &ty.path); |
| } |
| syn::Type::Ptr(ty) => self.visit_type(field, &ty.elem), |
| syn::Type::Reference(ty) => { |
| self.visit_type(field, &ty.elem); |
| } |
| syn::Type::Slice(ty) => self.visit_type(field, &ty.elem), |
| syn::Type::TraitObject(ty) => { |
| for bound in &ty.bounds { |
| self.visit_type_param_bound(field, bound); |
| } |
| } |
| syn::Type::Tuple(ty) => { |
| for elem in &ty.elems { |
| self.visit_type(field, elem); |
| } |
| } |
| |
| syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {} |
| |
| _ => {} |
| } |
| } |
| |
| fn visit_path_segment(&mut self, field: &Field, segment: &syn::PathSegment) { |
| self.visit_path_arguments(field, &segment.arguments); |
| } |
| |
| fn visit_path_arguments(&mut self, field: &Field, arguments: &syn::PathArguments) { |
| match arguments { |
| syn::PathArguments::None => {} |
| syn::PathArguments::AngleBracketed(arguments) => { |
| for arg in &arguments.args { |
| match arg { |
| syn::GenericArgument::Type(arg) => self.visit_type(field, arg), |
| syn::GenericArgument::AssocType(arg) => self.visit_type(field, &arg.ty), |
| syn::GenericArgument::Lifetime(_) |
| | syn::GenericArgument::Const(_) |
| | syn::GenericArgument::AssocConst(_) |
| | syn::GenericArgument::Constraint(_) => {} |
| _ => {} |
| } |
| } |
| } |
| syn::PathArguments::Parenthesized(arguments) => { |
| for argument in &arguments.inputs { |
| self.visit_type(field, argument); |
| } |
| self.visit_return_type(field, &arguments.output); |
| } |
| } |
| } |
| |
| fn visit_return_type(&mut self, field: &Field, return_type: &syn::ReturnType) { |
| match return_type { |
| syn::ReturnType::Default => {} |
| syn::ReturnType::Type(_, output) => self.visit_type(field, output), |
| } |
| } |
| |
| fn visit_type_param_bound(&mut self, field: &Field, bound: &syn::TypeParamBound) { |
| match bound { |
| syn::TypeParamBound::Trait(bound) => self.visit_path(field, &bound.path), |
| syn::TypeParamBound::Lifetime(_) |
| | syn::TypeParamBound::PreciseCapture(_) |
| | syn::TypeParamBound::Verbatim(_) => {} |
| _ => {} |
| } |
| } |
| |
| // Type parameter should not be considered used by a macro path. |
| // |
| // struct TypeMacro<T> { |
| // mac: T!(), |
| // marker: PhantomData<T>, |
| // } |
| #[allow(clippy::unused_self)] |
| fn visit_macro(&mut self, _field: &Field, _mac: &syn::Macro) {} |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use pretty_assertions::assert_eq; |
| |
| #[test] |
| fn test_enum_bounds() { |
| // All type params should be included in `JsonSchema` trait bounds except `Z` |
| let input = parse_quote! { |
| #[schemars(rename = "MyEnum<{T}, {U}, {V}, {W}, {X}, {Y}, {{Z}}>")] |
| pub enum MyEnum<'a, const LEN: usize, T, U, V, W, X, Y, Z> |
| where |
| X: Trait, |
| Z: OtherTrait |
| { |
| A, |
| B(), |
| C(T), |
| D(U, (i8, V, bool)), |
| E { |
| a: W, |
| b: [&'a Option<Box<<X as Trait>::AssocType::Z>>; LEN], |
| c: Token![Z], |
| d: PhantomData<Z>, |
| #[serde(skip)] |
| e: Z, |
| }, |
| #[serde(skip)] |
| F(Z), |
| } |
| }; |
| |
| let cont = Container::from_ast(&input).unwrap(); |
| |
| assert_eq!( |
| cont.generics.where_clause, |
| Some(parse_quote!( |
| where |
| X: Trait, |
| Z: OtherTrait, |
| T: schemars::JsonSchema, |
| U: schemars::JsonSchema, |
| V: schemars::JsonSchema, |
| W: schemars::JsonSchema, |
| X: schemars::JsonSchema, |
| Y: schemars::JsonSchema |
| )) |
| ); |
| |
| let relevant_type_params = |
| Vec::from_iter(cont.relevant_type_params.into_iter().map(Ident::to_string)); |
| assert_eq!(relevant_type_params, vec!["T", "U", "V", "W", "X", "Y"]); |
| } |
| } |