use heck::ToSnakeCase; | |
use proc_macro::TokenStream; | |
use syn::parse::{Parse, ParseStream}; | |
use syn::punctuated::Punctuated; | |
use syn::{Ident, ItemStruct, Path, Token}; | |
type PunctuatedQueryGroups = Punctuated<QueryGroup, Token![,]>; | |
pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { | |
let args = syn::parse_macro_input!(args as QueryGroupList); | |
let input = syn::parse_macro_input!(input as ItemStruct); | |
let query_groups = &args.query_groups; | |
let database_name = &input.ident; | |
let visibility = &input.vis; | |
let db_storage_field = quote! { storage }; | |
let mut output = proc_macro2::TokenStream::new(); | |
output.extend(quote! { #input }); | |
let query_group_names_snake: Vec<_> = query_groups | |
.iter() | |
.map(|query_group| { | |
let group_name = query_group.name(); | |
Ident::new(&group_name.to_string().to_snake_case(), group_name.span()) | |
}) | |
.collect(); | |
let query_group_storage_names: Vec<_> = query_groups | |
.iter() | |
.map(|QueryGroup { group_path }| { | |
quote! { | |
<#group_path as salsa::plumbing::QueryGroup>::GroupStorage | |
} | |
}) | |
.collect(); | |
// For each query group `foo::MyGroup` create a link to its | |
// `foo::MyGroupGroupStorage` | |
let mut storage_fields = proc_macro2::TokenStream::new(); | |
let mut storage_initializers = proc_macro2::TokenStream::new(); | |
let mut has_group_impls = proc_macro2::TokenStream::new(); | |
for (((query_group, group_name_snake), group_storage), group_index) in query_groups | |
.iter() | |
.zip(&query_group_names_snake) | |
.zip(&query_group_storage_names) | |
.zip(0_u16..) | |
{ | |
let group_path = &query_group.group_path; | |
// rewrite the last identifier (`MyGroup`, above) to | |
// (e.g.) `MyGroupGroupStorage`. | |
storage_fields.extend(quote! { | |
#group_name_snake: #group_storage, | |
}); | |
// rewrite the last identifier (`MyGroup`, above) to | |
// (e.g.) `MyGroupGroupStorage`. | |
storage_initializers.extend(quote! { | |
#group_name_snake: #group_storage::new(#group_index), | |
}); | |
// ANCHOR:HasQueryGroup | |
has_group_impls.extend(quote! { | |
impl salsa::plumbing::HasQueryGroup<#group_path> for #database_name { | |
fn group_storage(&self) -> &#group_storage { | |
&self.#db_storage_field.query_store().#group_name_snake | |
} | |
fn group_storage_mut(&mut self) -> (&#group_storage, &mut salsa::Runtime) { | |
let (query_store_mut, runtime) = self.#db_storage_field.query_store_mut(); | |
(&query_store_mut.#group_name_snake, runtime) | |
} | |
} | |
}); | |
// ANCHOR_END:HasQueryGroup | |
} | |
// create group storage wrapper struct | |
output.extend(quote! { | |
#[doc(hidden)] | |
#visibility struct __SalsaDatabaseStorage { | |
#storage_fields | |
} | |
impl Default for __SalsaDatabaseStorage { | |
fn default() -> Self { | |
Self { | |
#storage_initializers | |
} | |
} | |
} | |
}); | |
// Create a tuple (D1, D2, ...) where Di is the data for a given query group. | |
let mut database_data = vec![]; | |
for QueryGroup { group_path } in query_groups { | |
database_data.push(quote! { | |
<#group_path as salsa::plumbing::QueryGroup>::GroupData | |
}); | |
} | |
// ANCHOR:DatabaseStorageTypes | |
output.extend(quote! { | |
impl salsa::plumbing::DatabaseStorageTypes for #database_name { | |
type DatabaseStorage = __SalsaDatabaseStorage; | |
} | |
}); | |
// ANCHOR_END:DatabaseStorageTypes | |
// ANCHOR:DatabaseOps | |
let mut fmt_ops = proc_macro2::TokenStream::new(); | |
let mut maybe_changed_ops = proc_macro2::TokenStream::new(); | |
let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); | |
let mut for_each_ops = proc_macro2::TokenStream::new(); | |
for ((QueryGroup { group_path }, group_storage), group_index) in query_groups | |
.iter() | |
.zip(&query_group_storage_names) | |
.zip(0_u16..) | |
{ | |
fmt_ops.extend(quote! { | |
#group_index => { | |
let storage: &#group_storage = | |
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); | |
storage.fmt_index(self, input, fmt) | |
} | |
}); | |
maybe_changed_ops.extend(quote! { | |
#group_index => { | |
let storage: &#group_storage = | |
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); | |
storage.maybe_changed_after(self, input, revision) | |
} | |
}); | |
cycle_recovery_strategy_ops.extend(quote! { | |
#group_index => { | |
let storage: &#group_storage = | |
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); | |
storage.cycle_recovery_strategy(self, input) | |
} | |
}); | |
for_each_ops.extend(quote! { | |
let storage: &#group_storage = | |
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); | |
storage.for_each_query(runtime, &mut op); | |
}); | |
} | |
output.extend(quote! { | |
impl salsa::plumbing::DatabaseOps for #database_name { | |
fn ops_database(&self) -> &dyn salsa::Database { | |
self | |
} | |
fn ops_salsa_runtime(&self) -> &salsa::Runtime { | |
self.#db_storage_field.salsa_runtime() | |
} | |
fn ops_salsa_runtime_mut(&mut self) -> &mut salsa::Runtime { | |
self.#db_storage_field.salsa_runtime_mut() | |
} | |
fn fmt_index( | |
&self, | |
input: salsa::DatabaseKeyIndex, | |
fmt: &mut std::fmt::Formatter<'_>, | |
) -> std::fmt::Result { | |
match input.group_index() { | |
#fmt_ops | |
i => panic!("salsa: invalid group index {}", i) | |
} | |
} | |
fn maybe_changed_after( | |
&self, | |
input: salsa::DatabaseKeyIndex, | |
revision: salsa::Revision | |
) -> bool { | |
match input.group_index() { | |
#maybe_changed_ops | |
i => panic!("salsa: invalid group index {}", i) | |
} | |
} | |
fn cycle_recovery_strategy( | |
&self, | |
input: salsa::DatabaseKeyIndex, | |
) -> salsa::plumbing::CycleRecoveryStrategy { | |
match input.group_index() { | |
#cycle_recovery_strategy_ops | |
i => panic!("salsa: invalid group index {}", i) | |
} | |
} | |
fn for_each_query( | |
&self, | |
mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps), | |
) { | |
let runtime = salsa::Database::salsa_runtime(self); | |
#for_each_ops | |
} | |
} | |
}); | |
// ANCHOR_END:DatabaseOps | |
output.extend(has_group_impls); | |
if std::env::var("SALSA_DUMP").is_ok() { | |
println!("~~~ database_storage"); | |
println!("{}", output.to_string()); | |
println!("~~~ database_storage"); | |
} | |
output.into() | |
} | |
#[derive(Clone, Debug)] | |
struct QueryGroupList { | |
query_groups: PunctuatedQueryGroups, | |
} | |
impl Parse for QueryGroupList { | |
fn parse(input: ParseStream) -> syn::Result<Self> { | |
let query_groups: PunctuatedQueryGroups = | |
input.parse_terminated(QueryGroup::parse, Token![,])?; | |
Ok(QueryGroupList { query_groups }) | |
} | |
} | |
#[derive(Clone, Debug)] | |
struct QueryGroup { | |
group_path: Path, | |
} | |
impl QueryGroup { | |
/// The name of the query group trait. | |
fn name(&self) -> Ident { | |
self.group_path.segments.last().unwrap().ident.clone() | |
} | |
} | |
impl Parse for QueryGroup { | |
/// ```ignore | |
/// impl HelloWorldDatabase; | |
/// ``` | |
fn parse(input: ParseStream) -> syn::Result<Self> { | |
let group_path: Path = input.parse()?; | |
Ok(QueryGroup { group_path }) | |
} | |
} | |
struct Nothing; | |
impl Parse for Nothing { | |
fn parse(_input: ParseStream) -> syn::Result<Self> { | |
Ok(Nothing) | |
} | |
} |