From 5f0904ae4a9a54944fbf34300b380bed13b27621 Mon Sep 17 00:00:00 2001 From: David Barsky Date: Fri, 23 Aug 2024 14:30:59 -0400 Subject: [PATCH] introduce parallel salsa --- Cargo.toml | 1 + examples/calc/db.rs | 2 +- examples/lazy-input/main.rs | 15 +++-- src/database.rs | 2 +- src/database_impl.rs | 2 +- src/lib.rs | 2 + src/par_map.rs | 54 ++++++++++++++++ src/storage.rs | 6 +- src/zalsa.rs | 4 ++ tests/common/mod.rs | 14 +++-- tests/parallel/main.rs | 1 + tests/parallel/parallel_map.rs | 98 ++++++++++++++++++++++++++++++ tests/tracked_struct_durability.rs | 2 +- 13 files changed, 188 insertions(+), 15 deletions(-) create mode 100644 src/par_map.rs create mode 100644 tests/parallel/parallel_map.rs diff --git a/Cargo.toml b/Cargo.toml index 8a8e9063..5e1f5b55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" } salsa-macros = { path = "components/salsa-macros" } smallvec = "1" lazy_static = "1" +rayon = "1.10.0" [dev-dependencies] annotate-snippets = "0.11.4" diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 2873ed5b..924205c2 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex}; // ANCHOR: db_struct #[salsa::db] -#[derive(Default)] +#[derive(Default, Clone)] pub struct CalcDatabaseImpl { storage: salsa::Storage, diff --git a/examples/lazy-input/main.rs b/examples/lazy-input/main.rs index 792b7f34..ff998fa3 100644 --- a/examples/lazy-input/main.rs +++ b/examples/lazy-input/main.rs @@ -1,6 +1,10 @@ #![allow(unreachable_patterns)] // FIXME(rust-lang/rust#129031): regression in nightly -use std::{path::PathBuf, sync::Mutex, time::Duration}; +use std::{ + path::PathBuf, + sync::{Arc, Mutex}, + time::Duration, +}; use crossbeam::channel::{unbounded, Sender}; use dashmap::{mapref::entry::Entry, DashMap}; @@ -77,11 +81,12 @@ trait Db: salsa::Database { } #[salsa::db] +#[derive(Clone)] struct LazyInputDatabase { storage: Storage, - logs: Mutex>, + logs: Arc>>, files: DashMap, - file_watcher: Mutex>, + file_watcher: Arc>>, } impl LazyInputDatabase { @@ -90,7 +95,9 @@ impl LazyInputDatabase { storage: Default::default(), logs: Default::default(), files: DashMap::new(), - file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()), + file_watcher: Arc::new(Mutex::new( + new_debouncer(Duration::from_secs(1), tx).unwrap(), + )), } } } diff --git a/src/database.rs b/src/database.rs index 5a32bd9b..a978df0e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -90,7 +90,7 @@ impl dyn Database { /// /// # Panics /// - /// If the view has not been added to the database (see [`DatabaseView`][]) + /// If the view has not been added to the database (see [`crate::views::Views`]). #[track_caller] pub fn as_view(&self) -> &DbView { self.zalsa().views().try_view_as(self).unwrap() diff --git a/src/database_impl.rs b/src/database_impl.rs index 71da9fff..e31c6ed7 100644 --- a/src/database_impl.rs +++ b/src/database_impl.rs @@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage}; #[salsa::db] /// Default database implementation that you can use if you don't /// require any custom user data. -#[derive(Default)] +#[derive(Default, Clone)] pub struct DatabaseImpl { storage: Storage, } diff --git a/src/lib.rs b/src/lib.rs index 0ee7d3ec..8cc739fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ mod input; mod interned; mod key; mod nonce; +mod par_map; mod revision; mod runtime; mod salsa_struct; @@ -45,6 +46,7 @@ pub use self::storage::Storage; pub use self::update::Update; pub use self::zalsa::IngredientIndex; pub use crate::attach::with_attached_database; +pub use par_map::par_map; pub use salsa_macros::accumulator; pub use salsa_macros::db; pub use salsa_macros::input; diff --git a/src/par_map.rs b/src/par_map.rs new file mode 100644 index 00000000..1f93a1c5 --- /dev/null +++ b/src/par_map.rs @@ -0,0 +1,54 @@ +use std::ops::Deref; + +use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; + +use crate::Database; + +pub fn par_map( + db: &Db, + inputs: impl IntoParallelIterator, + op: fn(&Db, D) -> E, +) -> C +where + Db: Database + ?Sized, + D: Send, + E: Send + Sync, + C: FromParallelIterator, +{ + let parallel_db = ParallelDb::Ref(db.as_dyn_database()); + + inputs + .into_par_iter() + .map_with(parallel_db, |parallel_db, element| { + let db = parallel_db.as_view::(); + op(db, element) + }) + .collect() +} + +/// This enum _must not_ be public or used outside of `par_map`. +enum ParallelDb<'db> { + Ref(&'db dyn Database), + Fork(Box), +} + +/// SAFETY: the contents of the database are never accessed on the thread +/// where this wrapper type is created. +unsafe impl Send for ParallelDb<'_> {} + +impl Deref for ParallelDb<'_> { + type Target = dyn Database; + + fn deref(&self) -> &Self::Target { + match self { + ParallelDb::Ref(db) => *db, + ParallelDb::Fork(db) => db.as_dyn_database(), + } + } +} + +impl Clone for ParallelDb<'_> { + fn clone(&self) -> Self { + ParallelDb::Fork(self.fork_db()) + } +} diff --git a/src/storage.rs b/src/storage.rs index c9e1273b..40986291 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -15,7 +15,7 @@ use crate::{ /// /// The `storage` and `storage_mut` fields must both return a reference to the same /// storage field which must be owned by `self`. -pub unsafe trait HasStorage: Database + Sized { +pub unsafe trait HasStorage: Database + Clone + Sized { fn storage(&self) -> &Storage; fn storage_mut(&mut self) -> &mut Storage; } @@ -108,6 +108,10 @@ unsafe impl ZalsaDatabase for T { fn zalsa_local(&self) -> &ZalsaLocal { &self.storage().zalsa_local } + + fn fork_db(&self) -> Box { + Box::new(self.clone()) + } } impl RefUnwindSafe for Storage {} diff --git a/src/zalsa.rs b/src/zalsa.rs index 2f8fa95f..e92e2891 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -50,6 +50,10 @@ pub unsafe trait ZalsaDatabase: Any { /// Access the thread-local state associated with this database #[doc(hidden)] fn zalsa_local(&self) -> &ZalsaLocal; + + /// Clone the database. + #[doc(hidden)] + fn fork_db(&self) -> Box; } pub fn views(db: &Db) -> &Views { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4c4e9fc7..19f818b6 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -2,15 +2,17 @@ #![allow(dead_code)] +use std::sync::{Arc, Mutex}; + use salsa::{Database, Storage}; /// Logging userdata: provides [`LogDatabase`][] trait. /// /// If you wish to use it along with other userdata, /// you can also embed it in another struct and implement [`HasLogger`][] for that struct. -#[derive(Default)] +#[derive(Clone, Default)] pub struct Logger { - logs: std::sync::Mutex>, + logs: Arc>>, } /// Trait implemented by databases that lets them log events. @@ -48,7 +50,7 @@ impl LogDatabase for Db {} /// Database that provides logging but does not log salsa event. #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct LoggerDatabase { storage: Storage, logger: Logger, @@ -67,7 +69,7 @@ impl Database for LoggerDatabase { /// Database that provides logging and logs salsa events. #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct EventLoggerDatabase { storage: Storage, logger: Logger, @@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase { } #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct DiscardLoggerDatabase { storage: Storage, logger: Logger, @@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase { } #[salsa::db] -#[derive(Default)] +#[derive(Clone, Default)] pub struct ExecuteValidateLoggerDatabase { storage: Storage, logger: Logger, diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 578a83cb..e01e4654 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -5,4 +5,5 @@ mod parallel_cycle_all_recover; mod parallel_cycle_mid_recover; mod parallel_cycle_none_recover; mod parallel_cycle_one_recover; +mod parallel_map; mod signal; diff --git a/tests/parallel/parallel_map.rs b/tests/parallel/parallel_map.rs new file mode 100644 index 00000000..b1b6cd4e --- /dev/null +++ b/tests/parallel/parallel_map.rs @@ -0,0 +1,98 @@ +// test for rayon interations. + +use salsa::Cancelled; +use salsa::Setter; + +use crate::setup::Knobs; +use crate::setup::KnobsDatabase; + +#[salsa::input] +struct ParallelInput { + field: Vec, +} + +#[salsa::tracked] +fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec { + salsa::par_map(db, input.field(db), |_db, field| field + 1) +} + +#[test] +fn execute() { + let db = salsa::DatabaseImpl::new(); + + let counts = (1..=10).collect::>(); + let input = ParallelInput::new(&db, counts); + + tracked_fn(&db, input); +} + +#[salsa::tracked] +fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec { + db.signal(1); + salsa::par_map(db, input.field(db), |db, field| { + db.wait_for(2); + field + 1 + }) +} + +#[salsa::tracked] +fn dummy(_db: &dyn KnobsDatabase, _input: ParallelInput) -> ParallelInput { + panic!("should never get here!") +} + +// we expect this to panic, as `salsa::par_map` needs to be called from a query. +#[test] +#[should_panic] +fn direct_calls_panic() { + let db = salsa::DatabaseImpl::new(); + + let counts = (1..=10).collect::>(); + let input = ParallelInput::new(&db, counts); + let _: Vec = salsa::par_map(&db, input.field(&db), |_db, field| field + 1); +} + +// Cancellation signalling test +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 +// | wait for stage 1 +// signal stage 1 set input, triggers cancellation +// wait for stage 2 (blocks) triggering cancellation sends stage 2 +// | +// (unblocked) +// dummy +// panics + +#[test] +fn execute_cancellation() { + let mut db = Knobs::default(); + + let counts = (1..=10).collect::>(); + let input = ParallelInput::new(&db, counts); + + let thread_a = std::thread::spawn({ + let db = db.clone(); + move || a1(&db, input) + }); + + let counts = (2..=20).collect::>(); + + db.signal_on_did_cancel.store(2); + input.set_field(&mut db).to(counts); + + // Assert thread A *should* was cancelled + let cancelled = thread_a + .join() + .unwrap_err() + .downcast::() + .unwrap(); + + // and inspect the output + expect_test::expect![[r#" + PendingWrite + "#]] + .assert_debug_eq(&cancelled); +} diff --git a/tests/tracked_struct_durability.rs b/tests/tracked_struct_durability.rs index c1fb8b2b..f4ffaa8d 100644 --- a/tests/tracked_struct_durability.rs +++ b/tests/tracked_struct_durability.rs @@ -83,7 +83,7 @@ fn check<'db>(db: &'db dyn Db, file: File) -> Inference<'db> { #[test] fn execute() { #[salsa::db] - #[derive(Default)] + #[derive(Default, Clone)] struct Database { storage: salsa::Storage, files: Vec,