From dc573fee4b1f18b2b542fe5aa2bee381582314de Mon Sep 17 00:00:00 2001 From: Yoav Gross Date: Wed, 27 Nov 2024 15:20:53 +0200 Subject: [PATCH] fix(blockifier): merge state diff with squash --- crates/blockifier/src/state/cached_state.rs | 70 +++++++---- .../blockifier/src/state/cached_state_test.rs | 109 ++++++++++-------- .../src/transaction/account_transaction.rs | 17 +-- 3 files changed, 118 insertions(+), 78 deletions(-) diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 356b7b8f89..75b524e1a9 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -1,4 +1,4 @@ -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::collections::{HashMap, HashSet}; use indexmap::IndexMap; @@ -58,6 +58,11 @@ impl CachedState { self.to_state_diff() } + pub fn borrow_updated_state_cache(&mut self) -> StateResult> { + self.update_initial_values_of_write_only_access()?; + Ok(self.cache.borrow()) + } + pub fn update_cache( &mut self, write_updates: &StateMaps, @@ -383,7 +388,7 @@ impl StateMaps { /// The tracked changes are needed for block state commitment. // Invariant: keys cannot be deleted from fields (only used internally by the cached state). -#[derive(Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct StateCache { // Reader's cached information; initial values, read before any write operation (per cell). pub(crate) initial_reads: StateMaps, @@ -402,6 +407,44 @@ impl StateCache { StateChanges { state_maps, allocated_keys } } + /// Squashes the given state caches into a single one and returns the state diff. Note that the + /// order of the state caches is important. + pub fn squash_state_caches(state_caches: Vec<&Self>) -> Self { + let mut squashed_state_cache = StateCache::default(); + + // Gives priority to early initial reads. + state_caches.iter().rev().for_each(|state_cache| { + squashed_state_cache.initial_reads.extend(&state_cache.initial_reads) + }); + // Gives priority to late writes. + state_caches + .iter() + .for_each(|state_cache| squashed_state_cache.writes.extend(&state_cache.writes)); + squashed_state_cache + } + + /// Squashes the given state caches into a single one and returns the state diff. Note that the + /// order of the state caches is important. + /// If 'comprehensive_state_diff' is false, opposite updates may not be canceled out. Used for + /// backward compatibility. + pub fn squash_state_diff_backward_compatible( + state_caches: Vec<&Self>, + comprehensive_state_diff: bool, + ) -> StateChanges { + if comprehensive_state_diff { + return Self::squash_state_caches(state_caches).to_state_diff(); + } + + // Backward compatibility. + let mut merged_state_changes = StateChanges::default(); + for state_cache in state_caches { + let state_change = state_cache.to_state_diff(); + merged_state_changes.state_maps.extend(&state_change.state_maps); + merged_state_changes.allocated_keys.0.extend(&state_change.allocated_keys.0); + } + merged_state_changes + } + fn declare_contract(&mut self, class_hash: ClassHash) { self.writes.declared_contracts.insert(class_hash, true); } @@ -680,18 +723,6 @@ impl StateChangesKeys { pub struct AllocatedKeys(HashSet); impl AllocatedKeys { - /// Extends the set of allocated keys with the allocated_keys of the given state changes. - /// Removes storage keys that are set back to zero. - pub fn update(&mut self, state_change: &StateChanges) { - self.0.extend(&state_change.allocated_keys.0); - // Remove keys that are set back to zero. - state_change.state_maps.storage.iter().for_each(|(k, v)| { - if v == &Felt::ZERO { - self.0.remove(k); - } - }); - } - pub fn len(&self) -> usize { self.0.len() } @@ -726,17 +757,6 @@ pub struct StateChanges { } impl StateChanges { - /// Merges the given state changes into a single one. Note that the order of the state changes - /// is important. The state changes are merged in the order they appear in the given vector. - pub fn merge(state_changes: Vec) -> Self { - let mut merged_state_changes = Self::default(); - for state_change in state_changes { - merged_state_changes.state_maps.extend(&state_change.state_maps); - merged_state_changes.allocated_keys.update(&state_change); - } - merged_state_changes - } - pub fn count_for_fee_charge( &self, sender_address: Option, diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 20f191a99c..16e8e62430 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -289,11 +289,11 @@ fn cached_state_state_diff_conversion() { assert_eq!(expected_state_diff, state.to_state_diff().unwrap().state_maps.into()); } -fn create_state_changes_for_test( +fn create_state_cache_for_test( state: &mut CachedState, sender_address: Option, fee_token_address: ContractAddress, -) -> StateChanges { +) -> StateCache { let contract_address = contract_address!(CONTRACT_ADDRESS); let contract_address2 = contract_address!("0x101"); let class_hash = class_hash!("0x10"); @@ -323,7 +323,7 @@ fn create_state_changes_for_test( let sender_balance_key = get_fee_token_var_address(sender_address); state.set_storage_at(fee_token_address, sender_balance_key, felt!("0x1999")).unwrap(); } - state.get_actual_state_changes().unwrap() + state.borrow_updated_state_cache().unwrap().clone() } #[rstest] @@ -333,7 +333,7 @@ fn test_from_state_changes_for_fee_charge( let mut state: CachedState = CachedState::default(); let fee_token_address = contract_address!("0x17"); let state_changes = - create_state_changes_for_test(&mut state, sender_address, fee_token_address); + create_state_cache_for_test(&mut state, sender_address, fee_token_address).to_state_diff(); let state_changes_count = state_changes.count_for_fee_charge(sender_address, fee_token_address); let n_expected_storage_updates = 1 + usize::from(sender_address.is_some()); let expected_state_changes_count = StateChangesCountForFee { @@ -350,37 +350,32 @@ fn test_from_state_changes_for_fee_charge( } #[rstest] -fn test_state_changes_merge( +fn test_state_cache_merge( #[values(Some(contract_address!("0x102")), None)] sender_address: Option, ) { // Create a transactional state containing the `create_state_changes_for_test` logic, get the - // state changes and then commit. + // state cache and then commit. let mut state: CachedState = CachedState::default(); let mut transactional_state = TransactionalState::create_transactional(&mut state); let block_context = BlockContext::create_for_testing(); let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address; - let state_changes1 = - create_state_changes_for_test(&mut transactional_state, sender_address, fee_token_address); + let state_cache1 = + create_state_cache_for_test(&mut transactional_state, sender_address, fee_token_address); transactional_state.commit(); // After performing `commit`, the transactional state is moved (into state). We need to create // a new transactional state that wraps `state` to continue. let mut transactional_state = TransactionalState::create_transactional(&mut state); - // Make sure that `get_actual_state_changes` on a newly created transactional state returns null - // state changes and that merging null state changes with non-null state changes results in the - // non-null state changes, no matter the order. - let state_changes2 = transactional_state.get_actual_state_changes().unwrap(); - assert_eq!(state_changes2, StateChanges::default()); - assert_eq!( - StateChanges::merge(vec![state_changes1.clone(), state_changes2.clone()]), - state_changes1 - ); - assert_eq!( - StateChanges::merge(vec![state_changes2.clone(), state_changes1.clone()]), - state_changes1 - ); - - // Get the storage updates addresses and keys from the state_changes1, to overwrite. + // Make sure that the state_changes of a newly created transactional state returns null + // state cache and that merging null state cache with non-null state cache results in the + // non-null state cache, no matter the order. + let state_cache2 = transactional_state.borrow_updated_state_cache().unwrap().clone(); + assert_eq!(state_cache2, StateCache::default()); + assert_eq!(StateCache::squash_state_caches(vec![&state_cache1, &state_cache2]), state_cache1); + assert_eq!(StateCache::squash_state_caches(vec![&state_cache2, &state_cache1]), state_cache1); + + // Get the storage updates addresses and keys from the state_cache1, to overwrite. + let state_changes1 = state_cache1.to_state_diff(); let mut storage_updates_keys = state_changes1.state_maps.storage.keys(); let &(contract_address, storage_key) = storage_updates_keys .find(|(contract_address, _)| contract_address == &contract_address!(CONTRACT_ADDRESS)) @@ -394,8 +389,8 @@ fn test_state_changes_merge( .set_storage_at(new_contract_address, storage_key, felt!("0x43210")) .unwrap(); transactional_state.increment_nonce(contract_address).unwrap(); - // Get the new state changes and then commit the transactional state. - let state_changes3 = transactional_state.get_actual_state_changes().unwrap(); + // Get the new state cache and then commit the transactional state. + let state_cache3 = transactional_state.borrow_updated_state_cache().unwrap().clone(); transactional_state.commit(); // Get the total state changes of the CachedState underlying all the temporary transactional @@ -403,15 +398,13 @@ fn test_state_changes_merge( // states, but only when done in the right order. let state_changes_final = state.get_actual_state_changes().unwrap(); assert_eq!( - StateChanges::merge(vec![ - state_changes1.clone(), - state_changes2.clone(), - state_changes3.clone() - ]), + StateCache::squash_state_caches(vec![&state_cache1, &state_cache2, &state_cache3]) + .to_state_diff(), state_changes_final ); assert_ne!( - StateChanges::merge(vec![state_changes3, state_changes1, state_changes2]), + StateCache::squash_state_caches(vec![&state_cache3, &state_cache1, &state_cache2]) + .to_state_diff(), state_changes_final ); } @@ -422,33 +415,57 @@ fn test_state_changes_merge( #[case(true, vec![felt!("0x7")], true)] #[case(false, vec![felt!("0x7")], false)] #[case(true, vec![felt!("0x7"), felt!("0x0")], false)] -#[case(false, vec![felt!("0x0"), felt!("0x8")], true)] +#[case(false, vec![felt!("0x7"), felt!("0x1")], false)] +#[case(false, vec![felt!("0x0"), felt!("0x8")], false)] #[case(false, vec![felt!("0x0"), felt!("0x8"), felt!("0x0")], false)] -fn test_allocated_keys_commit_and_merge( +fn test_state_cache_commit_and_merge( #[case] is_base_empty: bool, #[case] storage_updates: Vec, #[case] charged: bool, + #[values(true, false)] comprehensive_state_diff: bool, ) { let contract_address = contract_address!(CONTRACT_ADDRESS); let storage_key = StorageKey::from(0x10_u16); // Set initial state let mut state: CachedState = CachedState::default(); + + let non_empty_base_value = felt!("0x1"); if !is_base_empty { - state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap(); + state.set_storage_at(contract_address, storage_key, non_empty_base_value).unwrap(); } - let mut state_changes = vec![]; + let mut state_caches = vec![]; - for value in storage_updates { + for value in storage_updates.iter() { // In the end of the previous loop, state has moved into the transactional state. let mut transactional_state = TransactionalState::create_transactional(&mut state); // Update state and collect the state changes. - transactional_state.set_storage_at(contract_address, storage_key, value).unwrap(); - state_changes.push(transactional_state.get_actual_state_changes().unwrap()); + transactional_state.set_storage_at(contract_address, storage_key, *value).unwrap(); + state_caches.push(transactional_state.borrow_updated_state_cache().unwrap().clone()); transactional_state.commit(); } - let merged_changes = StateChanges::merge(state_changes); - assert_ne!(merged_changes.allocated_keys.is_empty(), charged); + let merged_changes = StateCache::squash_state_diff_backward_compatible( + state_caches.iter().collect(), + comprehensive_state_diff, + ); + if comprehensive_state_diff { + // The comprehensive_state_diff is needed for backward compatibility of versions before the + // allocated keys feature was inserted. + assert_ne!(merged_changes.allocated_keys.is_empty(), charged); + } + + // Test the storage diff. + let base_value = if is_base_empty { Felt::ZERO } else { non_empty_base_value }; + let last_value = storage_updates.last().unwrap(); + let expected_storage_diff = if (&base_value == last_value) && comprehensive_state_diff { + None + } else { + Some(last_value) + }; + assert_eq!( + merged_changes.state_maps.storage.get(&(contract_address, storage_key)), + expected_storage_diff, + ); } // Test that allocations in validate and execute phases are properly squashed. @@ -456,8 +473,7 @@ fn test_allocated_keys_commit_and_merge( #[case(false, felt!("0x7"), felt!("0x8"), false)] #[case(true, felt!("0x0"), felt!("0x8"), true)] #[case(true, felt!("0x7"), felt!("0x7"), true)] -// TODO: not charge in the following case. -#[case(false, felt!("0x0"), felt!("0x8"), true)] +#[case(false, felt!("0x0"), felt!("0x8"), false)] #[case(true, felt!("0x7"), felt!("0x0"), false)] fn test_allocated_keys_two_transactions( #[case] is_base_empty: bool, @@ -475,14 +491,15 @@ fn test_allocated_keys_two_transactions( let mut first_state = TransactionalState::create_transactional(&mut state); first_state.set_storage_at(contract_address, storage_key, validate_value).unwrap(); - let first_state_changes = first_state.get_actual_state_changes().unwrap(); + let first_state_changes = first_state.borrow_updated_state_cache().unwrap().clone(); let mut second_state = TransactionalState::create_transactional(&mut first_state); second_state.set_storage_at(contract_address, storage_key, execute_value).unwrap(); - let second_state_changes = second_state.get_actual_state_changes().unwrap(); + let second_state_changes = second_state.borrow_updated_state_cache().unwrap().clone(); - let merged_changes = StateChanges::merge(vec![first_state_changes, second_state_changes]); - assert_ne!(merged_changes.allocated_keys.is_empty(), charged); + let merged_changes = + StateCache::squash_state_caches(vec![&first_state_changes, &second_state_changes]); + assert_ne!(merged_changes.to_state_diff().allocated_keys.is_empty(), charged); } #[test] diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 3d87f93c85..e2899de3b8 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -39,7 +39,7 @@ use crate::fee::fee_utils::{ use crate::fee::gas_usage::estimate_minimal_gas_vector; use crate::fee::receipt::TransactionReceipt; use crate::retdata; -use crate::state::cached_state::{StateChanges, TransactionalState}; +use crate::state::cached_state::{StateCache, TransactionalState}; use crate::state::state_api::{State, StateReader, UpdatableState}; use crate::transaction::errors::{ TransactionExecutionError, @@ -575,7 +575,7 @@ impl AccountTransaction { // Save the state changes resulting from running `validate_tx`, to be used later for // resource and fee calculation. - let validate_state_changes = state.get_actual_state_changes()?; + let validate_state_cache = state.borrow_updated_state_cache()?.clone(); // Create copies of state and validate_resources for the execution. // Both will be rolled back if the execution is reverted or committed upon success. @@ -591,7 +591,7 @@ impl AccountTransaction { let revert_receipt = TransactionReceipt::from_account_tx( self, &tx_context, - &validate_state_changes, + &validate_state_cache.to_state_diff(), CallInfo::summarize_many( validate_call_info.iter(), &tx_context.block_context.versioned_constants, @@ -606,10 +606,13 @@ impl AccountTransaction { let tx_receipt = TransactionReceipt::from_account_tx( self, &tx_context, - &StateChanges::merge(vec![ - validate_state_changes, - execution_state.get_actual_state_changes()?, - ]), + &StateCache::squash_state_diff_backward_compatible( + vec![ + &validate_state_cache, + &execution_state.borrow_updated_state_cache()?.clone(), + ], + tx_context.block_context.versioned_constants.comprehensive_state_diff, + ), CallInfo::summarize_many( validate_call_info.iter().chain(execute_call_info.iter()), &tx_context.block_context.versioned_constants,