diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index c9e94df6f..c4d587d91 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -48,6 +48,7 @@ macro_rules! setup_input_struct { $zalsa:ident, $zalsa_struct:ident, $Configuration:ident, + $Builder:ident, $CACHE:ident, $Db:ident, ] @@ -121,14 +122,24 @@ macro_rules! setup_input_struct { } impl $Struct { + #[inline] pub fn $new_fn<$Db>(db: &$Db, $($field_id: $field_ty),*) -> Self + where + // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` + $Db: ?Sized + salsa::Database, + { + Self::builder(db).new($($field_id,)*) + } + + pub fn builder<'db, $Db>(db: &'db $Db) -> $Builder<'db> where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, { let current_revision = $zalsa::current_revision(db); - let stamps = $zalsa::Array::new([$zalsa::stamp(current_revision, Default::default()); $N]); - $Configuration::ingredient(db.as_salsa_database()).new_input(($($field_id,)*), stamps) + $Builder { + inner: $zalsa_struct::BuilderImpl::new(current_revision, $Configuration::ingredient(db.as_salsa_database())), + } } $( @@ -204,6 +215,24 @@ macro_rules! setup_input_struct { }) } } + + pub struct $Builder<'db> { + #[doc(hidden)] + inner: $zalsa_struct::BuilderImpl<'db, $Configuration>, + } + + impl<'db> $Builder<'db> { + /// Sets the durability for all fields. + pub fn durability(mut self, durability: $zalsa::Durability) -> Self { + self.inner.durability(durability); + self + } + + pub fn new(self, $($field_id: $field_ty),*) -> $Struct { + self.inner.build(($($field_id,)*)) + } + } + }; }; } diff --git a/components/salsa-macro-rules/src/setup_struct_fn.rs b/components/salsa-macro-rules/src/setup_struct_fn.rs deleted file mode 100644 index 8b1378917..000000000 --- a/components/salsa-macro-rules/src/setup_struct_fn.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 2ff9f18b5..4040bf610 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -93,6 +93,7 @@ impl Macro { let zalsa = self.hygiene.ident("zalsa"); let zalsa_struct = self.hygiene.ident("zalsa_struct"); let Configuration = self.hygiene.ident("Configuration"); + let Builder = self.hygiene.ident("Builder"); let CACHE = self.hygiene.ident("CACHE"); let Db = self.hygiene.ident("Db"); @@ -117,6 +118,7 @@ impl Macro { #zalsa, #zalsa_struct, #Configuration, + #Builder, #CACHE, #Db, ] diff --git a/src/database.rs b/src/database.rs index ce4514f3d..fc1fe5195 100644 --- a/src/database.rs +++ b/src/database.rs @@ -21,7 +21,9 @@ pub trait Database: DatabaseGen { /// will block until that snapshot is dropped -- if that snapshot /// is owned by the current thread, this could trigger deadlock. fn synthetic_write(&mut self, durability: Durability) { - self.runtime_mut().report_tracked_write(durability); + let runtime = self.runtime_mut(); + runtime.new_revision(); + runtime.report_tracked_write(durability); } /// Reports that the query depends on some state unknown to salsa. diff --git a/src/input.rs b/src/input.rs index ba1bd41a3..aef37adfb 100644 --- a/src/input.rs +++ b/src/input.rs @@ -5,6 +5,7 @@ use std::{ sync::atomic::{AtomicU32, Ordering}, }; +pub mod builder; pub mod input_field; pub mod setter; mod struct_map; diff --git a/src/input/builder/mod.rs b/src/input/builder/mod.rs new file mode 100644 index 000000000..d392b2b43 --- /dev/null +++ b/src/input/builder/mod.rs @@ -0,0 +1,41 @@ +use super::{Configuration, IngredientImpl}; +use crate::plumbing::Array; +use crate::runtime::Stamp; +use crate::{Durability, Revision}; + +pub struct BuilderImpl<'builder, C> +where + C: Configuration, +{ + stamps: C::Stamps, + + ingredient: &'builder IngredientImpl, +} + +impl<'builder, const N: usize, C> BuilderImpl<'builder, C> +where + C: Configuration>, +{ + pub fn new(revision: Revision, ingredient: &'builder IngredientImpl) -> Self { + Self { + ingredient, + stamps: Array::new([crate::plumbing::stamp(revision, Durability::default()); N]), + } + } + + /// Sets the durability of a specific field. + pub fn set_field_durability(&mut self, index: usize, durability: Durability) -> &mut Self { + self.stamps[index].durability = durability; + self + } + + pub fn durability(&mut self, durability: Durability) { + for stamp in &mut *self.stamps { + stamp.durability = durability; + } + } + + pub fn build(self, fields: C::Fields) -> C::Struct { + self.ingredient.new_input(fields, self.stamps) + } +} diff --git a/src/lib.rs b/src/lib.rs index 7d1397a13..f69bc3ed1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,6 +106,7 @@ pub mod plumbing { pub use crate::update::helper::Dispatch as UpdateDispatch; pub use crate::update::helper::Fallback as UpdateFallback; pub use crate::update::Update; + pub use crate::Durability; pub use salsa_macro_rules::macro_if; pub use salsa_macro_rules::maybe_backdate; @@ -125,6 +126,7 @@ pub mod plumbing { } pub mod input { + pub use crate::input::builder::BuilderImpl; pub use crate::input::input_field::FieldIngredientImpl; pub use crate::input::setter::SetterImpl; pub use crate::input::Configuration; diff --git a/tests/accumulate-from-tracked-fn.rs b/tests/accumulate-from-tracked-fn.rs index a60f08b8c..61d251cae 100644 --- a/tests/accumulate-from-tracked-fn.rs +++ b/tests/accumulate-from-tracked-fn.rs @@ -2,13 +2,13 @@ //! Then mutate the values so that the tracked function re-executes. //! Check that we accumulate the appropriate, new values. -mod common; -use common::{HasLogger, Logger}; - use expect_test::expect; -use salsa::{Accumulator, Setter}; use test_log::test; +use common::{HasLogger, Logger}; +use salsa::{Accumulator, Setter}; + +mod common; #[salsa::db] trait Db: salsa::Database + HasLogger {} diff --git a/tests/tracked_fn_on_input_with_high_durability.rs b/tests/tracked_fn_on_input_with_high_durability.rs new file mode 100644 index 000000000..33aa31ace --- /dev/null +++ b/tests/tracked_fn_on_input_with_high_durability.rs @@ -0,0 +1,77 @@ +//! Test that a `tracked` fn on a `salsa::input` +//! compiles and executes successfully. +#![allow(warnings)] + +use expect_test::expect; + +use common::{HasLogger, Logger}; +use salsa::plumbing::HasStorage; +use salsa::{Database, Durability, Event, EventKind, Setter}; + +mod common; +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> u32 { + input.field(db) * 2 +} + +#[test] +fn execute() { + #[salsa::db] + #[derive(Default)] + struct Database { + storage: salsa::Storage, + logger: Logger, + } + + #[salsa::db] + impl salsa::Database for Database { + fn salsa_event(&self, event: Event) { + match event.kind { + EventKind::WillCheckCancellation => {} + _ => { + self.push_log(format!("salsa_event({:?})", event.kind)); + } + } + } + } + + impl HasLogger for Database { + fn logger(&self) -> &Logger { + &self.logger + } + } + + let mut db = Database::default(); + let input_low = MyInput::new(&db, 22); + let input_high = MyInput::builder(&db).durability(Durability::HIGH).new(2200); + + assert_eq!(tracked_fn(&db, input_low), 44); + assert_eq!(tracked_fn(&db, input_high), 4400); + + db.assert_logs(expect![[r#" + [ + "salsa_event(WillExecute { database_key: tracked_fn(0) })", + "salsa_event(WillExecute { database_key: tracked_fn(1) })", + ]"#]]); + + db.synthetic_write(Durability::LOW); + + assert_eq!(tracked_fn(&db, input_low), 44); + assert_eq!(tracked_fn(&db, input_high), 4400); + + // There's currently no good way to verify whether an input was validated using shallow or deep comparison. + // All we can do for now is verify that the values were validated. + // Note: It maybe confusing why it validates `input_high` when the write has `Durability::LOW`. + // This is because all values must be validated whenever a write occurs. It doesn't mean that it + // executed the query. + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: tracked_fn(0) })", + "salsa_event(DidValidateMemoizedValue { database_key: tracked_fn(1) })", + ]"#]]); +}