blob: ea761386dbdf3309fd10d982af60a934fa7cb821 [file] [log] [blame]
//! Core dependency injection macros
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::{
braced, parse, parse_macro_input, DeriveInput, Fields, FnArg, Ident, ItemFn, ItemStruct, Path,
Token, Type,
};
/// Defines a provider function, with generated helper that implicitly fetches argument instances from the registry
#[proc_macro_attribute]
pub fn provides(_attr: TokenStream, item: TokenStream) -> TokenStream {
let function: ItemFn = parse(item).expect("can only be applied to functions");
// Create the info needed to refer to the function & the injected version we generate
let ident = function.sig.ident.clone();
let injected_ident = format_ident!("__gddi_{}_injected", ident);
// Create the info needed to generate the call to the original function
let inputs = function.sig.inputs.iter().map(|arg| {
if let FnArg::Typed(t) = arg {
return t.ty.clone();
}
panic!("can't be applied to struct methods");
});
let local_var_idents = (0..inputs.len()).map(|i| format_ident!("__input{}", i));
let local_var_idents_for_call = local_var_idents.clone();
let emitted_code = quote! {
// Injecting wrapper
fn #injected_ident(registry: std::sync::Arc<gddi::Registry>) -> std::pin::Pin<gddi::ProviderFutureBox> {
Box::pin(async move {
// Create a local variable for each argument, to ensure they get generated in a
// deterministic order (compiler complains otherwise)
#(let #local_var_idents = registry.get::<#inputs>().await;)*
// Actually call the original function
Box::new(#ident(#(#local_var_idents_for_call),*).await) as Box<dyn std::any::Any>
})
}
#function
};
emitted_code.into()
}
struct ModuleDef {
name: Ident,
providers: Punctuated<ProviderDef, Token![,]>,
submodules: Punctuated<Path, Token![,]>,
}
enum ModuleEntry {
Providers(Punctuated<ProviderDef, Token![,]>),
Submodules(Punctuated<Path, Token![,]>),
}
struct ProviderDef {
ty: Type,
ident: Ident,
parts: bool,
}
impl Parse for ModuleDef {
fn parse(input: ParseStream) -> Result<Self> {
// first thing is the module name followed by a comma
let name = input.parse()?;
input.parse::<Token![,]>()?;
// Then comes submodules or provider sections, in any order
let entries: Punctuated<ModuleEntry, Token![,]> = Punctuated::parse_terminated(input)?;
let mut providers = Punctuated::new();
let mut submodules = Punctuated::new();
for entry in entries.into_iter() {
match entry {
ModuleEntry::Providers(value) => {
if !providers.is_empty() {
panic!("providers specified more than once");
}
providers = value;
}
ModuleEntry::Submodules(value) => {
if !submodules.is_empty() {
panic!("submodules specified more than once");
}
submodules = value;
}
}
}
Ok(ModuleDef { name, providers, submodules })
}
}
impl Parse for ProviderDef {
fn parse(input: ParseStream) -> Result<Self> {
let parts = input.peek3(Token![=>]);
if parts {
match input.parse::<Ident>()?.to_string().as_str() {
"parts" => {}
keyword => panic!("expected 'parts', got '{}'", keyword),
}
}
// A provider definition follows this format: <Type> -> <function name>
let ty = input.parse()?;
input.parse::<Token![=>]>()?;
let ident = input.parse()?;
Ok(ProviderDef { ty, ident, parts })
}
}
impl Parse for ModuleEntry {
fn parse(input: ParseStream) -> Result<Self> {
match input.parse::<Ident>()?.to_string().as_str() {
"providers" => {
let entries;
braced!(entries in input);
Ok(ModuleEntry::Providers(entries.parse_terminated(ProviderDef::parse)?))
}
"submodules" => {
let entries;
braced!(entries in input);
Ok(ModuleEntry::Submodules(entries.parse_terminated(Path::parse)?))
}
keyword => {
panic!("unexpected keyword: {}", keyword);
}
}
}
}
/// Emits a module function that registers submodules & providers with the registry
#[proc_macro]
pub fn module(item: TokenStream) -> TokenStream {
let module = parse_macro_input!(item as ModuleDef);
let init_ident = module.name.clone();
let providers = module.providers.iter();
let types = providers.clone().map(|p| p.ty.clone());
let provider_idents =
providers.clone().map(|p| format_ident!("__gddi_{}_injected", p.ident.clone()));
let parting_functions = providers.filter_map(|p| match &p.ty {
Type::Path(ty) if p.parts => Some(format_ident!(
"__gddi_part_out_{}",
ty.path.get_ident().unwrap().to_string().to_lowercase()
)),
_ => None,
});
let submodule_idents = module.submodules.iter();
let emitted_code = quote! {
#[doc(hidden)]
#[allow(missing_docs)]
pub fn #init_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
// Register all providers on this module
let ret = builder#(.register_provider::<#types>(Box::new(#provider_idents)))*
// Register all submodules on this module
#(.register_module(#submodule_idents))*;
#(let ret = #parting_functions(ret);)*
ret
}
};
emitted_code.into()
}
/// Emits a default implementation for Stoppable that does nothing;
#[proc_macro_derive(Stoppable)]
pub fn derive_nop_stop(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let ident = input.ident;
let emitted_code = quote! {
impl gddi::Stoppable for #ident {}
};
emitted_code.into()
}
/// Generates the code necessary to split up a type into its components
#[proc_macro_attribute]
pub fn part_out(_attr: TokenStream, item: TokenStream) -> TokenStream {
let struct_: ItemStruct = parse(item).expect("can only be applied to struct definitions");
let struct_ident = struct_.ident.clone();
let fields = match struct_.fields.clone() {
Fields::Named(f) => f,
_ => panic!("can only be applied to structs with named fields"),
}
.named;
let field_names = fields.iter().map(|f| f.ident.clone().expect("field without a name"));
let field_types = fields.iter().map(|f| f.ty.clone());
let fn_ident = format_ident!("__gddi_part_out_{}", struct_ident.to_string().to_lowercase());
let emitted_code = quote! {
#struct_
fn #fn_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
builder#(.register_provider::<#field_types>(Box::new(
|registry: std::sync::Arc<gddi::Registry>| -> std::pin::Pin<gddi::ProviderFutureBox> {
Box::pin(async move {
Box::new(async move {
registry.get::<#struct_ident>().await.#field_names
}.await) as Box<dyn std::any::Any>
})
})))*
}
};
emitted_code.into()
}