Skip to content

Commit

Permalink
move local-state into DatabaseImpl
Browse files Browse the repository at this point in the history
Each clone gets an independent local state.
  • Loading branch information
nikomatsakis committed Jul 28, 2024
1 parent f3480bd commit 22a8157
Show file tree
Hide file tree
Showing 17 changed files with 449 additions and 466 deletions.
48 changes: 23 additions & 25 deletions src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
hash::FxDashMap,
ingredient::{fmt_index, Ingredient, Jar},
key::DependencyIndex,
local_state::{self, LocalState, QueryOrigin},
local_state::{LocalState, QueryOrigin},
storage::IngredientIndex,
Database, DatabaseKeyIndex, Event, EventKind, Id, Revision,
};
Expand Down Expand Up @@ -80,32 +80,30 @@ impl<A: Accumulator> IngredientImpl<A> {
}

pub fn push(&self, db: &dyn crate::Database, value: A) {
local_state::attach(db, |state| {
let current_revision = db.zalsa().current_revision();
let (active_query, _) = match state.active_query() {
Some(pair) => pair,
None => {
panic!("cannot accumulate values outside of an active query")
}
};

let mut accumulated_values =
self.map.entry(active_query).or_insert(AccumulatedValues {
values: vec![],
produced_at: current_revision,
});

// When we call `push' in a query, we will add the accumulator to the output of the query.
// If we find here that this accumulator is not the output of the query,
// we can say that the accumulated values we stored for this query is out of date.
if !state.is_output_of_active_query(self.dependency_index()) {
accumulated_values.values.truncate(0);
accumulated_values.produced_at = current_revision;
let state = db.zalsa_local();
let current_revision = db.zalsa().current_revision();
let (active_query, _) = match state.active_query() {
Some(pair) => pair,
None => {
panic!("cannot accumulate values outside of an active query")
}
};

let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues {
values: vec![],
produced_at: current_revision,
});

// When we call `push' in a query, we will add the accumulator to the output of the query.
// If we find here that this accumulator is not the output of the query,
// we can say that the accumulated values we stored for this query is out of date.
if !state.is_output_of_active_query(self.dependency_index()) {
accumulated_values.values.truncate(0);
accumulated_values.produced_at = current_revision;
}

state.add_output(self.dependency_index());
accumulated_values.values.push(value);
})
state.add_output(self.dependency_index());
accumulated_values.values.push(value);
}

pub(crate) fn produced_by(
Expand Down
95 changes: 95 additions & 0 deletions src/attach.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::{cell::Cell, ptr::NonNull};

use crate::Database;

thread_local! {
/// The thread-local state salsa requires for a given thread
static ATTACHED: Attached = const { Attached::new() }
}

/// State that is specific to a single execution thread.
///
/// Internally, this type uses ref-cells.
///
/// **Note also that all mutations to the database handle (and hence
/// to the local-state) must be undone during unwinding.**
struct Attached {
/// Pointer to the currently attached database.
database: Cell<Option<NonNull<dyn Database>>>,
}

impl Attached {
const fn new() -> Self {
Self {
database: Cell::new(None),
}
}

fn attach<Db, R>(&self, db: &Db, op: impl FnOnce(&Db) -> R) -> R
where
Db: ?Sized + Database,
{
struct DbGuard<'s> {
state: Option<&'s Attached>,
}

impl<'s> DbGuard<'s> {
fn new(attached: &'s Attached, db: &dyn Database) -> Self {
if let Some(current_db) = attached.database.get() {
// Already attached? Assert that the database has not changed.
assert_eq!(
current_db,
NonNull::from(db),
"cannot change database mid-query",
);
Self { state: None }
} else {
// Otherwise, set the database.
attached.database.set(Some(NonNull::from(db)));
Self {
state: Some(attached),
}
}
}
}

impl Drop for DbGuard<'_> {
fn drop(&mut self) {
// Reset database to null if we did anything in `DbGuard::new`.
if let Some(attached) = self.state {
attached.database.set(None);
}
}
}

let _guard = DbGuard::new(self, db.as_dyn_database());
op(db)
}

/// Access the "attached" database. Returns `None` if no database is attached.
/// Databases are attached with `attach_database`.
fn with<R>(&self, op: impl FnOnce(&dyn Database) -> R) -> Option<R> {
if let Some(db) = self.database.get() {
// SAFETY: We always attach the database in for the entire duration of a function,
// so it cannot become "unattached" while this function is running.
Some(op(unsafe { db.as_ref() }))
} else {
None
}
}
}

/// Attach the database to the current thread and execute `op`.
/// Panics if a different database has already been attached.
pub(crate) fn attach<R, Db>(db: &Db, op: impl FnOnce(&Db) -> R) -> R
where
Db: ?Sized + Database,
{
ATTACHED.with(|a| a.attach(db, op))
}

/// Access the "attached" database. Returns `None` if no database is attached.
/// Databases are attached with `attach_database`.
pub fn with_attached_database<R>(op: impl FnOnce(&dyn Database) -> R) -> Option<R> {
ATTACHED.with(|a| a.with(op))
}
4 changes: 2 additions & 2 deletions src/cycle.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{key::DatabaseKeyIndex, local_state, Database};
use crate::{key::DatabaseKeyIndex, Database};
use std::{panic::AssertUnwindSafe, sync::Arc};

/// Captures the participants of a cycle that occurred when executing a query.
Expand Down Expand Up @@ -74,7 +74,7 @@ impl Cycle {

impl std::fmt::Debug for Cycle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
local_state::with_attached_database(|db| {
crate::attach::with_attached_database(|db| {
f.debug_struct("UnexpectedCycle")
.field("all_participants", &self.all_participants(db))
.field("unexpected_participants", &self.unexpected_participants(db))
Expand Down
27 changes: 20 additions & 7 deletions src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::{any::Any, panic::RefUnwindSafe, sync::Arc};
use parking_lot::{Condvar, Mutex};

use crate::{
self as salsa, local_state,
self as salsa,
local_state::{self, LocalState},
storage::{Zalsa, ZalsaImpl},
Durability, Event, EventKind, Revision,
};
Expand All @@ -16,7 +17,7 @@ use crate::{
/// This trait can only safely be implemented by Salsa's [`DatabaseImpl`][] type.
/// FIXME: Document better the unsafety conditions we guarantee.
#[salsa_macros::db]
pub unsafe trait Database: AsDynDatabase + Any {
pub unsafe trait Database: Send + AsDynDatabase + Any {
/// This function is invoked by the salsa runtime at various points during execution.
/// You can customize what happens by implementing the [`UserData`][] trait.
/// By default, the event is logged at level debug using tracing facade.
Expand Down Expand Up @@ -45,17 +46,16 @@ pub unsafe trait Database: AsDynDatabase + Any {
/// revision.
fn report_untracked_read(&self) {
let db = self.as_dyn_database();
local_state::attach(db, |state| {
state.report_untracked_read(db.zalsa().current_revision())
})
let zalsa_local = db.zalsa_local();
zalsa_local.report_untracked_read(db.zalsa().current_revision())
}

/// Execute `op` with the database in thread-local storage for debug print-outs.
fn attach<R>(&self, op: impl FnOnce(&Self) -> R) -> R
where
Self: Sized,
{
local_state::attach(self, |_state| op(self))
crate::attach::attach(self, op)
}

/// Plumbing method: Access the internal salsa methods.
Expand All @@ -68,6 +68,10 @@ pub unsafe trait Database: AsDynDatabase + Any {
/// This can lead to deadlock!
#[doc(hidden)]
fn zalsa_mut(&mut self) -> &mut dyn Zalsa;

/// Access the thread-local state associated with this database
#[doc(hidden)]
fn zalsa_local(&self) -> &LocalState;
}

/// Upcast to a `dyn Database`.
Expand Down Expand Up @@ -113,6 +117,9 @@ pub struct DatabaseImpl<U: UserData = ()> {
/// Coordination data for cancellation of other handles when `zalsa_mut` is called.
/// This could be stored in ZalsaImpl but it makes things marginally cleaner to keep it separate.
coordinate: Arc<Coordinate>,

/// Per-thread state
zalsa_local: local_state::LocalState,
}

impl<U: UserData + Default> Default for DatabaseImpl<U> {
Expand Down Expand Up @@ -141,6 +148,7 @@ impl<U: UserData> DatabaseImpl<U> {
clones: Mutex::new(1),
cvar: Default::default(),
}),
zalsa_local: LocalState::new(),
}
}

Expand Down Expand Up @@ -201,6 +209,10 @@ unsafe impl<U: UserData> Database for DatabaseImpl<U> {
zalsa_mut
}

fn zalsa_local(&self) -> &LocalState {
&self.zalsa_local
}

// Report a salsa event.
fn salsa_event(&self, event: &dyn Fn() -> Event) {
U::salsa_event(self, event)
Expand All @@ -214,6 +226,7 @@ impl<U: UserData> Clone for DatabaseImpl<U> {
Self {
zalsa_impl: self.zalsa_impl.clone(),
coordinate: Arc::clone(&self.coordinate),
zalsa_local: LocalState::new(),
}
}
}
Expand All @@ -229,7 +242,7 @@ impl<U: UserData> Drop for DatabaseImpl<U> {
}
}

pub trait UserData: Any + Sized {
pub trait UserData: Any + Sized + Send + Sync {
/// Callback invoked by the [`Database`][] at key points during salsa execution.
/// By overriding this method, you can inject logging or other custom behavior.
///
Expand Down
73 changes: 35 additions & 38 deletions src/function/accumulated.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{accumulator, hash::FxHashSet, local_state, Database, DatabaseKeyIndex, Id};
use crate::{accumulator, hash::FxHashSet, Database, DatabaseKeyIndex, Id};

use super::{Configuration, IngredientImpl};

Expand All @@ -12,44 +12,41 @@ where
where
A: accumulator::Accumulator,
{
local_state::attach(db, |local_state| {
let zalsa = db.zalsa();
let current_revision = zalsa.current_revision();

let Some(accumulator) = <accumulator::IngredientImpl<A>>::from_db(db) else {
return vec![];
};
let mut output = vec![];

// First ensure the result is up to date
self.fetch(db, key);

let db_key = self.database_key_index(key);
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];

while let Some(k) = stack.pop() {
if visited.insert(k) {
accumulator.produced_by(current_revision, local_state, k, &mut output);

let origin = zalsa
.lookup_ingredient(k.ingredient_index)
.origin(k.key_index);
let inputs = origin.iter().flat_map(|origin| origin.inputs());
// Careful: we want to push in execution order, so reverse order to
// ensure the first child that was executed will be the first child popped
// from the stack.
stack.extend(
inputs
.flat_map(|input| {
TryInto::<DatabaseKeyIndex>::try_into(input).into_iter()
})
.rev(),
);
}
let zalsa = db.zalsa();
let zalsa_local = db.zalsa_local();
let current_revision = zalsa.current_revision();

let Some(accumulator) = <accumulator::IngredientImpl<A>>::from_db(db) else {
return vec![];
};
let mut output = vec![];

// First ensure the result is up to date
self.fetch(db, key);

let db_key = self.database_key_index(key);
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];

while let Some(k) = stack.pop() {
if visited.insert(k) {
accumulator.produced_by(current_revision, zalsa_local, k, &mut output);

let origin = zalsa
.lookup_ingredient(k.ingredient_index)
.origin(k.key_index);
let inputs = origin.iter().flat_map(|origin| origin.inputs());
// Careful: we want to push in execution order, so reverse order to
// ensure the first child that was executed will be the first child popped
// from the stack.
stack.extend(
inputs
.flat_map(|input| TryInto::<DatabaseKeyIndex>::try_into(input).into_iter())
.rev(),
);
}
}

output
})
output
}
}
Loading

0 comments on commit 22a8157

Please sign in to comment.