blob: e813e91c7416fc8c37f03b91626c5248e280d760 [file] [log] [blame]
use std::convert::TryFrom;
use crate::parenthesized::Parenthesized;
use heck::CamelCase;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::ToTokens;
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, parse_quote, Attribute, FnArg, Ident, ItemTrait, Path, ReturnType, Token,
TraitBound, TraitBoundModifier, TraitItem, Type, TypeParamBound,
};
/// Implementation for `[salsa::query_group]` decorator.
pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
let group_struct = parse_macro_input!(args as Ident);
let input: ItemTrait = parse_macro_input!(input as ItemTrait);
// println!("args: {:#?}", args);
// println!("input: {:#?}", input);
let (trait_attrs, salsa_attrs) = filter_attrs(input.attrs);
let mut requires: Punctuated<Path, Token![+]> = Punctuated::new();
for SalsaAttr { name, tts } in salsa_attrs {
match name.as_str() {
"requires" => {
requires.push(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
_ => panic!("unknown salsa attribute `{}`", name),
}
}
let trait_vis = input.vis;
let trait_name = input.ident;
let _generics = input.generics.clone();
// Decompose the trait into the corresponding queries.
let mut queries = vec![];
for item in input.items {
match item {
TraitItem::Method(method) => {
let mut storage = QueryStorage::Memoized;
let mut cycle = None;
let mut invoke = None;
let mut query_type = Ident::new(
&format!("{}Query", method.sig.ident.to_string().to_camel_case()),
Span::call_site(),
);
let mut num_storages = 0;
// Extract attributes.
let (attrs, salsa_attrs) = filter_attrs(method.attrs);
for SalsaAttr { name, tts } in salsa_attrs {
match name.as_str() {
"memoized" => {
storage = QueryStorage::Memoized;
num_storages += 1;
}
"dependencies" => {
storage = QueryStorage::Dependencies;
num_storages += 1;
}
"input" => {
storage = QueryStorage::Input;
num_storages += 1;
}
"interned" => {
storage = QueryStorage::Interned;
num_storages += 1;
}
"cycle" => {
cycle = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
"invoke" => {
invoke = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
"query_type" => {
query_type = parse_macro_input!(tts as Parenthesized<Ident>).0;
}
"transparent" => {
storage = QueryStorage::Transparent;
num_storages += 1;
}
_ => panic!("unknown salsa attribute `{}`", name),
}
}
// Check attribute combinations.
if num_storages > 1 {
panic!("multiple storage attributes specified");
}
if invoke.is_some() && storage == QueryStorage::Input {
panic!("#[salsa::invoke] cannot be set on #[salsa::input] queries");
}
// Extract keys.
let mut iter = method.sig.inputs.iter();
match iter.next() {
Some(FnArg::Receiver(sr)) if sr.mutability.is_none() => (),
_ => panic!(
"first argument of query `{}` must be `&self`",
method.sig.ident
),
}
let mut keys: Vec<Type> = vec![];
for arg in iter {
match *arg {
FnArg::Typed(ref arg) => {
keys.push((*arg.ty).clone());
}
ref a => panic!("unsupported argument `{:?}` of `{}`", a, method.sig.ident),
}
}
// Extract value.
let value = match method.sig.output {
ReturnType::Type(_, ref ty) => ty.as_ref().clone(),
ref r => panic!(
"unsupported return type `{:?}` of `{}`",
r, method.sig.ident
),
};
// For `#[salsa::interned]` keys, we create a "lookup key" automatically.
//
// For a query like:
//
// fn foo(&self, x: Key1, y: Key2) -> u32
//
// we would create
//
// fn lookup_foo(&self, x: u32) -> (Key1, Key2)
let lookup_query = if let QueryStorage::Interned = storage {
let lookup_query_type = Ident::new(
&format!(
"{}LookupQuery",
method.sig.ident.to_string().to_camel_case()
),
Span::call_site(),
);
let lookup_fn_name = Ident::new(
&format!("lookup_{}", method.sig.ident.to_string()),
method.sig.ident.span(),
);
let keys = &keys;
let lookup_value: Type = parse_quote!((#(#keys),*));
let lookup_keys = vec![value.clone()];
Some(Query {
query_type: lookup_query_type,
fn_name: lookup_fn_name,
attrs: vec![], // FIXME -- some automatically generated docs on this method?
storage: QueryStorage::InternedLookup {
intern_query_type: query_type.clone(),
},
keys: lookup_keys,
value: lookup_value,
invoke: None,
cycle: cycle.clone(),
})
} else {
None
};
queries.push(Query {
query_type,
fn_name: method.sig.ident,
attrs,
storage,
keys,
value,
invoke,
cycle,
});
queries.extend(lookup_query);
}
_ => (),
}
}
let group_key = Ident::new(
&format!("{}GroupKey__", trait_name.to_string()),
Span::call_site(),
);
let group_storage = Ident::new(
&format!("{}GroupStorage__", trait_name.to_string()),
Span::call_site(),
);
let mut query_fn_declarations = proc_macro2::TokenStream::new();
let mut query_fn_definitions = proc_macro2::TokenStream::new();
let mut query_descriptor_variants = proc_macro2::TokenStream::new();
let mut group_data_elements = vec![];
let mut storage_fields = proc_macro2::TokenStream::new();
let mut storage_defaults = proc_macro2::TokenStream::new();
for query in &queries {
let key_names: &Vec<_> = &(0..query.keys.len())
.map(|i| Ident::new(&format!("key{}", i), Span::call_site()))
.collect();
let keys = &query.keys;
let value = &query.value;
let fn_name = &query.fn_name;
let qt = &query.query_type;
let attrs = &query.attrs;
query_fn_declarations.extend(quote! {
#(#attrs)*
fn #fn_name(&self, #(#key_names: #keys),*) -> #value;
});
// Special case: transparent queries don't create actual storage,
// just inline the definition
if let QueryStorage::Transparent = query.storage {
let invoke = query.invoke_tt();
query_fn_definitions.extend(quote! {
fn #fn_name(&self, #(#key_names: #keys),*) -> #value {
#invoke(self, #(#key_names),*)
}
});
continue;
}
query_fn_definitions.extend(quote! {
fn #fn_name(&self, #(#key_names: #keys),*) -> #value {
<Self as salsa::plumbing::GetQueryTable<#qt>>::get_query_table(self).get((#(#key_names),*))
}
});
// For input queries, we need `set_foo` etc
if let QueryStorage::Input = query.storage {
let set_fn_name = Ident::new(&format!("set_{}", fn_name), fn_name.span());
let set_with_durability_fn_name =
Ident::new(&format!("set_{}_with_durability", fn_name), fn_name.span());
let set_fn_docs = format!(
"
Set the value of the `{fn_name}` input.
See `{fn_name}` for details.
*Note:* Setting values will trigger cancellation
of any ongoing queries; this method blocks until
those queries have been cancelled.
",
fn_name = fn_name
);
let set_constant_fn_docs = format!(
"
Set the value of the `{fn_name}` input and promise
that its value will never change again.
See `{fn_name}` for details.
*Note:* Setting values will trigger cancellation
of any ongoing queries; this method blocks until
those queries have been cancelled.
",
fn_name = fn_name
);
query_fn_declarations.extend(quote! {
# [doc = #set_fn_docs]
fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value);
# [doc = #set_constant_fn_docs]
fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability);
});
query_fn_definitions.extend(quote! {
fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value) {
<Self as salsa::plumbing::GetQueryTable<#qt>>::get_query_table_mut(self).set((#(#key_names),*), value__)
}
fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability) {
<Self as salsa::plumbing::GetQueryTable<#qt>>::get_query_table_mut(self).set_with_durability((#(#key_names),*), value__, durability__)
}
});
}
// A variant for the group descriptor below
query_descriptor_variants.extend(quote! {
#fn_name((#(#keys),*)),
});
// Entry for the query group data tuple
group_data_elements.push(quote! {
(#(#keys,)* #value)
});
// A field for the storage struct
//
// FIXME(#120): the pub should not be necessary once we complete the transition
storage_fields.extend(quote! {
pub #fn_name: std::sync::Arc<<#qt as salsa::Query<DB__>>::Storage>,
});
storage_defaults.extend(quote! { #fn_name: Default::default(), });
}
// Emit the trait itself.
let mut output = {
let bounds = &input.supertraits;
quote! {
#(#trait_attrs)*
#trait_vis trait #trait_name : #bounds {
#query_fn_declarations
}
}
};
// Emit the query group struct and impl of `QueryGroup`.
output.extend(quote! {
/// Representative struct for the query group.
#trait_vis struct #group_struct { }
impl<DB__> salsa::plumbing::QueryGroup<DB__> for #group_struct
where
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
DB__: salsa::Database,
{
type GroupStorage = #group_storage<DB__>;
type GroupKey = #group_key;
type GroupData = (#(#group_data_elements),*);
}
});
// Emit an impl of the trait
output.extend({
let mut bounds = input.supertraits.clone();
for path in requires.clone() {
bounds.push(TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: TraitBoundModifier::None,
lifetimes: None,
path,
}));
}
quote! {
impl<T> #trait_name for T
where
T: #bounds,
T: salsa::plumbing::HasQueryGroup<#group_struct>
{
#query_fn_definitions
}
}
});
// Emit the query types.
for query in &queries {
let fn_name = &query.fn_name;
let qt = &query.query_type;
let db = quote! {DB};
let storage = match &query.storage {
QueryStorage::Memoized => quote!(salsa::plumbing::MemoizedStorage<#db, Self>),
QueryStorage::Dependencies => quote!(salsa::plumbing::DependencyStorage<#db, Self>),
QueryStorage::Input => quote!(salsa::plumbing::InputStorage<#db, Self>),
QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage<#db, Self>),
QueryStorage::InternedLookup { intern_query_type } => {
quote!(salsa::plumbing::LookupInternedStorage<#db, Self, #intern_query_type>)
}
QueryStorage::Transparent => continue,
};
let keys = &query.keys;
let value = &query.value;
// Emit the query struct and implement the Query trait on it.
output.extend(quote! {
#[derive(Default, Debug)]
#trait_vis struct #qt;
// Unsafe proof obligation: that our key/value are a part
// of the `GroupData`.
unsafe impl<#db> salsa::Query<#db> for #qt
where
DB: #trait_name + #requires,
DB: salsa::plumbing::HasQueryGroup<#group_struct>,
DB: salsa::Database,
{
type Key = (#(#keys),*);
type Value = #value;
type Storage = #storage;
type Group = #group_struct;
type GroupStorage = #group_storage<#db>;
type GroupKey = #group_key;
fn query_storage(
group_storage: &Self::GroupStorage,
) -> &std::sync::Arc<Self::Storage> {
&group_storage.#fn_name
}
fn group_key(key: Self::Key) -> Self::GroupKey {
#group_key::#fn_name(key)
}
}
});
// Implement the QueryFunction trait for queries which need it.
if query.storage.needs_query_function() {
let span = query.fn_name.span();
let key_names: &Vec<_> = &(0..query.keys.len())
.map(|i| Ident::new(&format!("key{}", i), Span::call_site()))
.collect();
let key_pattern = if query.keys.len() == 1 {
quote! { #(#key_names),* }
} else {
quote! { (#(#key_names),*) }
};
let invoke = query.invoke_tt();
let recover = if let Some(cycle_recovery_fn) = &query.cycle {
quote! {
fn recover(db: &DB, cycle: &[DB::DatabaseKey], #key_pattern: &<Self as salsa::Query<DB>>::Key)
-> Option<<Self as salsa::Query<DB>>::Value> {
Some(#cycle_recovery_fn(
db,
&cycle.iter().map(|k| format!("{:?}", k)).collect::<Vec<String>>(),
#(#key_names),*
))
}
}
} else {
quote! {}
};
output.extend(quote_spanned! {span=>
impl<DB> salsa::plumbing::QueryFunction<DB> for #qt
where
DB: #trait_name + #requires,
DB: salsa::plumbing::HasQueryGroup<#group_struct>,
DB: salsa::Database,
{
fn execute(db: &DB, #key_pattern: <Self as salsa::Query<DB>>::Key)
-> <Self as salsa::Query<DB>>::Value {
#invoke(db, #(#key_names),*)
}
#recover
}
});
}
}
// Emit query group descriptor
output.extend(quote! {
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[allow(non_camel_case_types)]
#trait_vis enum #group_key {
#query_descriptor_variants
}
});
let mut for_each_ops = proc_macro2::TokenStream::new();
for Query { fn_name, .. } in queries
.iter()
.filter(|q| q.storage != QueryStorage::Transparent)
{
for_each_ops.extend(quote! {
op(&*self.#fn_name);
});
}
// Emit query group storage struct
// It would derive Default, but then all database structs would have to implement Default
// as the derived version includes an unused `+ Default` constraint.
output.extend(quote! {
#trait_vis struct #group_storage<DB__>
where
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
DB__: salsa::Database,
{
#storage_fields
}
impl<DB__> Default for #group_storage<DB__>
where
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
DB__: salsa::Database,
{
#[inline]
fn default() -> Self {
#group_storage {
#storage_defaults
}
}
}
impl<DB__> #group_storage<DB__>
where
DB__: #trait_name + #requires,
DB__: salsa::plumbing::HasQueryGroup<#group_struct>,
{
#trait_vis fn for_each_query(
&self,
db: &DB__,
mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps<DB__>),
) {
#for_each_ops
}
}
});
if std::env::var("SALSA_DUMP").is_ok() {
println!("~~~ query_group");
println!("{}", output.to_string());
println!("~~~ query_group");
}
output.into()
}
struct SalsaAttr {
name: String,
tts: TokenStream,
}
impl TryFrom<syn::Attribute> for SalsaAttr {
type Error = syn::Attribute;
fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
if is_not_salsa_attr_path(&attr.path) {
return Err(attr);
}
let name = attr.path.segments[1].ident.to_string();
let tts = attr.tokens.into();
Ok(SalsaAttr { name, tts })
}
}
fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
path.segments
.first()
.map(|s| s.ident != "salsa")
.unwrap_or(true)
|| path.segments.len() != 2
}
fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
let mut other = vec![];
let mut salsa = vec![];
// Leave non-salsa attributes untouched. These are
// attributes that don't start with `salsa::` or don't have
// exactly two segments in their path.
// Keep the salsa attributes around.
for attr in attrs {
match SalsaAttr::try_from(attr) {
Ok(it) => salsa.push(it),
Err(it) => other.push(it),
}
}
(other, salsa)
}
#[derive(Debug)]
struct Query {
fn_name: Ident,
attrs: Vec<syn::Attribute>,
query_type: Ident,
storage: QueryStorage,
keys: Vec<syn::Type>,
value: syn::Type,
invoke: Option<syn::Path>,
cycle: Option<syn::Path>,
}
impl Query {
fn invoke_tt(&self) -> proc_macro2::TokenStream {
match &self.invoke {
Some(i) => i.into_token_stream(),
None => self.fn_name.clone().into_token_stream(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum QueryStorage {
Memoized,
Dependencies,
Input,
Interned,
InternedLookup { intern_query_type: Ident },
Transparent,
}
impl QueryStorage {
fn needs_query_function(&self) -> bool {
match self {
QueryStorage::Input
| QueryStorage::Interned
| QueryStorage::InternedLookup { .. }
| QueryStorage::Transparent => false,
QueryStorage::Memoized | QueryStorage::Dependencies => true,
}
}
}