Skip to content

Commit

Permalink
Merge pull request #591 from puuuuh/typeid_mismatch_fix
Browse files Browse the repository at this point in the history
Add IngredientIndex to KeyStruct
  • Loading branch information
nikomatsakis authored Oct 17, 2024
2 parents c6c51a0 + 0744fd8 commit 82f2a7d
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 39 deletions.
16 changes: 8 additions & 8 deletions src/active_query.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use rustc_hash::FxHashMap;

use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions};
use crate::tracked_struct::IdentityHash;
use crate::{
accumulator::accumulated_map::AccumulatedMap,
durability::Durability,
hash::FxIndexSet,
key::{DatabaseKeyIndex, DependencyIndex},
tracked_struct::{Disambiguator, KeyStruct},
tracked_struct::{Disambiguator, Identity},
zalsa_local::EMPTY_DEPENDENCIES,
Cycle, Revision,
Cycle, Id, Revision,
};

use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions};

#[derive(Debug)]
pub(crate) struct ActiveQuery {
/// What query is executing
Expand Down Expand Up @@ -45,11 +45,11 @@ pub(crate) struct ActiveQuery {
/// This table starts empty as the query begins and is gradually populated.
/// Note that if a query executes in 2 different revisions but creates the same
/// set of tracked structs, they will get the same disambiguator values.
disambiguator_map: FxHashMap<u64, Disambiguator>,
disambiguator_map: FxHashMap<IdentityHash, Disambiguator>,

/// Map from tracked struct keys (which include the hash + disambiguator) to their
/// final id.
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, DatabaseKeyIndex>,
pub(crate) tracked_struct_ids: FxHashMap<Identity, Id>,

/// Stores the values accumulated to the given ingredient.
/// The type of accumulated value is erased but known to the ingredient.
Expand Down Expand Up @@ -155,10 +155,10 @@ impl ActiveQuery {
self.input_outputs.clone_from(&cycle_query.input_outputs);
}

pub(super) fn disambiguate(&mut self, hash: u64) -> Disambiguator {
pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator {
let disambiguator = self
.disambiguator_map
.entry(hash)
.entry(key)
.or_insert(Disambiguator(0));
let result = *disambiguator;
disambiguator.0 += 1;
Expand Down
6 changes: 3 additions & 3 deletions src/function/diff_outputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ where
if !old_outputs.is_empty() {
// Remove the outputs that are no longer present in the current revision
// to prevent that the next revision is seeded with a id mapping that no longer exists.
revisions.tracked_struct_ids.retain(|_k, value| {
revisions.tracked_struct_ids.retain(|k, value| {
!old_outputs.contains(&DependencyIndex {
ingredient_index: value.ingredient_index,
key_index: Some(value.key_index),
ingredient_index: k.ingredient_index(),
key_index: Some(*value),
})
});
}
Expand Down
45 changes: 34 additions & 11 deletions src/tracked_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,35 @@ where
/// stored in the [`ActiveQuery`](`crate::active_query::ActiveQuery`)
/// struct and later moved to the [`Memo`](`crate::function::memo::Memo`).
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
pub(crate) struct KeyStruct {
/// The hash of the `#[id]` fields of this struct.
/// Note that multiple structs may share the same hash.
data_hash: u64,
pub(crate) struct Identity {
/// Hash of fields with id attribute
identity_hash: IdentityHash,

/// The unique disambiguator assigned within the active query
/// to distinguish distinct tracked structs with the same hash.
/// to distinguish distinct tracked structs with the same identity_hash.
disambiguator: Disambiguator,
}

impl Identity {
pub(crate) fn ingredient_index(&self) -> IngredientIndex {
self.identity_hash.ingredient_index
}
}

/// Stores the data that (almost) uniquely identifies a tracked struct.
/// This includes the ingredient index of that struct type plus the hash of its id fields.
/// This is mapped to a disambiguator -- a value that starts as 0 but increments each round,
/// allowing for multiple tracked structs with the same hash and ingredient_index
/// created within the query to each have a unique id.
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
pub struct IdentityHash {
/// Index of the tracked struct ingredient.
ingredient_index: IngredientIndex,

/// Hash of the id fields.
hash: u64,
}

// ANCHOR: ValueStruct
#[derive(Debug)]
pub struct Value<C>
Expand Down Expand Up @@ -255,17 +274,21 @@ where
) -> C::Struct<'db> {
let (zalsa, zalsa_local) = db.zalsas();

let data_hash = crate::hash::hash(&C::id_fields(&fields));
let identity_hash = IdentityHash {
ingredient_index: self.ingredient_index,
hash: crate::hash::hash(&C::id_fields(&fields)),
};

let (current_deps, disambiguator) = zalsa_local.disambiguate(identity_hash);

let (current_deps, disambiguator) = zalsa_local.disambiguate(data_hash);
let identity = Identity {
identity_hash,

let key_struct = KeyStruct {
disambiguator,
data_hash,
};

let current_revision = zalsa.current_revision();
match zalsa_local.tracked_struct_id(&key_struct) {
match zalsa_local.tracked_struct_id(&identity) {
Some(id) => {
// The struct already exists in the intern map.
zalsa_local.add_output(self.database_key_index(id).into());
Expand All @@ -278,7 +301,7 @@ where
let id = self.allocate(zalsa, zalsa_local, current_revision, &current_deps, fields);
let key = self.database_key_index(id);
zalsa_local.add_output(key.into());
zalsa_local.store_tracked_struct_id(key_struct, key);
zalsa_local.store_tracked_struct_id(identity, id);
C::struct_from_id(id)
}
}
Expand Down
28 changes: 11 additions & 17 deletions src/zalsa_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use crate::runtime::StampedValue;
use crate::table::PageIndex;
use crate::table::Slot;
use crate::table::Table;
use crate::tracked_struct::Disambiguator;
use crate::tracked_struct::KeyStruct;
use crate::tracked_struct::{Disambiguator, Identity, IdentityHash};
use crate::zalsa::IngredientIndex;
use crate::Accumulator;
use crate::Cancelled;
Expand Down Expand Up @@ -262,15 +261,15 @@ impl ZalsaLocal {
/// * the current dependencies (durability, changed_at) of current query
/// * the disambiguator index
#[track_caller]
pub(crate) fn disambiguate(&self, data_hash: u64) -> (StampedValue<()>, Disambiguator) {
pub(crate) fn disambiguate(&self, key: IdentityHash) -> (StampedValue<()>, Disambiguator) {
assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
);

self.with_query_stack(|stack| {
let top_query = stack.last_mut().unwrap();
let disambiguator = top_query.disambiguate(data_hash);
let disambiguator = top_query.disambiguate(key);
(
StampedValue {
value: (),
Expand All @@ -283,32 +282,30 @@ impl ZalsaLocal {
}

#[track_caller]
pub(crate) fn tracked_struct_id(&self, key_struct: &KeyStruct) -> Option<Id> {
pub(crate) fn tracked_struct_id(&self, identity: &Identity) -> Option<Id> {
debug_assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
);

self.with_query_stack(|stack| {
let top_query = stack.last().unwrap();
top_query
.tracked_struct_ids
.get(key_struct)
.map(|index| index.key_index())
top_query.tracked_struct_ids.get(identity).copied()
})
}

#[track_caller]
pub(crate) fn store_tracked_struct_id(&self, key_struct: KeyStruct, id: DatabaseKeyIndex) {
pub(crate) fn store_tracked_struct_id(&self, identity: Identity, id: Id) {
debug_assert!(
self.query_in_progress(),
"cannot create a tracked struct disambiguator outside of a tracked function"
);
self.with_query_stack(|stack| {
let top_query = stack.last_mut().unwrap();
let old_id = top_query.tracked_struct_ids.insert(key_struct, id);
let old_id = top_query.tracked_struct_ids.insert(identity, id);
assert!(
old_id.is_none(),
"overwrote a previous id for `{key_struct:?}`"
"overwrote a previous id for `{identity:?}`"
);
})
}
Expand Down Expand Up @@ -377,7 +374,7 @@ pub(crate) struct QueryRevisions {
/// previous revision. To handle this, `diff_outputs` compares
/// the structs from the old/new revision and retains
/// only entries that appeared in the new revision.
pub(super) tracked_struct_ids: FxHashMap<KeyStruct, DatabaseKeyIndex>,
pub(super) tracked_struct_ids: FxHashMap<Identity, Id>,

pub(super) accumulated: AccumulatedMap,
}
Expand Down Expand Up @@ -536,10 +533,7 @@ impl ActiveQueryGuard<'_> {
}

/// Initialize the tracked struct ids with the values from the prior execution.
pub(crate) fn seed_tracked_struct_ids(
&self,
tracked_struct_ids: &FxHashMap<KeyStruct, DatabaseKeyIndex>,
) {
pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &FxHashMap<Identity, Id>) {
self.local_state.with_query_stack(|stack| {
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
Expand Down
32 changes: 32 additions & 0 deletions tests/hash_collision.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::hash::Hash;

#[test]
fn hello() {
use salsa::{Database, DatabaseImpl, Setter};

#[salsa::input]
struct Bool {
value: bool,
}

#[salsa::tracked]
struct True<'db> {}

#[salsa::tracked]
struct False<'db> {}

#[salsa::tracked]
fn hello(db: &dyn Database, bool: Bool) {
if bool.value(db) {
True::new(db);
} else {
False::new(db);
}
}

let mut db = DatabaseImpl::new();
let input = Bool::new(&db, false);
hello(&db, input);
input.set_value(&mut db).to(true);
hello(&db, input);
}

0 comments on commit 82f2a7d

Please sign in to comment.