diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 6ebac551..90d52e81 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -1,4 +1,5 @@ use proc_macro2::{Literal, Span, TokenStream}; +use quote::ToTokens; use syn::{spanned::Spanned, ItemFn}; use crate::{db_lifetime, fn_util, hygiene::Hygiene, options::Options}; @@ -154,7 +155,8 @@ impl Macro { )); } - let (db_ident, db_path) = check_db_argument(&item.sig.inputs[0])?; + let (db_ident, db_path) = + check_db_argument(&item.sig.inputs[0], item.sig.generics.lifetimes().next())?; Ok(ValidFn { db_ident, db_path }) } @@ -202,6 +204,7 @@ fn function_type(item_fn: &syn::ItemFn) -> FunctionType { pub fn check_db_argument<'arg>( fn_arg: &'arg syn::FnArg, + explicit_lt: Option<&'arg syn::LifetimeParam>, ) -> syn::Result<(&'arg syn::Ident, &'arg syn::Path)> { match fn_arg { syn::FnArg::Receiver(_) => { @@ -256,11 +259,23 @@ pub fn check_db_argument<'arg>( )); } - let extract_db_path = || -> Result<&'arg syn::Path, Span> { - let syn::Type::Reference(ref_type) = &*typed.ty else { - return Err(typed.ty.span()); - }; + let tykind_error_msg = + "must have type `&dyn Db`, where `Db` is some Salsa Database trait"; + let syn::Type::Reference(ref_type) = &*typed.ty else { + return Err(syn::Error::new(typed.ty.span(), tykind_error_msg)); + }; + + if let Some(lt) = explicit_lt { + if ref_type.lifetime.is_none() { + return Err(syn::Error::new_spanned( + ref_type.and_token, + format!("must have a `{}` lifetime", lt.lifetime.to_token_stream()), + )); + } + } + + let extract_db_path = || -> Result<&'arg syn::Path, Span> { if let Some(m) = &ref_type.mutability { return Err(m.span()); } diff --git a/tests/compile-fail/tracked_fn_incompatibles.rs b/tests/compile-fail/tracked_fn_incompatibles.rs index 309e4fba..dea5e7d1 100644 --- a/tests/compile-fail/tracked_fn_incompatibles.rs +++ b/tests/compile-fail/tracked_fn_incompatibles.rs @@ -34,4 +34,30 @@ fn tracked_fn_with_too_many_arguments_for_specify( ) -> u32 { } +#[salsa::interned] +struct MyInterned<'db> { + field: u32, +} + +#[salsa::tracked] +fn tracked_fn_with_lt_param_and_elided_lt_on_db_arg<'db>( + db: &dyn Db, + interned: MyInterned<'db>, +) -> u32 { + interned.field(db) * 2 +} + +#[salsa::tracked] +fn tracked_fn_with_lt_param_and_elided_lt_on_input<'db>( + db: &'db dyn Db, + interned: MyInterned, +) -> u32 { + interned.field(db) * 2 +} + +#[salsa::tracked] +fn tracked_fn_with_multiple_lts<'db1, 'db2>(db: &'db1 dyn Db, interned: MyInterned<'db2>) -> u32 { + interned.field(db) * 2 +} + fn main() {} diff --git a/tests/compile-fail/tracked_fn_incompatibles.stderr b/tests/compile-fail/tracked_fn_incompatibles.stderr index 5851bf7e..9efda648 100644 --- a/tests/compile-fail/tracked_fn_incompatibles.stderr +++ b/tests/compile-fail/tracked_fn_incompatibles.stderr @@ -28,6 +28,29 @@ error: only functions with a single salsa struct as their input can be specified 29 | #[salsa::tracked(specify)] | ^^^^^^^ +error: must have a `'db` lifetime + --> tests/compile-fail/tracked_fn_incompatibles.rs:44:9 + | +44 | db: &dyn Db, + | ^ + +error: only a single lifetime parameter is accepted + --> tests/compile-fail/tracked_fn_incompatibles.rs:59:39 + | +59 | fn tracked_fn_with_multiple_lts<'db1, 'db2>(db: &'db1 dyn Db, interned: MyInterned<'db2>) -> u32 { + | ^^^^ + +error[E0106]: missing lifetime specifier + --> tests/compile-fail/tracked_fn_incompatibles.rs:53:15 + | +53 | interned: MyInterned, + | ^^^^^^^^^^ expected named lifetime parameter + | +help: consider using the `'db` lifetime + | +53 | interned: MyInterned<'db>, + | +++++ + error[E0308]: mismatched types --> tests/compile-fail/tracked_fn_incompatibles.rs:24:46 | diff --git a/tests/tracked_fn_read_own_specify.rs b/tests/tracked_fn_read_own_specify.rs index 426d18a7..c91bac60 100644 --- a/tests/tracked_fn_read_own_specify.rs +++ b/tests/tracked_fn_read_own_specify.rs @@ -22,7 +22,7 @@ fn tracked_fn(db: &dyn LogDatabase, input: MyInput) -> u32 { } #[salsa::tracked(specify)] -fn tracked_fn_extra<'db>(db: &dyn LogDatabase, input: MyTracked<'db>) -> u32 { +fn tracked_fn_extra<'db>(db: &'db dyn LogDatabase, input: MyTracked<'db>) -> u32 { db.push_log(format!("tracked_fn_extra({input:?})")); 0 }