Skip to content

Commit

Permalink
introduce parallel salsa
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbarsky committed Oct 15, 2024
1 parent af2ec49 commit 132ce16
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
salsa-macros = { path = "components/salsa-macros" }
smallvec = "1"
lazy_static = "1"
rayon = "1.10.0"

[dev-dependencies]
annotate-snippets = "0.11.4"
Expand Down
2 changes: 1 addition & 1 deletion examples/calc/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex};

// ANCHOR: db_struct
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
pub struct CalcDatabaseImpl {
storage: salsa::Storage<Self>,

Expand Down
15 changes: 11 additions & 4 deletions examples/lazy-input/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#![allow(unreachable_patterns)]
// FIXME(rust-lang/rust#129031): regression in nightly
use std::{path::PathBuf, sync::Mutex, time::Duration};
use std::{
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};

use crossbeam::channel::{unbounded, Sender};
use dashmap::{mapref::entry::Entry, DashMap};
Expand Down Expand Up @@ -77,11 +81,12 @@ trait Db: salsa::Database {
}

#[salsa::db]
#[derive(Clone)]
struct LazyInputDatabase {
storage: Storage<Self>,
logs: Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
files: DashMap<PathBuf, File>,
file_watcher: Mutex<Debouncer<RecommendedWatcher>>,
file_watcher: Arc<Mutex<Debouncer<RecommendedWatcher>>>,
}

impl LazyInputDatabase {
Expand All @@ -90,7 +95,9 @@ impl LazyInputDatabase {
storage: Default::default(),
logs: Default::default(),
files: DashMap::new(),
file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()),
file_watcher: Arc::new(Mutex::new(
new_debouncer(Duration::from_secs(1), tx).unwrap(),
)),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl dyn Database {
///
/// # Panics
///
/// If the view has not been added to the database (see [`DatabaseView`][])
/// If the view has not been added to the database (see [`crate::views::Views`]).
#[track_caller]
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
self.zalsa().views().try_view_as(self).unwrap()
Expand Down
2 changes: 1 addition & 1 deletion src/database_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage};
#[salsa::db]
/// Default database implementation that you can use if you don't
/// require any custom user data.
#[derive(Default)]
#[derive(Default, Clone)]
pub struct DatabaseImpl {
storage: Storage<Self>,
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod input;
mod interned;
mod key;
mod nonce;
mod par_map;
mod revision;
mod runtime;
mod salsa_struct;
Expand Down Expand Up @@ -45,6 +46,7 @@ pub use self::storage::Storage;
pub use self::update::Update;
pub use self::zalsa::IngredientIndex;
pub use crate::attach::with_attached_database;
pub use par_map::par_map;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::input;
Expand Down
54 changes: 54 additions & 0 deletions src/par_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::ops::Deref;

use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};

use crate::Database;

pub fn par_map<Db, D, E, C>(
db: &Db,
inputs: impl IntoParallelIterator<Item = D>,
op: fn(&Db, D) -> E,
) -> C
where
Db: Database + ?Sized,
D: Send,
E: Send + Sync,
C: FromParallelIterator<E>,
{
let parallel_db = ParallelDb::Ref(db.as_dyn_database());

inputs
.into_par_iter()
.map_with(parallel_db, |parallel_db, element| {
let db = parallel_db.as_view::<Db>();
op(db, element)
})
.collect()
}

/// This enum _must not_ be public or used outside of `par_map`.
enum ParallelDb<'db> {
Ref(&'db dyn Database),
Fork(Box<dyn Database + Send>),
}

/// SAFETY: the contents of the database are never accessed on the thread
/// where this wrapper type is created.
unsafe impl<'db> Send for ParallelDb<'db> {}

impl<'db> Deref for ParallelDb<'db> {
type Target = dyn Database;

fn deref(&self) -> &Self::Target {
match self {
ParallelDb::Ref(db) => *db,
ParallelDb::Fork(db) => db.as_dyn_database(),
}
}
}

impl Clone for ParallelDb<'_> {
fn clone(&self) -> Self {
ParallelDb::Fork(self.fork_db())
}
}
6 changes: 5 additions & 1 deletion src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
///
/// The `storage` and `storage_mut` fields must both return a reference to the same
/// storage field which must be owned by `self`.
pub unsafe trait HasStorage: Database + Sized {
pub unsafe trait HasStorage: Database + Clone + Sized {
fn storage(&self) -> &Storage<Self>;
fn storage_mut(&mut self) -> &mut Storage<Self>;
}
Expand Down Expand Up @@ -108,6 +108,10 @@ unsafe impl<T: HasStorage> ZalsaDatabase for T {
fn zalsa_local(&self) -> &ZalsaLocal {
&self.storage().zalsa_local
}

fn fork_db(&self) -> Box<dyn Database> {
Box::new(self.clone())
}
}

impl<Db: Database> RefUnwindSafe for Storage<Db> {}
Expand Down
4 changes: 4 additions & 0 deletions src/zalsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ pub unsafe trait ZalsaDatabase: Any {
/// Access the thread-local state associated with this database
#[doc(hidden)]
fn zalsa_local(&self) -> &ZalsaLocal;

/// Clone the database.
#[doc(hidden)]
fn fork_db(&self) -> Box<dyn Database>;
}

pub fn views<Db: ?Sized + Database>(db: &Db) -> &Views {
Expand Down
14 changes: 8 additions & 6 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
#![allow(dead_code)]

use std::sync::{Arc, Mutex};

use salsa::{Database, Storage};

/// Logging userdata: provides [`LogDatabase`][] trait.
///
/// If you wish to use it along with other userdata,
/// you can also embed it in another struct and implement [`HasLogger`][] for that struct.
#[derive(Default)]
#[derive(Clone, Default)]
pub struct Logger {
logs: std::sync::Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
}

/// Trait implemented by databases that lets them log events.
Expand Down Expand Up @@ -48,7 +50,7 @@ impl<Db: HasLogger + Database> LogDatabase for Db {}

/// Database that provides logging but does not log salsa event.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct LoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -67,7 +69,7 @@ impl Database for LoggerDatabase {

/// Database that provides logging and logs salsa events.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct EventLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct DiscardLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct ExecuteValidateLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand Down
1 change: 1 addition & 0 deletions tests/parallel/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recover;
mod parallel_map;
mod signal;
98 changes: 98 additions & 0 deletions tests/parallel/parallel_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// test for rayon interations.

use salsa::Cancelled;
use salsa::Setter;

use crate::setup::Knobs;
use crate::setup::KnobsDatabase;

#[salsa::input]
struct ParallelInput {
field: Vec<u32>,
}

#[salsa::tracked]
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec<u32> {
salsa::par_map(db, input.field(db), |_db, field| field + 1)
}

#[test]
fn execute() {
let db = salsa::DatabaseImpl::new();

let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);

tracked_fn(&db, input);
}

#[salsa::tracked]
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec<u32> {
db.signal(1);
salsa::par_map(db, input.field(db), |db, field| {
db.wait_for(2);
field + 1
})
}

#[salsa::tracked]
fn dummy(_db: &dyn KnobsDatabase, _input: ParallelInput) -> ParallelInput {
panic!("should never get here!")
}

// we expect this to panic, as `salsa::par_map` needs to be called from a query.
#[test]
#[should_panic]
fn direct_calls_panic() {
let db = salsa::DatabaseImpl::new();

let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);
let _: Vec<u32> = salsa::par_map(&db, input.field(&db), |_db, field| field + 1);
}

// Cancellation signalling test
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1
// | wait for stage 1
// signal stage 1 set input, triggers cancellation
// wait for stage 2 (blocks) triggering cancellation sends stage 2
// |
// (unblocked)
// dummy
// panics

#[test]
fn execute_cancellation() {
let mut db = Knobs::default();

let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);

let thread_a = std::thread::spawn({
let db = db.clone();
move || a1(&db, input)
});

let counts = (2..=20).collect::<Vec<u32>>();

db.signal_on_did_cancel.store(2);
input.set_field(&mut db).to(counts);

// Assert thread A *should* was cancelled
let cancelled = thread_a
.join()
.unwrap_err()
.downcast::<Cancelled>()
.unwrap();

// and inspect the output
expect_test::expect![[r#"
PendingWrite
"#]]
.assert_debug_eq(&cancelled);
}
2 changes: 1 addition & 1 deletion tests/tracked_struct_durability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn check<'db>(db: &'db dyn Db, file: File) -> Inference<'db> {
#[test]
fn execute() {
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
struct Database {
storage: salsa::Storage<Self>,
files: Vec<File>,
Expand Down

0 comments on commit 132ce16

Please sign in to comment.