Skip to content

Commit

Permalink
refactor(blockifier): add sierra version to get compiled class
Browse files Browse the repository at this point in the history
  • Loading branch information
AvivYossef-starkware committed Dec 4, 2024
1 parent 1083f6b commit 390a375
Show file tree
Hide file tree
Showing 58 changed files with 564 additions and 318 deletions.
6 changes: 3 additions & 3 deletions crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ impl<S: StateReader> TransactionExecutor<S> {
.visited_pcs
.iter()
.map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> {
let contract_class = self
let (contract_class,_) = self
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_compiled_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
.get_compiled_contract_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
})
.collect::<TransactionExecutorResult<_>>()?;

Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/bouncer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ pub fn get_casm_hash_calculation_resources<S: StateReader>(
let mut casm_hash_computation_resources = ExecutionResources::default();

for class_hash in executed_class_hashes {
let class = state_reader.get_compiled_class(*class_hash)?;
let (class, _) = state_reader.get_compiled_contract_class(*class_hash)?;
casm_hash_computation_resources += &class.estimate_casm_hash_computation_resources();
}

Expand Down
11 changes: 7 additions & 4 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use starknet_types_core::felt::Felt;

use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};
Expand All @@ -34,7 +34,7 @@ pub struct VersionedState<S: StateReader> {
// the compiled contract classes mapping. Each key with value false, sohuld not apprear
// in the compiled contract classes mapping.
declared_contracts: VersionedStorage<ClassHash, bool>,
compiled_contract_classes: VersionedStorage<ClassHash, RunnableCompiledClass>,
compiled_contract_classes: VersionedStorage<ClassHash, VersionedRunnableCompiledClass>,
}

impl<S: StateReader> VersionedState<S> {
Expand Down Expand Up @@ -336,11 +336,14 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {
}
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
fn get_compiled_contract_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass> {
let mut state = self.state();
match state.compiled_contract_classes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
None => match state.initial_state.get_compiled_class(class_hash) {
None => match state.initial_state.get_compiled_contract_class(class_hash) {
Ok(initial_value) => {
state.declared_contracts.set_initial_value(class_hash, true);
state
Expand Down
62 changes: 38 additions & 24 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn test_versioned_state_proxy() {
let class_hash = class_hash!(27_u8);
let another_class_hash = class_hash!(28_u8);
let compiled_class_hash = compiled_class_hash!(29_u8);
let contract_class = test_contract.get_runnable_class();
let contract_class = (test_contract.get_runnable_class(), test_contract.get_sierra_version());

// Create the versioned state
let cached_state = CachedState::from(DictStateReader {
Expand All @@ -98,9 +98,12 @@ fn test_versioned_state_proxy() {
versioned_state_proxys[5].get_compiled_class_hash(class_hash).unwrap(),
compiled_class_hash
);
assert_eq!(versioned_state_proxys[7].get_compiled_class(class_hash).unwrap(), contract_class);
assert_eq!(
versioned_state_proxys[7].get_compiled_contract_class(class_hash).unwrap(),
contract_class
);
assert_matches!(
versioned_state_proxys[7].get_compiled_class(another_class_hash).unwrap_err(),
versioned_state_proxys[7].get_compiled_contract_class(another_class_hash).unwrap_err(),
StateError::UndeclaredClassHash(class_hash) if
another_class_hash == class_hash
);
Expand All @@ -115,8 +118,10 @@ fn test_versioned_state_proxy() {
let class_hash_v7 = class_hash!(28_u8);
let class_hash_v10 = class_hash!(29_u8);
let compiled_class_hash_v18 = compiled_class_hash!(30_u8);
let contract_class_v11 =
FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class();
let contract_class_v11 = (
FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(),
FeatureContract::TestContract(CairoVersion::Cairo1).get_sierra_version(),
);

versioned_state_proxys[3].state().apply_writes(
3,
Expand Down Expand Up @@ -195,7 +200,7 @@ fn test_versioned_state_proxy() {
compiled_class_hash_v18
);
assert_eq!(
versioned_state_proxys[15].get_compiled_class(class_hash).unwrap(),
versioned_state_proxys[15].get_compiled_contract_class(class_hash).unwrap(),
contract_class_v11
);
}
Expand Down Expand Up @@ -323,7 +328,7 @@ fn test_validate_reads(

assert!(transactional_state.cache.borrow().initial_reads.declared_contracts.is_empty());
assert_matches!(
transactional_state.get_compiled_class(class_hash),
transactional_state.get_compiled_contract_class(class_hash),
Err(StateError::UndeclaredClassHash(err_class_hash)) if
err_class_hash == class_hash
);
Expand Down Expand Up @@ -405,9 +410,10 @@ fn test_false_validate_reads_declared_contracts(
..Default::default()
};
let version_state_proxy = safe_versioned_state.pin_version(0);
let compiled_contract_calss =
FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class();
let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]);
let compiled_contract_class =
FeatureContract::TestContract(CairoVersion::Cairo1).get_compiled_contract_class();

let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_class)]);
version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class);
assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads));
}
Expand All @@ -431,9 +437,12 @@ fn test_apply_writes(
assert_eq!(transactional_states[0].cache.borrow().writes.class_hashes.len(), 1);

// Transaction 0 contract class.
let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class();
let contract_class_0 =
FeatureContract::TestContract(CairoVersion::Cairo1).get_compiled_contract_class();
assert!(transactional_states[0].class_hash_to_class.borrow().is_empty());
transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap();
transactional_states[0]
.set_compiled_contract_class(class_hash, contract_class_0.clone())
.unwrap();
assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1);

safe_versioned_state.pin_version(0).apply_writes(
Expand All @@ -442,7 +451,7 @@ fn test_apply_writes(
&HashMap::default(),
);
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0);
assert!(transactional_states[1].get_compiled_class(class_hash).unwrap() == contract_class_0);
assert!(transactional_states[1].get_compiled_contract_class(class_hash).unwrap() == contract_class_0);
}

#[rstest]
Expand Down Expand Up @@ -508,9 +517,9 @@ fn test_delete_writes(
}
// Modify the `class_hash_to_class` member of the CachedState.
tx_state
.set_contract_class(
.set_compiled_contract_class(
feature_contract.get_class_hash(),
feature_contract.get_runnable_class(),
feature_contract.get_compiled_contract_class(),
)
.unwrap();
safe_versioned_state.pin_version(i).apply_writes(
Expand Down Expand Up @@ -569,8 +578,10 @@ fn test_delete_writes_completeness(
)]),
declared_contracts: HashMap::from([(feature_contract.get_class_hash(), true)]),
};
let class_hash_to_class_writes =
HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_runnable_class())]);
let class_hash_to_class_writes = HashMap::from([(
feature_contract.get_class_hash(),
feature_contract.get_compiled_contract_class(),
)]);

let tx_index = 0;
let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index);
Expand Down Expand Up @@ -633,12 +644,15 @@ fn test_versioned_proxy_state_flow(
transactional_states[3].set_class_hash_at(contract_address, class_hash_3).unwrap();

// Clients contract class values.
let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_runnable_class();
let contract_class_2 =
FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1).get_runnable_class();

transactional_states[0].set_contract_class(class_hash, contract_class_0).unwrap();
transactional_states[2].set_contract_class(class_hash, contract_class_2.clone()).unwrap();
let contract_class_0 =
FeatureContract::TestContract(CairoVersion::Cairo0).get_compiled_contract_class();
let contract_class_2 = FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1)
.get_compiled_contract_class();

transactional_states[0].set_compiled_contract_class(class_hash, contract_class_0).unwrap();
transactional_states[2]
.set_compiled_contract_class(class_hash, contract_class_2.clone())
.unwrap();

// Apply the changes.
for (i, transactional_state) in transactional_states.iter_mut().enumerate() {
Expand All @@ -658,5 +672,5 @@ fn test_versioned_proxy_state_flow(
.commit_chunk_and_recover_block_state(4, HashMap::new());

assert!(modified_block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3);
assert!(modified_block_state.get_compiled_class(class_hash).unwrap() == contract_class_2);
assert!(modified_block_state.get_compiled_contract_class(class_hash).unwrap() == contract_class_2);
}
4 changes: 3 additions & 1 deletion crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use itertools::Itertools;
use semver::Version;
use serde::de::Error as DeserializationError;
use serde::{Deserialize, Deserializer, Serialize};
use starknet_api::contract_class::{ContractClass, EntryPointType};
use starknet_api::contract_class::{ContractClass, EntryPointType, SierraVersion};
use starknet_api::core::EntryPointSelector;
use starknet_api::deprecated_contract_class::{
ContractClass as DeprecatedContractClass,
Expand Down Expand Up @@ -68,6 +68,8 @@ pub enum RunnableCompiledClass {
V1Native(NativeCompiledClassV1),
}

pub type VersionedRunnableCompiledClass = (RunnableCompiledClass, SierraVersion);

impl TryFrom<ContractClass> for RunnableCompiledClass {
type Error = ProgramError;

Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/execution/deprecated_syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ pub fn replace_class(
syscall_handler: &mut DeprecatedSyscallHintProcessor<'_>,
) -> DeprecatedSyscallResult<ReplaceClassResponse> {
// Ensure the class is declared (by reading it).
syscall_handler.state.get_compiled_class(request.class_hash)?;
syscall_handler.state.get_compiled_contract_class(request.class_hash)?;
syscall_handler.state.set_class_hash_at(syscall_handler.storage_address, request.class_hash)?;

Ok(ReplaceClassResponse {})
Expand Down
19 changes: 13 additions & 6 deletions crates/blockifier/src/execution/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl CallEntryPoint {
}
// Add class hash to the call, that will appear in the output (call info).
self.class_hash = Some(class_hash);
let compiled_class = state.get_compiled_class(class_hash)?;
let (runnable_compiled_class, _) = state.get_compiled_contract_class(class_hash)?;

context.revert_infos.0.push(EntryPointRevertInfo::new(
self.storage_address,
Expand All @@ -157,7 +157,13 @@ impl CallEntryPoint {
));

// This is the last operation of this function.
execute_entry_point_call_wrapper(self, compiled_class, state, context, remaining_gas)
execute_entry_point_call_wrapper(
self,
runnable_compiled_class,
state,
context,
remaining_gas,
)
}

/// Similar to `execute`, but returns an error if the outer call is reverted.
Expand Down Expand Up @@ -402,10 +408,11 @@ pub fn execute_constructor_entry_point(
remaining_gas: &mut u64,
) -> ConstructorEntryPointExecutionResult<CallInfo> {
// Ensure the class is declared (by reading it).
let compiled_class = state.get_compiled_class(ctor_context.class_hash).map_err(|error| {
ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None)
})?;
let Some(constructor_selector) = compiled_class.constructor_selector() else {
let (contract_class, _) =
state.get_compiled_contract_class(ctor_context.class_hash).map_err(|error| {
ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None)
})?;
let Some(constructor_selector) = contract_class.constructor_selector() else {
// Contract has no constructor.
return handle_empty_constructor(&ctor_context, calldata, *remaining_gas)
.map_err(|error| ConstructorEntryPointExecutionError::new(error, &ctor_context, None));
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/execution/native/syscall_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> {
.map_err(|err| self.handle_error(remaining_gas, err))?;
Ok(())
}

fn library_call(
&mut self,
class_hash: Felt,
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/execution/syscalls/syscall_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl<'state> SyscallHandlerBase<'state> {

pub fn replace_class(&mut self, class_hash: ClassHash) -> SyscallResult<()> {
// Ensure the class is declared (by reading it), and of type V1.
let compiled_class = self.state.get_compiled_class(class_hash)?;
let (compiled_class, _) = self.state.get_compiled_contract_class(class_hash)?;

if !is_cairo1(&compiled_class) {
return Err(SyscallExecutionError::ForbiddenClassReplacement { class_hash });
Expand Down
22 changes: 14 additions & 8 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

use crate::context::TransactionContext;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
use crate::state::errors::StateError;
use crate::state::state_api::{State, StateReader, StateResult, UpdatableState};
use crate::transaction::objects::TransactionExecutionInfo;
Expand All @@ -18,7 +18,7 @@ use crate::utils::{strict_subtract_mappings, subtract_mappings};
#[path = "cached_state_test.rs"]
mod test;

pub type ContractClassMapping = HashMap<ClassHash, RunnableCompiledClass>;
pub type ContractClassMapping = HashMap<ClassHash, VersionedRunnableCompiledClass>;

/// Caches read and write requests.
///
Expand Down Expand Up @@ -173,14 +173,17 @@ impl<S: StateReader> StateReader for CachedState<S> {
Ok(*class_hash)
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
fn get_compiled_contract_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass> {
let mut cache = self.cache.borrow_mut();
let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut();

if let std::collections::hash_map::Entry::Vacant(vacant_entry) =
class_hash_to_class.entry(class_hash)
{
match self.state.get_compiled_class(class_hash) {
match self.state.get_compiled_contract_class(class_hash) {
Err(StateError::UndeclaredClassHash(class_hash)) => {
cache.set_declared_contract_initial_value(class_hash, false);
cache.set_compiled_class_hash_initial_value(
Expand Down Expand Up @@ -253,10 +256,10 @@ impl<S: StateReader> State for CachedState<S> {
Ok(())
}

fn set_contract_class(
fn set_compiled_contract_class(
&mut self,
class_hash: ClassHash,
contract_class: RunnableCompiledClass,
contract_class: VersionedRunnableCompiledClass,
) -> StateResult<()> {
self.class_hash_to_class.get_mut().insert(class_hash, contract_class);
let mut cache = self.cache.borrow_mut();
Expand Down Expand Up @@ -524,8 +527,11 @@ impl<S: StateReader + ?Sized> StateReader for MutRefState<'_, S> {
self.0.get_class_hash_at(contract_address)
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
self.0.get_compiled_class(class_hash)
fn get_compiled_contract_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass> {
self.0.get_compiled_contract_class(class_hash)
}

fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
Expand Down
Loading

0 comments on commit 390a375

Please sign in to comment.