Skip to content

Commit

Permalink
Merge pull request #573 from dhruvmanila/dhruv/recreate-panic
Browse files Browse the repository at this point in the history
Fix panic when recreating tracked struct that was deleted in previous revision
  • Loading branch information
nikomatsakis authored Sep 22, 2024
2 parents 198c43f + 8094e0c commit 4a7c955
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/active_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
key::{DatabaseKeyIndex, DependencyIndex},
tracked_struct::{Disambiguator, KeyStruct},
zalsa_local::EMPTY_DEPENDENCIES,
Cycle, Id, Revision,
Cycle, Revision,
};

use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions};
Expand Down Expand Up @@ -49,7 +49,7 @@ pub(crate) struct ActiveQuery {

/// Map from tracked struct keys (which include the hash + disambiguator) to their
/// final id.
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, Id>,
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, DatabaseKeyIndex>,

/// Stores the values accumulated to the given ingredient.
/// The type of accumulated value is erased but known to the ingredient.
Expand Down
28 changes: 18 additions & 10 deletions src/function/diff_outputs.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,43 @@
use super::{memo::Memo, Configuration, IngredientImpl};
use crate::{
hash::FxHashSet, key::DependencyIndex, zalsa_local::QueryRevisions, AsDynDatabase as _,
DatabaseKeyIndex, Event, EventKind,
};

use super::{memo::Memo, Configuration, IngredientImpl};

impl<C> IngredientImpl<C>
where
C: Configuration,
{
/// Compute the old and new outputs and invoke the `clear_stale_output` callback
/// for each output that was generated before but is not generated now.
///
/// This function takes a `&mut` reference to `revisions` to remove outputs
/// that no longer exist in this revision from [`QueryRevisions::tracked_struct_ids`].
pub(super) fn diff_outputs(
&self,
db: &C::DbView,
key: DatabaseKeyIndex,
old_memo: &Memo<C::Output<'_>>,
revisions: &QueryRevisions,
revisions: &mut QueryRevisions,
) {
// Iterate over the outputs of the `old_memo` and put them into a hashset
let mut old_outputs = FxHashSet::default();
old_memo.revisions.origin.outputs().for_each(|i| {
old_outputs.insert(i);
});
let mut old_outputs: FxHashSet<_> = old_memo.revisions.origin.outputs().collect();

// Iterate over the outputs of the current query
// and remove elements from `old_outputs` when we find them
for new_output in revisions.origin.outputs() {
if old_outputs.contains(&new_output) {
old_outputs.remove(&new_output);
}
old_outputs.remove(&new_output);
}

if !old_outputs.is_empty() {
// Remove the outputs that are no longer present in the current revision
// to prevent that the next revision is seeded with a id mapping that no longer exists.
revisions.tracked_struct_ids.retain(|_k, value| {
!old_outputs.contains(&DependencyIndex {
ingredient_index: value.ingredient_index,
key_index: Some(value.key_index),
})
});
}

for old_output in old_outputs {
Expand Down
2 changes: 1 addition & 1 deletion src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ where
// old value.
if let Some(old_memo) = &opt_old_memo {
self.backdate_if_appropriate(old_memo, &mut revisions, &value);
self.diff_outputs(db, database_key_index, old_memo, &revisions);
self.diff_outputs(db, database_key_index, old_memo, &mut revisions);
}

tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}");
Expand Down
2 changes: 1 addition & 1 deletion src/function/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ impl<C> IngredientImpl<C>
where
C: Configuration,
{
pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &C::Output<'db> {
pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> {
let (zalsa, zalsa_local) = db.zalsas();
zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database());

Expand Down
2 changes: 1 addition & 1 deletion src/function/specify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ where

if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) {
self.backdate_if_appropriate(&old_memo, &mut revisions, &value);
self.diff_outputs(db, database_key_index, &old_memo, &revisions);
self.diff_outputs(db, database_key_index, &old_memo, &mut revisions);
}

let memo = Memo {
Expand Down
5 changes: 3 additions & 2 deletions src/tracked_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,9 @@ where
None => {
// This is a new tracked struct, so create an entry in the struct map.
let id = self.allocate(zalsa, zalsa_local, current_revision, &current_deps, fields);
zalsa_local.add_output(self.database_key_index(id).into());
zalsa_local.store_tracked_struct_id(key_struct, id);
let key = self.database_key_index(id);
zalsa_local.add_output(key.into());
zalsa_local.store_tracked_struct_id(key_struct, key);
C::struct_from_id(id)
}
}
Expand Down
32 changes: 26 additions & 6 deletions src/zalsa_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,15 @@ impl ZalsaLocal {
);
self.with_query_stack(|stack| {
let top_query = stack.last().unwrap();
top_query.tracked_struct_ids.get(key_struct).cloned()
top_query
.tracked_struct_ids
.get(key_struct)
.map(|index| index.key_index())
})
}

#[track_caller]
pub(crate) fn store_tracked_struct_id(&self, key_struct: KeyStruct, id: Id) {
pub(crate) fn store_tracked_struct_id(&self, key_struct: KeyStruct, id: DatabaseKeyIndex) {
debug_assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
Expand Down Expand Up @@ -358,9 +361,23 @@ pub(crate) struct QueryRevisions {
pub(crate) origin: QueryOrigin,

/// The ids of tracked structs created by this query.
/// This is used to seed the next round if the query is
/// re-executed.
pub(super) tracked_struct_ids: FxHashMap<KeyStruct, Id>,
///
/// This table plays an important role when queries are
/// re-executed:
/// * A clone of this field is used as the initial set of
/// `TrackedStructId`s for the query on the next execution.
/// * The query will thus re-use the same ids if it creates
/// tracked structs with the same `KeyStruct` as before.
/// It may also create new tracked structs.
/// * One tricky case involves deleted structs. If
/// the old revision created a struct S but the new
/// revision did not, there will still be a map entry
/// for S. This is because queries only ever grow the map
/// and they start with the same entries as from the
/// previous revision. To handle this, `diff_outputs` compares
/// the structs from the old/new revision and retains
/// only entries that appeared in the new revision.
pub(super) tracked_struct_ids: FxHashMap<KeyStruct, DatabaseKeyIndex>,

pub(super) accumulated: AccumulatedMap,
}
Expand Down Expand Up @@ -519,7 +536,10 @@ impl ActiveQueryGuard<'_> {
}

/// Initialize the tracked struct ids with the values from the prior execution.
pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &FxHashMap<KeyStruct, Id>) {
pub(crate) fn seed_tracked_struct_ids(
&self,
tracked_struct_ids: &FxHashMap<KeyStruct, DatabaseKeyIndex>,
) {
self.local_state.with_query_stack(|stack| {
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
Expand Down
35 changes: 35 additions & 0 deletions tests/tracked_struct_recreate_new_revision.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//! Test that re-creating a `tracked` struct after it was deleted in a previous
//! revision doesn't panic.
#![allow(warnings)]

use salsa::Setter;

#[salsa::input]
struct MyInput {
field: u32,
}

#[salsa::tracked]
struct TrackedStruct<'db> {
field: u32,
}

#[salsa::tracked]
fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> Option<TrackedStruct<'_>> {
if input.field(db) == 1 {
Some(TrackedStruct::new(db, 1))
} else {
None
}
}

#[test]
fn execute() {
let mut db = salsa::DatabaseImpl::new();
let input = MyInput::new(&db, 1);
assert!(tracked_fn(&db, input).is_some());
input.set_field(&mut db).to(0);
assert_eq!(tracked_fn(&db, input), None);
input.set_field(&mut db).to(1);
assert!(tracked_fn(&db, input).is_some());
}

0 comments on commit 4a7c955

Please sign in to comment.