From 8f2cccda4fd426d743a183f8b361a466c5cc304d Mon Sep 17 00:00:00 2001 From: Yoav Gross Date: Sun, 10 Nov 2024 15:38:08 +0200 Subject: [PATCH] feat(blockifier): implement the allocated keys logic --- crates/blockifier/src/state/cached_state.rs | 27 ++++++++++--- .../blockifier/src/state/cached_state_test.rs | 40 ++++++++++++++++++- .../src/transaction/post_execution_test.rs | 13 +++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index ac26d64bd3a..cd9a2189409 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -684,9 +684,16 @@ impl StateChangesKeys { pub struct AllocatedKeys(HashSet); impl AllocatedKeys { + /// Extend the set of allocated keys with the allocated_keys of the given state changes. + /// Remove storage keys that are set back to zero. pub fn update(&mut self, state_change: &StateChanges) { self.0.extend(&state_change.allocated_keys.0); - // TODO: Remove keys that are set back to zero. + // Remove keys that are set back to zero. + state_change.state_maps.storage.iter().for_each(|(k, v)| { + if v == &Felt::ZERO { + self.0.remove(k); + } + }); } pub fn len(&self) -> usize { @@ -699,12 +706,22 @@ impl AllocatedKeys { /// Collect entries that turn zero -> nonzero. pub fn from_storage_diff( - _updated_storage: &HashMap, - _base_storage: &HashMap, + updated_storage: &HashMap, + base_storage: &HashMap, ) -> Self { Self( - HashSet::new(), - // TODO: Calculate the difference between the updated_storage and the base_storage. + updated_storage + .iter() + .filter_map(|(k, v)| { + let base_value = base_storage.get(k); + if *v != Felt::ZERO && (base_value.is_none() || base_value == Some(&Felt::ZERO)) + { + Some(*k) + } else { + None + } + }) + .collect(), ) } } diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index b23a189261d..e2117503e7b 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -338,15 +338,16 @@ fn test_from_state_changes_for_fee_charge( sender_address, fee_token_address, ); + let n_expected_storage_update = 1 + usize::from(sender_address.is_some()); let expected_state_changes_count = StateChangesCountForFee { // 1 for storage update + 1 for sender balance update if sender is defined. state_changes_count: StateChangesCount { - n_storage_updates: 1 + usize::from(sender_address.is_some()), + n_storage_updates: n_expected_storage_update, n_class_hash_updates: 1, n_compiled_class_hash_updates: 1, n_modified_contracts: 2, }, - n_allocated_keys: 0, + n_allocated_keys: if enable_stateful_compression { n_expected_storage_update } else { 0 }, }; assert_eq!(state_changes_count, expected_state_changes_count); } @@ -418,6 +419,41 @@ fn test_state_changes_merge( ); } +/// Test that `allocated_keys` collects zero -> nonzero updates. +#[rstest] +#[case(false, vec![felt!("0x0")], false)] +#[case(true, vec![felt!("0x7")], true)] +#[case(false, vec![felt!("0x7")], false)] +#[case(true, vec![felt!("0x7"), felt!("0x0")], false)] +#[case(false, vec![felt!("0x0"), felt!("0x8")], true)] +#[case(false, vec![felt!("0x0"), felt!("0x8"), felt!("0x0")], false)] +fn test_allocated_keys_update( + #[case] is_base_empty: bool, + #[case] storage_updates: Vec, + #[case] charged: bool, +) { + let contract_address = contract_address!(CONTRACT_ADDRESS); + let storage_key = StorageKey::from(0x10_u16); + // Set initial state + let mut state: CachedState = CachedState::default(); + if !is_base_empty { + state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap(); + } + let mut state_changes = vec![]; + + for value in storage_updates { + // In the end of the previous loop, state has moved into the transactional state. + let mut transactional_state = TransactionalState::create_transactional(&mut state); + // Update state and collect the state changes. + transactional_state.set_storage_at(contract_address, storage_key, value).unwrap(); + state_changes.push(transactional_state.get_actual_state_changes().unwrap()); + transactional_state.commit(); + } + + let merged_changes = StateChanges::merge(state_changes); + assert_ne!(merged_changes.allocated_keys.is_empty(), charged); +} + #[test] fn test_contract_cache_is_used() { // Initialize the global cache with a single class, and initialize an empty state with this diff --git a/crates/blockifier/src/transaction/post_execution_test.rs b/crates/blockifier/src/transaction/post_execution_test.rs index 98956e08131..3ac147f0a36 100644 --- a/crates/blockifier/src/transaction/post_execution_test.rs +++ b/crates/blockifier/src/transaction/post_execution_test.rs @@ -289,7 +289,7 @@ fn test_revert_on_resource_overuse( // We need this kind of invocation, to be able to test the specific scenario: the resource // bounds must be enough to allow completion of the transaction, and yet must still fail // post-execution bounds check. - let execution_info_measure = run_invoke_tx( + let mut execution_info_measure = run_invoke_tx( &mut state, &block_context, invoke_tx_args! { @@ -337,6 +337,17 @@ fn test_revert_on_resource_overuse( .unwrap(); assert_eq!(execution_info_tight.revert_error, None); assert_eq!(execution_info_tight.receipt.fee, actual_fee); + // The only difference between the two executions should be the number of allocated keys, as the + // second execution writes to the same keys as the first. + let n_allocated_keys = &mut execution_info_measure + .receipt + .resources + .starknet_resources + .state + .state_changes_for_fee + .n_allocated_keys; + assert_eq!(n_allocated_keys, &usize::from(n_writes)); + *n_allocated_keys = 0; assert_eq!(execution_info_tight.receipt.resources, execution_info_measure.receipt.resources); // Re-run the same function with max bounds slightly below the actual usage, and verify it's