Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(blockifier): merge state diff with squash #2310

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::cell::RefCell;
use std::cell::{Ref, RefCell};
use std::collections::{HashMap, HashSet};

use indexmap::IndexMap;
Expand Down Expand Up @@ -58,6 +58,11 @@ impl<S: StateReader> CachedState<S> {
self.to_state_diff()
}

pub fn borrow_updated_state_cache(&mut self) -> StateResult<Ref<'_, StateCache>> {
self.update_initial_values_of_write_only_access()?;
Ok(self.cache.borrow())
}

pub fn update_cache(
&mut self,
write_updates: &StateMaps,
Expand Down Expand Up @@ -383,7 +388,7 @@ impl StateMaps {
/// The tracked changes are needed for block state commitment.

// Invariant: keys cannot be deleted from fields (only used internally by the cached state).
#[derive(Debug, Default, PartialEq, Eq)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct StateCache {
// Reader's cached information; initial values, read before any write operation (per cell).
pub(crate) initial_reads: StateMaps,
Expand All @@ -402,6 +407,44 @@ impl StateCache {
StateChanges { state_maps, allocated_keys }
}

/// Squashes the given state caches into a single one and returns the state diff. Note that the
/// order of the state caches is important.
pub fn squash_state_caches(state_caches: Vec<&Self>) -> Self {
let mut squashed_state_cache = StateCache::default();

// Gives priority to early initial reads.
state_caches.iter().rev().for_each(|state_cache| {
squashed_state_cache.initial_reads.extend(&state_cache.initial_reads)
});
// Gives priority to late writes.
state_caches
.iter()
.for_each(|state_cache| squashed_state_cache.writes.extend(&state_cache.writes));
squashed_state_cache
}

/// Squashes the given state caches into a single one and returns the state diff. Note that the
/// order of the state caches is important.
/// If 'comprehensive_state_diff' is false, opposite updates may not be canceled out. Used for
/// backward compatibility.
pub fn squash_state_diff(
state_caches: Vec<&Self>,
comprehensive_state_diff: bool,
) -> StateChanges {
if comprehensive_state_diff {
return Self::squash_state_caches(state_caches).to_state_diff();
}

// Backward compatibility.
let mut merged_state_changes = StateChanges::default();
for state_cache in state_caches {
let state_change = state_cache.to_state_diff();
merged_state_changes.state_maps.extend(&state_change.state_maps);
merged_state_changes.allocated_keys.0.extend(&state_change.allocated_keys.0);
}
merged_state_changes
}

fn declare_contract(&mut self, class_hash: ClassHash) {
self.writes.declared_contracts.insert(class_hash, true);
}
Expand Down Expand Up @@ -680,18 +723,6 @@ impl StateChangesKeys {
pub struct AllocatedKeys(HashSet<StorageEntry>);

impl AllocatedKeys {
/// Extends the set of allocated keys with the allocated_keys of the given state changes.
/// Removes storage keys that are set back to zero.
pub fn update(&mut self, state_change: &StateChanges) {
self.0.extend(&state_change.allocated_keys.0);
// Remove keys that are set back to zero.
state_change.state_maps.storage.iter().for_each(|(k, v)| {
if v == &Felt::ZERO {
self.0.remove(k);
}
});
}

pub fn len(&self) -> usize {
self.0.len()
}
Expand Down Expand Up @@ -726,17 +757,6 @@ pub struct StateChanges {
}

impl StateChanges {
/// Merges the given state changes into a single one. Note that the order of the state changes
/// is important. The state changes are merged in the order they appear in the given vector.
pub fn merge(state_changes: Vec<Self>) -> Self {
let mut merged_state_changes = Self::default();
for state_change in state_changes {
merged_state_changes.state_maps.extend(&state_change.state_maps);
merged_state_changes.allocated_keys.update(&state_change);
}
merged_state_changes
}

pub fn count_for_fee_charge(
&self,
sender_address: Option<ContractAddress>,
Expand Down
107 changes: 61 additions & 46 deletions crates/blockifier/src/state/cached_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,11 @@ fn cached_state_state_diff_conversion() {
assert_eq!(expected_state_diff, state.to_state_diff().unwrap().state_maps.into());
}

fn create_state_changes_for_test<S: StateReader>(
fn create_state_cache_for_test<S: StateReader>(
state: &mut CachedState<S>,
sender_address: Option<ContractAddress>,
fee_token_address: ContractAddress,
) -> StateChanges {
) -> StateCache {
let contract_address = contract_address!(CONTRACT_ADDRESS);
let contract_address2 = contract_address!("0x101");
let class_hash = class_hash!("0x10");
Expand Down Expand Up @@ -323,7 +323,7 @@ fn create_state_changes_for_test<S: StateReader>(
let sender_balance_key = get_fee_token_var_address(sender_address);
state.set_storage_at(fee_token_address, sender_balance_key, felt!("0x1999")).unwrap();
}
state.get_actual_state_changes().unwrap()
state.borrow_updated_state_cache().unwrap().clone()
}

#[rstest]
Expand All @@ -333,7 +333,7 @@ fn test_from_state_changes_for_fee_charge(
let mut state: CachedState<DictStateReader> = CachedState::default();
let fee_token_address = contract_address!("0x17");
let state_changes =
create_state_changes_for_test(&mut state, sender_address, fee_token_address);
create_state_cache_for_test(&mut state, sender_address, fee_token_address).to_state_diff();
let state_changes_count = state_changes.count_for_fee_charge(sender_address, fee_token_address);
let n_expected_storage_updates = 1 + usize::from(sender_address.is_some());
let expected_state_changes_count = StateChangesCountForFee {
Expand All @@ -350,37 +350,32 @@ fn test_from_state_changes_for_fee_charge(
}

#[rstest]
fn test_state_changes_merge(
fn test_state_cache_merge(
#[values(Some(contract_address!("0x102")), None)] sender_address: Option<ContractAddress>,
) {
// Create a transactional state containing the `create_state_changes_for_test` logic, get the
// state changes and then commit.
// state cache and then commit.
let mut state: CachedState<DictStateReader> = CachedState::default();
let mut transactional_state = TransactionalState::create_transactional(&mut state);
let block_context = BlockContext::create_for_testing();
let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address;
let state_changes1 =
create_state_changes_for_test(&mut transactional_state, sender_address, fee_token_address);
let state_cache1 =
create_state_cache_for_test(&mut transactional_state, sender_address, fee_token_address);
transactional_state.commit();

// After performing `commit`, the transactional state is moved (into state). We need to create
// a new transactional state that wraps `state` to continue.
let mut transactional_state = TransactionalState::create_transactional(&mut state);
// Make sure that `get_actual_state_changes` on a newly created transactional state returns null
// state changes and that merging null state changes with non-null state changes results in the
// non-null state changes, no matter the order.
let state_changes2 = transactional_state.get_actual_state_changes().unwrap();
assert_eq!(state_changes2, StateChanges::default());
assert_eq!(
StateChanges::merge(vec![state_changes1.clone(), state_changes2.clone()]),
state_changes1
);
assert_eq!(
StateChanges::merge(vec![state_changes2.clone(), state_changes1.clone()]),
state_changes1
);

// Get the storage updates addresses and keys from the state_changes1, to overwrite.
// Make sure that the state_changes of a newly created transactional state returns null
// state cache and that merging null state cache with non-null state cache results in the
// non-null state cache, no matter the order.
let state_cache2 = transactional_state.borrow_updated_state_cache().unwrap().clone();
assert_eq!(state_cache2, StateCache::default());
assert_eq!(StateCache::squash_state_caches(vec![&state_cache1, &state_cache2]), state_cache1);
assert_eq!(StateCache::squash_state_caches(vec![&state_cache2, &state_cache1]), state_cache1);

// Get the storage updates addresses and keys from the state_cache1, to overwrite.
let state_changes1 = state_cache1.to_state_diff();
let mut storage_updates_keys = state_changes1.state_maps.storage.keys();
let &(contract_address, storage_key) = storage_updates_keys
.find(|(contract_address, _)| contract_address == &contract_address!(CONTRACT_ADDRESS))
Expand All @@ -394,24 +389,22 @@ fn test_state_changes_merge(
.set_storage_at(new_contract_address, storage_key, felt!("0x43210"))
.unwrap();
transactional_state.increment_nonce(contract_address).unwrap();
// Get the new state changes and then commit the transactional state.
let state_changes3 = transactional_state.get_actual_state_changes().unwrap();
// Get the new state cache and then commit the transactional state.
let state_cache3 = transactional_state.borrow_updated_state_cache().unwrap().clone();
transactional_state.commit();

// Get the total state changes of the CachedState underlying all the temporary transactional
// states. We expect the state_changes to match the merged state_changes of the transactional
// states, but only when done in the right order.
let state_changes_final = state.get_actual_state_changes().unwrap();
assert_eq!(
StateChanges::merge(vec![
state_changes1.clone(),
state_changes2.clone(),
state_changes3.clone()
]),
StateCache::squash_state_caches(vec![&state_cache1, &state_cache2, &state_cache3])
.to_state_diff(),
state_changes_final
);
assert_ne!(
StateChanges::merge(vec![state_changes3, state_changes1, state_changes2]),
StateCache::squash_state_caches(vec![&state_cache3, &state_cache1, &state_cache2])
.to_state_diff(),
state_changes_final
);
}
Expand All @@ -422,42 +415,63 @@ fn test_state_changes_merge(
#[case(true, vec![felt!("0x7")], true)]
#[case(false, vec![felt!("0x7")], false)]
#[case(true, vec![felt!("0x7"), felt!("0x0")], false)]
#[case(false, vec![felt!("0x0"), felt!("0x8")], true)]
#[case(false, vec![felt!("0x7"), felt!("0x1")], false)]
#[case(false, vec![felt!("0x0"), felt!("0x8")], false)]
#[case(false, vec![felt!("0x0"), felt!("0x8"), felt!("0x0")], false)]
fn test_allocated_keys_commit_and_merge(
fn test_state_cache_commit_and_merge(
#[case] is_base_empty: bool,
#[case] storage_updates: Vec<Felt>,
#[case] charged: bool,
#[values(true, false)] comprehensive_state_diff: bool,
) {
let contract_address = contract_address!(CONTRACT_ADDRESS);
let storage_key = StorageKey::from(0x10_u16);
// Set initial state
let mut state: CachedState<DictStateReader> = CachedState::default();

let non_empty_base_value = felt!("0x1");
if !is_base_empty {
state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap();
state.set_storage_at(contract_address, storage_key, non_empty_base_value).unwrap();
}
let mut state_changes = vec![];
let mut state_caches = vec![];

for value in storage_updates {
for value in storage_updates.iter() {
// In the end of the previous loop, state has moved into the transactional state.
let mut transactional_state = TransactionalState::create_transactional(&mut state);
// Update state and collect the state changes.
transactional_state.set_storage_at(contract_address, storage_key, value).unwrap();
state_changes.push(transactional_state.get_actual_state_changes().unwrap());
transactional_state.set_storage_at(contract_address, storage_key, *value).unwrap();
state_caches.push(transactional_state.borrow_updated_state_cache().unwrap().clone());
transactional_state.commit();
}

let merged_changes = StateChanges::merge(state_changes);
assert_ne!(merged_changes.allocated_keys.is_empty(), charged);
let merged_changes =
StateCache::squash_state_diff(state_caches.iter().collect(), comprehensive_state_diff);
if comprehensive_state_diff {
// The comprehensive_state_diff is needed for backward compatibility of versions before the
// allocated keys feature was inserted.
assert_ne!(merged_changes.allocated_keys.is_empty(), charged);
}

// Test the storage diff.
let base_value = if is_base_empty { Felt::ZERO } else { non_empty_base_value };
let last_value = storage_updates.last().unwrap();
let expected_storage_diff = if (&base_value == last_value) && comprehensive_state_diff {
None
} else {
Some(last_value)
};
assert_eq!(
merged_changes.state_maps.storage.get(&(contract_address, storage_key)),
expected_storage_diff,
);
}

// Test that allocations in validate and execute phases are properly squashed.
#[rstest]
#[case(false, felt!("0x7"), felt!("0x8"), false)]
#[case(true, felt!("0x0"), felt!("0x8"), true)]
#[case(true, felt!("0x7"), felt!("0x7"), true)]
// TODO: not charge in the following case.
#[case(false, felt!("0x0"), felt!("0x8"), true)]
#[case(false, felt!("0x0"), felt!("0x8"), false)]
#[case(true, felt!("0x7"), felt!("0x0"), false)]
fn test_allocated_keys_two_transactions(
#[case] is_base_empty: bool,
Expand All @@ -475,14 +489,15 @@ fn test_allocated_keys_two_transactions(

let mut first_state = TransactionalState::create_transactional(&mut state);
first_state.set_storage_at(contract_address, storage_key, validate_value).unwrap();
let first_state_changes = first_state.get_actual_state_changes().unwrap();
let first_state_changes = first_state.borrow_updated_state_cache().unwrap().clone();

let mut second_state = TransactionalState::create_transactional(&mut first_state);
second_state.set_storage_at(contract_address, storage_key, execute_value).unwrap();
let second_state_changes = second_state.get_actual_state_changes().unwrap();
let second_state_changes = second_state.borrow_updated_state_cache().unwrap().clone();

let merged_changes = StateChanges::merge(vec![first_state_changes, second_state_changes]);
assert_ne!(merged_changes.allocated_keys.is_empty(), charged);
let merged_changes =
StateCache::squash_state_caches(vec![&first_state_changes, &second_state_changes]);
assert_ne!(merged_changes.to_state_diff().allocated_keys.is_empty(), charged);
}

#[test]
Expand Down
17 changes: 10 additions & 7 deletions crates/blockifier/src/transaction/account_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::fee::fee_utils::{
use crate::fee::gas_usage::estimate_minimal_gas_vector;
use crate::fee::receipt::TransactionReceipt;
use crate::retdata;
use crate::state::cached_state::{StateChanges, TransactionalState};
use crate::state::cached_state::{StateCache, TransactionalState};
use crate::state::state_api::{State, StateReader, UpdatableState};
use crate::transaction::errors::{
TransactionExecutionError,
Expand Down Expand Up @@ -575,7 +575,7 @@ impl AccountTransaction {

// Save the state changes resulting from running `validate_tx`, to be used later for
// resource and fee calculation.
let validate_state_changes = state.get_actual_state_changes()?;
let validate_state_cache = state.borrow_updated_state_cache()?.clone();

// Create copies of state and validate_resources for the execution.
// Both will be rolled back if the execution is reverted or committed upon success.
Expand All @@ -591,7 +591,7 @@ impl AccountTransaction {
let revert_receipt = TransactionReceipt::from_account_tx(
self,
&tx_context,
&validate_state_changes,
&validate_state_cache.to_state_diff(),
CallInfo::summarize_many(
validate_call_info.iter(),
&tx_context.block_context.versioned_constants,
Expand All @@ -606,10 +606,13 @@ impl AccountTransaction {
let tx_receipt = TransactionReceipt::from_account_tx(
self,
&tx_context,
&StateChanges::merge(vec![
validate_state_changes,
execution_state.get_actual_state_changes()?,
]),
&StateCache::squash_state_diff(
vec![
&validate_state_cache,
&execution_state.borrow_updated_state_cache()?.clone(),
],
tx_context.block_context.versioned_constants.comprehensive_state_diff,
),
CallInfo::summarize_many(
validate_call_info.iter().chain(execute_call_info.iter()),
&tx_context.block_context.versioned_constants,
Expand Down
Loading