diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 5a46ce7866..ac26d64bd3 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -401,7 +401,9 @@ impl StateCache { /// reads. Assumes (and enforces) all initial reads are cached. pub fn to_state_diff(&self) -> StateChanges { let state_maps = self.writes.diff(&self.initial_reads); - StateChanges { state_maps } + let allocated_keys = + AllocatedKeys::from_storage_diff(&self.writes.storage, &self.initial_reads.storage); + StateChanges { state_maps, allocated_keys } } fn declare_contract(&mut self, class_hash: ClassHash) { @@ -677,11 +679,42 @@ impl StateChangesKeys { } } +#[cfg_attr(any(feature = "testing", test), derive(Clone))] +#[derive(Debug, Default, Eq, PartialEq)] +pub struct AllocatedKeys(HashSet); + +impl AllocatedKeys { + pub fn update(&mut self, state_change: &StateChanges) { + self.0.extend(&state_change.allocated_keys.0); + // TODO: Remove keys that are set back to zero. + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Collect entries that turn zero -> nonzero. + pub fn from_storage_diff( + _updated_storage: &HashMap, + _base_storage: &HashMap, + ) -> Self { + Self( + HashSet::new(), + // TODO: Calculate the difference between the updated_storage and the base_storage. + ) + } +} + /// Holds the state changes. #[cfg_attr(any(feature = "testing", test), derive(Clone))] #[derive(Debug, Default, Eq, PartialEq)] pub struct StateChanges { pub state_maps: StateMaps, + pub allocated_keys: AllocatedKeys, } impl StateChanges { @@ -691,8 +724,8 @@ impl StateChanges { let mut merged_state_changes = Self::default(); for state_change in state_changes { merged_state_changes.state_maps.extend(&state_change.state_maps); + merged_state_changes.allocated_keys.update(&state_change); } - merged_state_changes } @@ -727,8 +760,7 @@ impl StateChanges { n_compiled_class_hash_updates: self.state_maps.compiled_class_hashes.len(), n_modified_contracts: modified_contracts.len(), }, - // TODO: Set number of allocated keys. - n_allocated_keys: { 0 }, + n_allocated_keys: self.allocated_keys.len(), } } }