diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index dbb84243e7e..e433ef07933 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -58,6 +58,11 @@ impl CachedState { self.to_state_diff() } + pub fn to_state_cache(&mut self) -> StateResult { + self.update_initial_values_of_write_only_access()?; + Ok(self.cache.borrow().clone()) + } + pub fn update_cache( &mut self, write_updates: &StateMaps, @@ -387,7 +392,7 @@ impl StateMaps { /// The tracked changes are needed for block state commitment. // Invariant: keys cannot be deleted from fields (only used internally by the cached state). -#[derive(Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct StateCache { // Reader's cached information; initial values, read before any write operation (per cell). pub(crate) initial_reads: StateMaps, @@ -406,6 +411,21 @@ impl StateCache { StateChanges { state_maps, allocated_keys } } + /// Merges the given state caches into a single one. Note that the order of the state caches + /// is important. + pub fn merge(state_caches: Vec) -> Self { + let mut merged_state_cache = StateCache::default(); + // Gives priority to early initial reads. + state_caches.iter().rev().for_each(|state_cache| { + merged_state_cache.initial_reads.extend(&state_cache.initial_reads) + }); + // Gives priority to late writes. + state_caches + .iter() + .for_each(|state_cache| merged_state_cache.writes.extend(&state_cache.writes)); + merged_state_cache + } + fn declare_contract(&mut self, class_hash: ClassHash) { self.writes.declared_contracts.insert(class_hash, true); } @@ -733,17 +753,6 @@ pub struct StateChanges { } impl StateChanges { - /// Merges the given state changes into a single one. Note that the order of the state changes - /// is important. The state changes are merged in the order they appear in the given vector. - pub fn merge(state_changes: Vec) -> Self { - let mut merged_state_changes = Self::default(); - for state_change in state_changes { - merged_state_changes.state_maps.extend(&state_change.state_maps); - merged_state_changes.allocated_keys.update(&state_change); - } - merged_state_changes - } - pub fn count_for_fee_charge( &self, sender_address: Option, diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index bcc91ab2f57..719bd66f0d2 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -289,11 +289,11 @@ fn cached_state_state_diff_conversion() { assert_eq!(expected_state_diff, state.to_state_diff().unwrap().state_maps.into()); } -fn create_state_changes_for_test( +fn create_state_cache_for_test( state: &mut CachedState, sender_address: Option, fee_token_address: ContractAddress, -) -> StateChanges { +) -> StateCache { let contract_address = contract_address!(CONTRACT_ADDRESS); let contract_address2 = contract_address!("0x101"); let class_hash = class_hash!("0x10"); @@ -323,7 +323,7 @@ fn create_state_changes_for_test( let sender_balance_key = get_fee_token_var_address(sender_address); state.set_storage_at(fee_token_address, sender_balance_key, felt!("0x1999")).unwrap(); } - state.get_actual_state_changes().unwrap() + state.to_state_cache().unwrap() } #[rstest] @@ -333,7 +333,7 @@ fn test_from_state_changes_for_fee_charge( let mut state: CachedState = CachedState::default(); let fee_token_address = contract_address!("0x17"); let state_changes = - create_state_changes_for_test(&mut state, sender_address, fee_token_address); + create_state_cache_for_test(&mut state, sender_address, fee_token_address).to_state_diff(); let state_changes_count = state_changes.count_for_fee_charge(sender_address, fee_token_address); let n_expected_storage_updates = 1 + usize::from(sender_address.is_some()); let expected_state_changes_count = StateChangesCountForFee { @@ -350,37 +350,32 @@ fn test_from_state_changes_for_fee_charge( } #[rstest] -fn test_state_changes_merge( +fn test_state_cache_merge( #[values(Some(contract_address!("0x102")), None)] sender_address: Option, ) { // Create a transactional state containing the `create_state_changes_for_test` logic, get the - // state changes and then commit. + // state cache and then commit. let mut state: CachedState = CachedState::default(); let mut transactional_state = TransactionalState::create_transactional(&mut state); let block_context = BlockContext::create_for_testing(); let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address; - let state_changes1 = - create_state_changes_for_test(&mut transactional_state, sender_address, fee_token_address); + let state_cache1 = + create_state_cache_for_test(&mut transactional_state, sender_address, fee_token_address); transactional_state.commit(); // After performing `commit`, the transactional state is moved (into state). We need to create // a new transactional state that wraps `state` to continue. let mut transactional_state = TransactionalState::create_transactional(&mut state); - // Make sure that `get_actual_state_changes` on a newly created transactional state returns null - // state changes and that merging null state changes with non-null state changes results in the - // non-null state changes, no matter the order. - let state_changes2 = transactional_state.get_actual_state_changes().unwrap(); - assert_eq!(state_changes2, StateChanges::default()); - assert_eq!( - StateChanges::merge(vec![state_changes1.clone(), state_changes2.clone()]), - state_changes1 - ); - assert_eq!( - StateChanges::merge(vec![state_changes2.clone(), state_changes1.clone()]), - state_changes1 - ); - - // Get the storage updates addresses and keys from the state_changes1, to overwrite. + // Make sure that the state_changes of a newly created transactional state returns null + // state cache and that merging null state cache with non-null state cache results in the + // non-null state cache, no matter the order. + let state_cache2 = transactional_state.to_state_cache().unwrap(); + assert_eq!(state_cache2, StateCache::default()); + assert_eq!(StateCache::merge(vec![state_cache1.clone(), state_cache2.clone()]), state_cache1); + assert_eq!(StateCache::merge(vec![state_cache2.clone(), state_cache1.clone()]), state_cache1); + + // Get the storage updates addresses and keys from the state_cache1, to overwrite. + let state_changes1 = state_cache1.to_state_diff(); let mut storage_updates_keys = state_changes1.state_maps.storage.keys(); let &(contract_address, storage_key) = storage_updates_keys .find(|(contract_address, _)| contract_address == &contract_address!(CONTRACT_ADDRESS)) @@ -394,8 +389,8 @@ fn test_state_changes_merge( .set_storage_at(new_contract_address, storage_key, felt!("0x43210")) .unwrap(); transactional_state.increment_nonce(contract_address).unwrap(); - // Get the new state changes and then commit the transactional state. - let state_changes3 = transactional_state.get_actual_state_changes().unwrap(); + // Get the new state cache and then commit the transactional state. + let state_cache3 = transactional_state.to_state_cache().unwrap(); transactional_state.commit(); // Get the total state changes of the CachedState underlying all the temporary transactional @@ -403,15 +398,12 @@ fn test_state_changes_merge( // states, but only when done in the right order. let state_changes_final = state.get_actual_state_changes().unwrap(); assert_eq!( - StateChanges::merge(vec![ - state_changes1.clone(), - state_changes2.clone(), - state_changes3.clone() - ]), + StateCache::merge(vec![state_cache1.clone(), state_cache2.clone(), state_cache3.clone()]) + .to_state_diff(), state_changes_final ); assert_ne!( - StateChanges::merge(vec![state_changes3, state_changes1, state_changes2]), + StateCache::merge(vec![state_cache3, state_cache1, state_cache2]).to_state_diff(), state_changes_final ); } @@ -422,9 +414,10 @@ fn test_state_changes_merge( #[case(true, vec![felt!("0x7")], true)] #[case(false, vec![felt!("0x7")], false)] #[case(true, vec![felt!("0x7"), felt!("0x0")], false)] -#[case(false, vec![felt!("0x0"), felt!("0x8")], true)] +#[case(false, vec![felt!("0x7"), felt!("0x1")], false)] +#[case(false, vec![felt!("0x0"), felt!("0x8")], false)] #[case(false, vec![felt!("0x0"), felt!("0x8"), felt!("0x0")], false)] -fn test_allocated_keys_commit_and_merge( +fn test_state_cache_commit_and_merge( #[case] is_base_empty: bool, #[case] storage_updates: Vec, #[case] charged: bool, @@ -433,22 +426,33 @@ fn test_allocated_keys_commit_and_merge( let storage_key = StorageKey::from(0x10_u16); // Set initial state let mut state: CachedState = CachedState::default(); + + let non_empty_base_value = felt!("0x1"); if !is_base_empty { - state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap(); + state.set_storage_at(contract_address, storage_key, non_empty_base_value).unwrap(); } - let mut state_changes = vec![]; + let mut state_caches = vec![]; - for value in storage_updates { + for value in storage_updates.iter() { // In the end of the previous loop, state has moved into the transactional state. let mut transactional_state = TransactionalState::create_transactional(&mut state); // Update state and collect the state changes. - transactional_state.set_storage_at(contract_address, storage_key, value).unwrap(); - state_changes.push(transactional_state.get_actual_state_changes().unwrap()); + transactional_state.set_storage_at(contract_address, storage_key, *value).unwrap(); + state_caches.push(transactional_state.to_state_cache().unwrap()); transactional_state.commit(); } - let merged_changes = StateChanges::merge(state_changes); + let merged_changes = StateCache::merge(state_caches).to_state_diff(); assert_ne!(merged_changes.allocated_keys.is_empty(), charged); + + // Test the storage diff. + let base_value = if is_base_empty { Felt::ZERO } else { non_empty_base_value }; + let last_value = storage_updates.last().unwrap(); + let expected_storage_diff = if &base_value == last_value { None } else { Some(last_value) }; + assert_eq!( + merged_changes.state_maps.storage.get(&(contract_address, storage_key)), + expected_storage_diff, + ); } // Test that allocations in validate and execute phases are properly squashed. @@ -456,8 +460,7 @@ fn test_allocated_keys_commit_and_merge( #[case(false, felt!("0x7"), felt!("0x8"), false)] #[case(true, felt!("0x0"), felt!("0x8"), true)] #[case(true, felt!("0x7"), felt!("0x7"), true)] -// TODO: not charge in the following case. -#[case(false, felt!("0x0"), felt!("0x8"), true)] +#[case(false, felt!("0x0"), felt!("0x8"), false)] #[case(true, felt!("0x7"), felt!("0x0"), false)] fn test_allocated_keys_two_transactions( #[case] is_base_empty: bool, @@ -475,13 +478,14 @@ fn test_allocated_keys_two_transactions( let mut first_state = TransactionalState::create_transactional(&mut state); first_state.set_storage_at(contract_address, storage_key, validate_value).unwrap(); - let first_state_changes = first_state.get_actual_state_changes().unwrap(); + let first_state_changes = first_state.to_state_cache().unwrap(); let mut second_state = TransactionalState::create_transactional(&mut first_state); second_state.set_storage_at(contract_address, storage_key, execute_value).unwrap(); - let second_state_changes = second_state.get_actual_state_changes().unwrap(); + let second_state_changes = second_state.to_state_cache().unwrap(); - let merged_changes = StateChanges::merge(vec![first_state_changes, second_state_changes]); + let merged_changes = + StateCache::merge(vec![first_state_changes, second_state_changes]).to_state_diff(); assert_ne!(merged_changes.allocated_keys.is_empty(), charged); } diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index a4b245dd0d3..35bf6ce0d19 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -44,7 +44,7 @@ use crate::fee::fee_utils::{ use crate::fee::gas_usage::estimate_minimal_gas_vector; use crate::fee::receipt::TransactionReceipt; use crate::retdata; -use crate::state::cached_state::{StateChanges, TransactionalState}; +use crate::state::cached_state::{StateCache, TransactionalState}; use crate::state::state_api::{State, StateReader, UpdatableState}; use crate::transaction::errors::{ TransactionExecutionError, @@ -612,7 +612,8 @@ impl AccountTransaction { // Save the state changes resulting from running `validate_tx`, to be used later for // resource and fee calculation. - let validate_state_changes = state.get_actual_state_changes()?; + let validate_state_cache = state.to_state_cache()?; + let validate_state_changes = validate_state_cache.to_state_diff(); // Create copies of state and validate_resources for the execution. // Both will be rolled back if the execution is reverted or committed upon success. @@ -643,10 +644,11 @@ impl AccountTransaction { let tx_receipt = TransactionReceipt::from_account_tx( self, &tx_context, - &StateChanges::merge(vec![ - validate_state_changes, - execution_state.get_actual_state_changes()?, - ]), + &StateCache::merge(vec![ + validate_state_cache, + execution_state.to_state_cache()?, + ]) + .to_state_diff(), CallInfo::summarize_many( validate_call_info.iter().chain(execute_call_info.iter()), &tx_context.block_context.versioned_constants,