From 69a2cb8f934bd0f58443129d875ad4f0718298f6 Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Sat, 27 Jul 2024 19:57:47 +0000 Subject: [PATCH] merge handle into the database Separate handles are no longer needed. --- .../src/setup_input_struct.rs | 3 +- src/database.rs | 97 ++++++++++++-- src/handle.rs | 125 ------------------ src/lib.rs | 2 - src/storage.rs | 47 +++---- tests/parallel/parallel_cancellation.rs | 9 +- tests/parallel/parallel_cycle_all_recover.rs | 9 +- tests/parallel/parallel_cycle_mid_recover.rs | 10 +- tests/parallel/parallel_cycle_none_recover.rs | 13 +- tests/parallel/parallel_cycle_one_recover.rs | 10 +- tests/preverify-struct-with-leaked-data.rs | 1 + 11 files changed, 131 insertions(+), 195 deletions(-) delete mode 100644 src/handle.rs diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index ccdc60c9..d89d63d6 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -89,7 +89,8 @@ macro_rules! setup_input_struct { pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, $zalsa::Revision) { let zalsa_mut = db.zalsa_mut(); let index = zalsa_mut.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()); - let (ingredient, current_revision) = zalsa_mut.lookup_ingredient_mut(index); + let current_revision = zalsa_mut.current_revision(); + let ingredient = zalsa_mut.lookup_ingredient_mut(index); let ingredient = ingredient.assert_type_mut::<$zalsa_struct::IngredientImpl>(); (ingredient, current_revision) } diff --git a/src/database.rs b/src/database.rs index 55a2b6d6..45c5c515 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,9 +1,11 @@ -use std::{any::Any, panic::RefUnwindSafe}; +use std::{any::Any, panic::RefUnwindSafe, sync::Arc}; + +use parking_lot::{Condvar, Mutex}; use crate::{ self as salsa, local_state, storage::{Zalsa, ZalsaImpl}, - Durability, Event, Revision, + Durability, Event, EventKind, Revision, }; /// The trait implemented by all Salsa databases. @@ -34,7 +36,6 @@ pub unsafe trait Database: AsDynDatabase + Any { /// is owned by the current thread, this could trigger deadlock. fn synthetic_write(&mut self, durability: Durability) { let zalsa_mut = self.zalsa_mut(); - zalsa_mut.new_revision(); zalsa_mut.report_tracked_write(durability); } @@ -57,10 +58,14 @@ pub unsafe trait Database: AsDynDatabase + Any { local_state::attach(self, |_state| op(self)) } - /// Plumbing methods. + /// Plumbing method: Access the internal salsa methods. #[doc(hidden)] fn zalsa(&self) -> &dyn Zalsa; + /// Plumbing method: Access the internal salsa methods for mutating the database. + /// + /// **WARNING:** Triggers a new revision, canceling other database handles. + /// This can lead to deadlock! #[doc(hidden)] fn zalsa_mut(&mut self) -> &mut dyn Zalsa; } @@ -102,7 +107,11 @@ impl dyn Database { /// Concrete implementation of the [`Database`][] trait. /// Takes an optional type parameter `U` that allows you to thread your own data. pub struct DatabaseImpl { - storage: ZalsaImpl, + /// Reference to the database. This is always `Some` except during destruction. + zalsa_impl: Option>>, + + /// Coordination data. + coordinate: Arc, } impl Default for DatabaseImpl { @@ -116,9 +125,7 @@ impl DatabaseImpl<()> { /// /// You can also use the [`Default`][] trait if your userdata implements it. pub fn new() -> Self { - Self { - storage: ZalsaImpl::with(()), - } + Self::with(()) } } @@ -128,16 +135,47 @@ impl DatabaseImpl { /// You can also use the [`Default`][] trait if your userdata implements it. pub fn with(u: U) -> Self { Self { - storage: ZalsaImpl::with(u), + zalsa_impl: Some(Arc::new(ZalsaImpl::with(u))), + coordinate: Arc::new(Coordinate { + clones: Mutex::new(1), + cvar: Default::default(), + }), + } + } + + fn zalsa_impl(&self) -> &Arc> { + self.zalsa_impl.as_ref().unwrap() + } + + // ANCHOR: cancel_other_workers + /// Sets cancellation flag and blocks until all other workers with access + /// to this storage have completed. + /// + /// This could deadlock if there is a single worker with two handles to the + /// same database! + fn cancel_others(&mut self) { + let zalsa = self.zalsa_impl(); + zalsa.set_cancellation_flag(); + + self.salsa_event(&|| Event { + thread_id: std::thread::current().id(), + + kind: EventKind::DidSetCancellationFlag, + }); + + let mut clones = self.coordinate.clones.lock(); + while *clones != 1 { + self.coordinate.cvar.wait(&mut clones); } } + // ANCHOR_END: cancel_other_workers } impl std::ops::Deref for DatabaseImpl { type Target = U; fn deref(&self) -> &U { - &self.storage.user_data() + self.zalsa_impl().user_data() } } @@ -146,11 +184,17 @@ impl RefUnwindSafe for DatabaseImpl {} #[salsa_macros::db] unsafe impl Database for DatabaseImpl { fn zalsa(&self) -> &dyn Zalsa { - &self.storage + &**self.zalsa_impl() } fn zalsa_mut(&mut self) -> &mut dyn Zalsa { - &mut self.storage + self.cancel_others(); + + // The ref count on the `Arc` should now be 1 + let arc_zalsa_mut = self.zalsa_impl.as_mut().unwrap(); + let zalsa_mut = Arc::get_mut(arc_zalsa_mut).unwrap(); + zalsa_mut.new_revision(); + zalsa_mut } // Report a salsa event. @@ -159,6 +203,28 @@ unsafe impl Database for DatabaseImpl { } } +impl Clone for DatabaseImpl { + fn clone(&self) -> Self { + *self.coordinate.clones.lock() += 1; + + Self { + zalsa_impl: self.zalsa_impl.clone(), + coordinate: Arc::clone(&self.coordinate), + } + } +} + +impl Drop for DatabaseImpl { + fn drop(&mut self) { + // Drop the database handle *first* + self.zalsa_impl.take(); + + // *Now* decrement the number of clones and notify once we have completed + *self.coordinate.clones.lock() -= 1; + self.coordinate.cvar.notify_all(); + } +} + pub trait UserData: Any + Sized { /// Callback invoked by the [`Database`][] at key points during salsa execution. /// By overriding this method, you can inject logging or other custom behavior. @@ -174,3 +240,10 @@ pub trait UserData: Any + Sized { } impl UserData for () {} + +struct Coordinate { + /// Counter of the number of clones of actor. Begins at 1. + /// Incremented when cloned, decremented when dropped. + clones: Mutex, + cvar: Condvar, +} diff --git a/src/handle.rs b/src/handle.rs deleted file mode 100644 index e3c2ecd9..00000000 --- a/src/handle.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::sync::Arc; - -use parking_lot::{Condvar, Mutex}; - -use crate::{Database, Event, EventKind}; - -/// A database "handle" allows coordination of multiple async tasks accessing the same database. -/// So long as you are just doing reads, you can freely clone. -/// When you attempt to modify the database, you call `get_mut`, which will set the cancellation flag, -/// causing other handles to get panics. Once all other handles are dropped, you can proceed. -pub struct Handle { - /// Reference to the database. This is always `Some` except during destruction. - db: Option>, - - /// Coordination data. - coordinate: Arc, -} - -struct Coordinate { - /// Counter of the number of clones of actor. Begins at 1. - /// Incremented when cloned, decremented when dropped. - clones: Mutex, - cvar: Condvar, -} - -impl Handle { - /// Create a new handle wrapping `db`. - pub fn new(db: Db) -> Self { - Self { - db: Some(Arc::new(db)), - coordinate: Arc::new(Coordinate { - clones: Mutex::new(1), - cvar: Default::default(), - }), - } - } - - fn db(&self) -> &Arc { - self.db.as_ref().unwrap() - } - - fn db_mut(&mut self) -> &mut Arc { - self.db.as_mut().unwrap() - } - - /// Returns a mutable reference to the inner database. - /// If other handles are active, this method sets the cancellation flag - /// and blocks until they are dropped. - pub fn get_mut(&mut self) -> &mut Db { - self.cancel_others(); - - // Once cancellation above completes, the other handles are being dropped. - // However, because the signal is sent before the destructor completes, it's - // possible that they have not *yet* dropped. - // - // Therefore, we may have to do a (short) bit of - // spinning before we observe the thread-count reducing to 0. - // - // An alternative would be to - Arc::get_mut(self.db_mut()).expect("other threads remain active despite cancellation") - } - - /// Returns the inner database, consuming the handle. - /// - /// If other handles are active, this method sets the cancellation flag - /// and blocks until they are dropped. - pub fn into_inner(mut self) -> Db { - self.cancel_others(); - Arc::into_inner(self.db.take().unwrap()) - .expect("other threads remain active despite cancellation") - } - - // ANCHOR: cancel_other_workers - /// Sets cancellation flag and blocks until all other workers with access - /// to this storage have completed. - /// - /// This could deadlock if there is a single worker with two handles to the - /// same database! - fn cancel_others(&mut self) { - let zalsa = self.db().zalsa(); - zalsa.set_cancellation_flag(); - - self.db().salsa_event(&|| Event { - thread_id: std::thread::current().id(), - - kind: EventKind::DidSetCancellationFlag, - }); - - let mut clones = self.coordinate.clones.lock(); - while *clones != 1 { - self.coordinate.cvar.wait(&mut clones); - } - } - // ANCHOR_END: cancel_other_workers -} - -impl Drop for Handle { - fn drop(&mut self) { - // Drop the database handle *first* - self.db.take(); - - // *Now* decrement the number of clones and notify once we have completed - *self.coordinate.clones.lock() -= 1; - self.coordinate.cvar.notify_all(); - } -} - -impl std::ops::Deref for Handle { - type Target = Db; - - fn deref(&self) -> &Self::Target { - self.db() - } -} - -impl Clone for Handle { - fn clone(&self) -> Self { - *self.coordinate.clones.lock() += 1; - - Self { - db: Some(Arc::clone(self.db())), - coordinate: Arc::clone(&self.coordinate), - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 098cd9d9..8aa9685e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,6 @@ mod database; mod durability; mod event; mod function; -mod handle; mod hash; mod id; mod ingredient; @@ -36,7 +35,6 @@ pub use self::database::UserData; pub use self::durability::Durability; pub use self::event::Event; pub use self::event::EventKind; -pub use self::handle::Handle; pub use self::id::Id; pub use self::input::setter::Setter; pub use self::key::DatabaseKeyIndex; diff --git a/src/storage.rs b/src/storage.rs index 43ac7670..77fc6210 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -41,24 +41,16 @@ pub trait Zalsa { fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient; /// Gets an `&mut`-ref to an ingredient by index. - /// - /// **Triggers a new revision.** Returns the `&mut` reference - /// along with the new revision index. fn lookup_ingredient_mut( &mut self, index: IngredientIndex, - ) -> (&mut dyn Ingredient, Revision); + ) -> &mut dyn Ingredient; fn runtimex(&self) -> &Runtime; /// Return the current revision fn current_revision(&self) -> Revision; - /// Increment revision counter. - /// - /// **Triggers a new revision.** - fn new_revision(&mut self) -> Revision; - /// Return the time when an input of durability `durability` last changed fn last_changed_revision(&self, durability: Durability) -> Revision; @@ -126,22 +118,10 @@ impl Zalsa for ZalsaImpl { fn lookup_ingredient_mut( &mut self, index: IngredientIndex, - ) -> (&mut dyn Ingredient, Revision) { - let new_revision = self.runtime.new_revision(); - - for index in self.ingredients_requiring_reset.iter() { - self.ingredients_vec - .get_mut(index.as_usize()) - .unwrap() - .reset_for_new_revision(); - } - - ( - &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap(), - new_revision, - ) + ) -> &mut dyn Ingredient { + &mut **self.ingredients_vec.get_mut(index.as_usize()).unwrap() } - + fn current_revision(&self) -> Revision { self.runtime.current_revision() } @@ -165,10 +145,6 @@ impl Zalsa for ZalsaImpl { fn set_cancellation_flag(&self) { self.runtime.set_cancellation_flag() } - - fn new_revision(&mut self) -> Revision { - self.runtime.new_revision() - } } /// Nonce type representing the underlying database storage. @@ -264,6 +240,21 @@ impl ZalsaImpl { pub(crate) fn user_data(&self) -> &U { &self.user_data } + + /// Triggers a new revision. Invoked automatically when you call `zalsa_mut` + /// and so doesn't need to be called otherwise. + pub(crate) fn new_revision(&mut self) -> Revision { + let new_revision = self.runtime.new_revision(); + + for index in self.ingredients_requiring_reset.iter() { + self.ingredients_vec + .get_mut(index.as_usize()) + .unwrap() + .reset_for_new_revision(); + } + + new_revision + } } /// Caches a pointer to an ingredient in a database. diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index 0e35ab25..55f81e8d 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -4,7 +4,6 @@ use salsa::Cancelled; use salsa::DatabaseImpl; -use salsa::Handle; use salsa::Setter; use crate::setup::Knobs; @@ -44,17 +43,17 @@ fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> MyInput { #[test] fn execute() { - let mut db = Handle::new(>::default()); + let mut db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); - input.set_field(db.get_mut()).to(2); + input.set_field(&mut db).to(2); // Assert thread A *should* was cancelled let cancelled = thread_a diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs index 7706d6ec..ac20b504 100644 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ b/tests/parallel/parallel_cycle_all_recover.rs @@ -3,7 +3,6 @@ //! both intra and cross thread. use salsa::DatabaseImpl; -use salsa::Handle; use crate::setup::Knobs; use crate::setup::KnobsDatabase; @@ -87,19 +86,19 @@ fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i3 #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b1(&*db, input) + move || b1(&db, input) }); assert_eq!(thread_a.join().unwrap(), 11); diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs index 0c5e3475..8bca2f61 100644 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ b/tests/parallel/parallel_cycle_mid_recover.rs @@ -2,7 +2,7 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::{DatabaseImpl, Handle}; +use salsa::DatabaseImpl; use crate::setup::{Knobs, KnobsDatabase}; @@ -81,19 +81,19 @@ fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i3 #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b1(&*db, input) + move || b1(&db, input) }); // We expect that the recovery function yields diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs index 39b6299c..d74aa5b0 100644 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ b/tests/parallel/parallel_cycle_none_recover.rs @@ -5,9 +5,8 @@ use crate::setup::Knobs; use crate::setup::KnobsDatabase; use expect_test::expect; -use salsa::Database as _; +use salsa::Database; use salsa::DatabaseImpl; -use salsa::Handle; #[salsa::input] pub(crate) struct MyInput { @@ -38,19 +37,19 @@ pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, -1); + let input = MyInput::new(&db, -1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a(&*db, input) + move || a(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b(&*db, input) + move || b(&db, input) }); // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). @@ -64,7 +63,7 @@ fn execute() { b(0), ] "#]]; - expected.assert_debug_eq(&c.all_participants(&*db)); + expected.assert_debug_eq(&c.all_participants(&db)); } else { panic!("b failed in an unexpected way: {:?}", err_b); } diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs index 7a32d95c..2bf53857 100644 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ b/tests/parallel/parallel_cycle_one_recover.rs @@ -2,7 +2,7 @@ //! See `../cycles.rs` for a complete listing of cycle tests, //! both intra and cross thread. -use salsa::{DatabaseImpl, Handle}; +use salsa::DatabaseImpl; use crate::setup::{Knobs, KnobsDatabase}; @@ -70,19 +70,19 @@ pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { #[test] fn execute() { - let db = Handle::new(>::default()); + let db = >::default(); db.knobs().signal_on_will_block.store(3); - let input = MyInput::new(&*db, 1); + let input = MyInput::new(&db, 1); let thread_a = std::thread::spawn({ let db = db.clone(); - move || a1(&*db, input) + move || a1(&db, input) }); let thread_b = std::thread::spawn({ let db = db.clone(); - move || b1(&*db, input) + move || b1(&db, input) }); // We expect that the recovery function yields diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index 2c5bdfd5..99391709 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -66,6 +66,7 @@ fn test_leaked_inputs_ignored() { let result_in_rev_2 = function(&db, input); db.assert_logs(expect![[r#" [ + "Event { thread_id: ThreadId(2), kind: DidSetCancellationFlag }", "Event { thread_id: ThreadId(2), kind: WillCheckCancellation }", "Event { thread_id: ThreadId(2), kind: WillExecute { database_key: function(0) } }", ]"#]]);