Skip to content

Commit

Permalink
refactor(blockifier): state reader trait
Browse files Browse the repository at this point in the history
  • Loading branch information
AvivYossef-starkware committed Dec 17, 2024
1 parent 4fb9c16 commit 2176de2
Show file tree
Hide file tree
Showing 26 changed files with 197 additions and 79 deletions.
9 changes: 7 additions & 2 deletions crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::blockifier::config::TransactionExecutorConfig;
use crate::bouncer::{Bouncer, BouncerWeights};
use crate::concurrency::worker_logic::WorkerExecutor;
use crate::context::BlockContext;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult};
Expand Down Expand Up @@ -156,12 +157,16 @@ impl<S: StateReader> TransactionExecutor<S> {
.visited_pcs
.iter()
.map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> {
let contract_class = self
let versioned_contract_class = self
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_compiled_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
Ok((
*class_hash,
RunnableCompiledClass::from(versioned_contract_class)
.get_visited_segments(class_visited_pcs)?,
))
})
.collect::<TransactionExecutorResult<_>>()?;

Expand Down
3 changes: 2 additions & 1 deletion crates/blockifier/src/bouncer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::blockifier::transaction_executor::{
TransactionExecutorResult,
};
use crate::execution::call_info::ExecutionSummary;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::fee::gas_usage::get_onchain_data_segment_length;
use crate::fee::resources::TransactionResources;
use crate::state::cached_state::{StateChangesKeys, StorageEntry};
Expand Down Expand Up @@ -565,7 +566,7 @@ pub fn get_casm_hash_calculation_resources<S: StateReader>(
let mut casm_hash_computation_resources = ExecutionResources::default();

for class_hash in executed_class_hashes {
let class = state_reader.get_compiled_class(*class_hash)?;
let class: RunnableCompiledClass = state_reader.get_compiled_class(*class_hash)?.into();
casm_hash_computation_resources += &class.estimate_casm_hash_computation_resources();
}

Expand Down
4 changes: 2 additions & 2 deletions crates/blockifier/src/concurrency/worker_logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::concurrency::utils::lock_mutex_in_array;
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::concurrency::TxIndex;
use crate::context::BlockContext;
use crate::state::cached_state::{ContractClassMapping, StateMaps, TransactionalState};
use crate::state::cached_state::{StateMaps, TransactionalState, VersionedContractClassMapping};
use crate::state::state_api::{StateReader, UpdatableState};
use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult};
use crate::transaction::transaction_execution::Transaction;
Expand All @@ -32,7 +32,7 @@ pub struct ExecutionTaskOutput {
pub reads: StateMaps,
// TODO(Yoni): rename to state_diff.
pub writes: StateMaps,
pub contract_classes: ContractClassMapping,
pub contract_classes: VersionedContractClassMapping,
pub visited_pcs: HashMap<ClassHash, HashSet<usize>>,
pub result: TransactionExecutionResult<TransactionExecutionInfo>,
}
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub enum RunnableCompiledClass {
}

/// Represents a runnable compiled class for Cairo, with the Sierra version (for Cairo 1).
#[derive(Clone)]
#[derive(Debug, Clone, PartialEq)]
pub enum VersionedRunnableCompiledClass {
Cairo0(RunnableCompiledClass),
Cairo1((RunnableCompiledClass, SierraVersion)),
Expand Down
12 changes: 8 additions & 4 deletions crates/blockifier/src/execution/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use starknet_api::transaction::fields::{
use starknet_api::transaction::TransactionVersion;
use starknet_types_core::felt::Felt;

use super::contract_class::RunnableCompiledClass;
use crate::context::{BlockContext, TransactionContext};
use crate::execution::call_info::CallInfo;
use crate::execution::common_hints::ExecutionMode;
Expand Down Expand Up @@ -148,7 +149,7 @@ impl CallEntryPoint {
}
// Add class hash to the call, that will appear in the output (call info).
self.class_hash = Some(class_hash);
let compiled_class = state.get_compiled_class(class_hash)?;
let compiled_class: RunnableCompiledClass = state.get_compiled_class(class_hash)?.into();

context.revert_infos.0.push(EntryPointRevertInfo::new(
self.storage_address,
Expand Down Expand Up @@ -407,9 +408,12 @@ pub fn execute_constructor_entry_point(
remaining_gas: &mut u64,
) -> ConstructorEntryPointExecutionResult<CallInfo> {
// Ensure the class is declared (by reading it).
let compiled_class = state.get_compiled_class(ctor_context.class_hash).map_err(|error| {
ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None)
})?;
let compiled_class: RunnableCompiledClass = state
.get_compiled_class(ctor_context.class_hash)
.map_err(|error| {
ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None)
})?
.into();
let Some(constructor_selector) = compiled_class.constructor_selector() else {
// Contract has no constructor.
return handle_empty_constructor(&ctor_context, calldata, *remaining_gas)
Expand Down
4 changes: 3 additions & 1 deletion crates/blockifier/src/execution/syscalls/syscall_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::exceeds_event_size_limit;
use crate::abi::constants;
use crate::execution::call_info::{CallInfo, MessageToL1, OrderedEvent, OrderedL2ToL1Message};
use crate::execution::common_hints::ExecutionMode;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::execution::entry_point::{
CallEntryPoint,
ConstructorContext,
Expand Down Expand Up @@ -164,7 +165,8 @@ impl<'state> SyscallHandlerBase<'state> {

pub fn replace_class(&mut self, class_hash: ClassHash) -> SyscallResult<()> {
// Ensure the class is declared (by reading it), and of type V1.
let compiled_class = self.state.get_compiled_class(class_hash)?;
let compiled_class: RunnableCompiledClass =
self.state.get_compiled_class(class_hash)?.into();

if !is_cairo1(&compiled_class) {
return Err(SyscallExecutionError::ForbiddenClassReplacement { class_hash });
Expand Down
9 changes: 6 additions & 3 deletions crates/blockifier/src/state/state_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

use super::cached_state::{ContractClassMapping, StateMaps};
use crate::execution::contract_class::RunnableCompiledClass;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
use crate::state::errors::StateError;

pub type StateResult<T> = Result<T, StateError>;
Expand Down Expand Up @@ -39,8 +39,11 @@ pub trait StateReader {
/// Default: 0 (uninitialized class hash) for an uninitialized contract address.
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash>;

/// Returns the compiled class of the given class hash.
fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass>;
/// Returns the versioned runnable compiled class of the given class hash.
fn get_compiled_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass>;

/// Returns the compiled class hash of the given class hash.
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash>;
Expand Down
1 change: 0 additions & 1 deletion crates/blockifier/src/test_utils/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ impl FeatureContract {
self.get_class().try_into().unwrap()
}

#[allow(dead_code)]
pub fn get_versioned_runnable_class(&self) -> VersionedRunnableCompiledClass {
let runnable_class = self.get_runnable_class();
match self.cairo_version() {
Expand Down
15 changes: 9 additions & 6 deletions crates/blockifier/src/test_utils/dict_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

use crate::execution::contract_class::RunnableCompiledClass;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
use crate::state::cached_state::StorageEntry;
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult};
Expand All @@ -15,7 +15,7 @@ pub struct DictStateReader {
pub storage_view: HashMap<StorageEntry, Felt>,
pub address_to_nonce: HashMap<ContractAddress, Nonce>,
pub address_to_class_hash: HashMap<ContractAddress, ClassHash>,
pub class_hash_to_class: HashMap<ClassHash, RunnableCompiledClass>,
pub class_hash_to_class: HashMap<ClassHash, VersionedRunnableCompiledClass>,
pub class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
}

Expand All @@ -35,10 +35,13 @@ impl StateReader for DictStateReader {
Ok(nonce)
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
let contract_class = self.class_hash_to_class.get(&class_hash).cloned();
match contract_class {
Some(contract_class) => Ok(contract_class),
fn get_compiled_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass> {
let versioned_contract_class = self.class_hash_to_class.get(&class_hash).cloned();
match versioned_contract_class {
Some(versioned_contract_class) => Ok(versioned_contract_class),
_ => Err(StateError::UndeclaredClassHash(class_hash)),
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/blockifier/src/test_utils/initial_test_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub fn test_state_inner(

// Declare and deploy account and ERC20 contracts.
let erc20 = FeatureContract::ERC20(erc20_contract_version);
class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_runnable_class());
class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_versioned_runnable_class());
address_to_class_hash
.insert(chain_info.fee_token_address(&FeeType::Eth), erc20.get_class_hash());
address_to_class_hash
Expand All @@ -58,7 +58,7 @@ pub fn test_state_inner(
// Set up the rest of the requested contracts.
for (contract, n_instances) in contract_instances.iter() {
let class_hash = contract.get_class_hash();
class_hash_to_class.insert(class_hash, contract.get_runnable_class());
class_hash_to_class.insert(class_hash, contract.get_versioned_runnable_class());
for instance in 0..*n_instances {
let instance_address = contract.get_instance_address(instance);
address_to_class_hash.insert(instance_address, class_hash);
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/transaction/account_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ impl ValidatableTransaction for AccountTransaction {
})?;

// Validate return data.
let compiled_class = state.get_compiled_class(class_hash)?;
let compiled_class: RunnableCompiledClass = state.get_compiled_class(class_hash)?.into();
if is_cairo1(&compiled_class) {
// The account contract class is a Cairo 1.0 contract; the `validate` entry point should
// return `VALID`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,10 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) {
create_test_init_data(chain_info, CairoVersion::Cairo0);
let class_hash = class_hash!(0xdeadeadeaf72_u128);
let contract_class =
FeatureContract::Empty(CairoVersion::Cairo1(RunnableCairo1::Casm)).get_class();
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)).get_class();
let versioned_contract_class =
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm))
.get_versioned_runnable_class();
let next_nonce = nonce_manager.next(account_address);

// Cannot fail executing a declare tx unless it's V2 or above, and already declared.
Expand All @@ -789,7 +792,7 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) {
sender_address: account_address,
..Default::default()
};
state.set_contract_class(class_hash, contract_class.clone().try_into().unwrap()).unwrap();
state.set_contract_class(class_hash, versioned_contract_class.clone()).unwrap();
state.set_compiled_class_hash(class_hash, declare_tx_v2.compiled_class_hash).unwrap();
let class_info = calculate_class_info_for_testing(contract_class);
let executable_declare = ApiExecutableDeclareTransaction {
Expand Down
35 changes: 33 additions & 2 deletions crates/blockifier/src/transaction/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use starknet_api::transaction::{

use crate::context::{BlockContext, TransactionContext};
use crate::execution::call_info::CallInfo;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
use crate::execution::entry_point::{
CallEntryPoint,
CallType,
Expand Down Expand Up @@ -204,7 +205,23 @@ impl<S: State> Executable<S> for DeclareTransaction {
// We allow redeclaration of the class for backward compatibility.
// In the past, we allowed redeclaration of Cairo 0 contracts since there was
// no class commitment (so no need to check if the class is already declared).
state.set_contract_class(class_hash, self.contract_class().try_into()?)?;
let compiled_contract_class = self.contract_class();
let versioned_compiled_contract_class = {
match compiled_contract_class {
starknet_api::contract_class::ContractClass::V0(_) => {
VersionedRunnableCompiledClass::Cairo0(
compiled_contract_class.try_into()?,
)
}
starknet_api::contract_class::ContractClass::V1(_) => {
VersionedRunnableCompiledClass::Cairo1((
compiled_contract_class.try_into()?,
self.class_info.sierra_version.clone(),
))
}
}
};
state.set_contract_class(class_hash, versioned_compiled_contract_class)?;
}
}
starknet_api::transaction::DeclareTransaction::V2(DeclareTransactionV2 {
Expand Down Expand Up @@ -417,7 +434,21 @@ fn try_declare<S: State>(
match state.get_compiled_class(class_hash) {
Err(StateError::UndeclaredClassHash(_)) => {
// Class is undeclared; declare it.
state.set_contract_class(class_hash, tx.contract_class().try_into()?)?;
let compiled_contract_class = tx.contract_class();
let versioned_compiled_contract_class = {
match compiled_contract_class {
starknet_api::contract_class::ContractClass::V0(_) => {
VersionedRunnableCompiledClass::Cairo0(compiled_contract_class.try_into()?)
}
starknet_api::contract_class::ContractClass::V1(_) => {
VersionedRunnableCompiledClass::Cairo1((
compiled_contract_class.try_into()?,
tx.class_info.sierra_version.clone(),
))
}
}
};
state.set_contract_class(class_hash, versioned_compiled_contract_class)?;
if let Some(compiled_class_hash) = compiled_class_hash {
state.set_compiled_class_hash(class_hash, compiled_class_hash)?;
}
Expand Down
5 changes: 4 additions & 1 deletion crates/blockifier/src/transaction/transactions_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,8 @@ fn test_declare_tx(
#[case] empty_contract_version: CairoVersion,
#[values(false, true)] use_kzg_da: bool,
) {
use crate::execution::contract_class::RunnableCompiledClass;

let block_context = &BlockContext::create_for_account_testing_with_kzg(use_kzg_da);
let versioned_constants = &block_context.versioned_constants;
let empty_contract = FeatureContract::Empty(empty_contract_version);
Expand Down Expand Up @@ -1698,7 +1700,8 @@ fn test_declare_tx(
);

// Verify class declaration.
let contract_class_from_state = state.get_compiled_class(class_hash).unwrap();
let contract_class_from_state: RunnableCompiledClass =
state.get_compiled_class(class_hash).unwrap().into();
assert_eq!(contract_class_from_state, class_info.contract_class().try_into().unwrap());

// Checks that redeclaring the same contract fails.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig;
use blockifier::blockifier::transaction_executor::TransactionExecutor;
use blockifier::bouncer::BouncerConfig;
use blockifier::context::BlockContext;
use blockifier::execution::contract_class::RunnableCompiledClass;
use blockifier::execution::contract_class::{
RunnableCompiledClass,
VersionedRunnableCompiledClass,
};
use blockifier::state::cached_state::{CommitmentStateDiff, StateMaps};
use blockifier::state::errors::StateError;
use blockifier::state::state_api::{StateReader, StateResult};
Expand Down Expand Up @@ -157,15 +160,22 @@ impl StateReader for OfflineStateReader {
)?)
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
fn get_compiled_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass> {
match self.get_contract_class(&class_hash)? {
StarknetContractClass::Sierra(sierra) => {
let (casm, _) = sierra_to_versioned_contract_class_v1(sierra).unwrap();
Ok(casm.try_into().unwrap())
}
StarknetContractClass::Legacy(legacy) => {
Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap())
let (casm, sierra_version) = sierra_to_versioned_contract_class_v1(sierra).unwrap();
let runnable_compiled_class: RunnableCompiledClass = casm.try_into().unwrap();
Ok(VersionedRunnableCompiledClass::Cairo1((
runnable_compiled_class,
sierra_version,
)))
}
StarknetContractClass::Legacy(legacy) => Ok(VersionedRunnableCompiledClass::Cairo0(
legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(),
)),
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig;
use blockifier::blockifier::transaction_executor::TransactionExecutor;
use blockifier::bouncer::BouncerConfig;
use blockifier::context::BlockContext;
use blockifier::execution::contract_class::RunnableCompiledClass;
use blockifier::execution::contract_class::{
RunnableCompiledClass,
VersionedRunnableCompiledClass,
};
use blockifier::state::cached_state::CommitmentStateDiff;
use blockifier::state::errors::StateError;
use blockifier::state::state_api::{StateReader, StateResult};
Expand Down Expand Up @@ -128,18 +131,25 @@ impl StateReader for TestStateReader {

/// Returns the contract class of the given class hash.
/// Compile the contract class if it is Sierra.
fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
fn get_compiled_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass> {
let contract_class =
retry_request!(self.retry_config, || self.get_contract_class(&class_hash))?;

match contract_class {
StarknetContractClass::Sierra(sierra) => {
let (casm, _) = sierra_to_versioned_contract_class_v1(sierra).unwrap();
Ok(RunnableCompiledClass::try_from(casm).unwrap())
}
StarknetContractClass::Legacy(legacy) => {
Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap())
let (casm, sierra_version) = sierra_to_versioned_contract_class_v1(sierra).unwrap();
let runnable_contract_class: RunnableCompiledClass = casm.try_into().unwrap();
Ok(VersionedRunnableCompiledClass::Cairo1((
runnable_contract_class,
sierra_version,
)))
}
StarknetContractClass::Legacy(legacy) => Ok(VersionedRunnableCompiledClass::Cairo0(
legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(),
)),
}
}

Expand Down
Loading

0 comments on commit 2176de2

Please sign in to comment.