-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
af2ec49
commit 132ce16
Showing
13 changed files
with
188 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
use std::ops::Deref; | ||
|
||
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; | ||
|
||
use crate::Database; | ||
|
||
pub fn par_map<Db, D, E, C>( | ||
db: &Db, | ||
inputs: impl IntoParallelIterator<Item = D>, | ||
op: fn(&Db, D) -> E, | ||
) -> C | ||
where | ||
Db: Database + ?Sized, | ||
D: Send, | ||
E: Send + Sync, | ||
C: FromParallelIterator<E>, | ||
{ | ||
let parallel_db = ParallelDb::Ref(db.as_dyn_database()); | ||
|
||
inputs | ||
.into_par_iter() | ||
.map_with(parallel_db, |parallel_db, element| { | ||
let db = parallel_db.as_view::<Db>(); | ||
op(db, element) | ||
}) | ||
.collect() | ||
} | ||
|
||
/// This enum _must not_ be public or used outside of `par_map`. | ||
enum ParallelDb<'db> { | ||
Ref(&'db dyn Database), | ||
Fork(Box<dyn Database + Send>), | ||
} | ||
|
||
/// SAFETY: the contents of the database are never accessed on the thread | ||
/// where this wrapper type is created. | ||
unsafe impl<'db> Send for ParallelDb<'db> {} | ||
|
||
impl<'db> Deref for ParallelDb<'db> { | ||
type Target = dyn Database; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
match self { | ||
ParallelDb::Ref(db) => *db, | ||
ParallelDb::Fork(db) => db.as_dyn_database(), | ||
} | ||
} | ||
} | ||
|
||
impl Clone for ParallelDb<'_> { | ||
fn clone(&self) -> Self { | ||
ParallelDb::Fork(self.fork_db()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// test for rayon interations. | ||
|
||
use salsa::Cancelled; | ||
use salsa::Setter; | ||
|
||
use crate::setup::Knobs; | ||
use crate::setup::KnobsDatabase; | ||
|
||
#[salsa::input] | ||
struct ParallelInput { | ||
field: Vec<u32>, | ||
} | ||
|
||
#[salsa::tracked] | ||
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec<u32> { | ||
salsa::par_map(db, input.field(db), |_db, field| field + 1) | ||
} | ||
|
||
#[test] | ||
fn execute() { | ||
let db = salsa::DatabaseImpl::new(); | ||
|
||
let counts = (1..=10).collect::<Vec<u32>>(); | ||
let input = ParallelInput::new(&db, counts); | ||
|
||
tracked_fn(&db, input); | ||
} | ||
|
||
#[salsa::tracked] | ||
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec<u32> { | ||
db.signal(1); | ||
salsa::par_map(db, input.field(db), |db, field| { | ||
db.wait_for(2); | ||
field + 1 | ||
}) | ||
} | ||
|
||
#[salsa::tracked] | ||
fn dummy(_db: &dyn KnobsDatabase, _input: ParallelInput) -> ParallelInput { | ||
panic!("should never get here!") | ||
} | ||
|
||
// we expect this to panic, as `salsa::par_map` needs to be called from a query. | ||
#[test] | ||
#[should_panic] | ||
fn direct_calls_panic() { | ||
let db = salsa::DatabaseImpl::new(); | ||
|
||
let counts = (1..=10).collect::<Vec<u32>>(); | ||
let input = ParallelInput::new(&db, counts); | ||
let _: Vec<u32> = salsa::par_map(&db, input.field(&db), |_db, field| field + 1); | ||
} | ||
|
||
// Cancellation signalling test | ||
// | ||
// The pattern is as follows. | ||
// | ||
// Thread A Thread B | ||
// -------- -------- | ||
// a1 | ||
// | wait for stage 1 | ||
// signal stage 1 set input, triggers cancellation | ||
// wait for stage 2 (blocks) triggering cancellation sends stage 2 | ||
// | | ||
// (unblocked) | ||
// dummy | ||
// panics | ||
|
||
#[test] | ||
fn execute_cancellation() { | ||
let mut db = Knobs::default(); | ||
|
||
let counts = (1..=10).collect::<Vec<u32>>(); | ||
let input = ParallelInput::new(&db, counts); | ||
|
||
let thread_a = std::thread::spawn({ | ||
let db = db.clone(); | ||
move || a1(&db, input) | ||
}); | ||
|
||
let counts = (2..=20).collect::<Vec<u32>>(); | ||
|
||
db.signal_on_did_cancel.store(2); | ||
input.set_field(&mut db).to(counts); | ||
|
||
// Assert thread A *should* was cancelled | ||
let cancelled = thread_a | ||
.join() | ||
.unwrap_err() | ||
.downcast::<Cancelled>() | ||
.unwrap(); | ||
|
||
// and inspect the output | ||
expect_test::expect![[r#" | ||
PendingWrite | ||
"#]] | ||
.assert_debug_eq(&cancelled); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters