diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index 4ad8bfb285..a9c7a7179c 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -158,12 +158,12 @@ impl TransactionExecutor { .visited_pcs .iter() .map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> { - let contract_class = self + let (contract_class,_) = self .block_state .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) - .get_compiled_class(*class_hash)?; - Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) + .get_compiled_contract_class(*class_hash)?; + Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) }) .collect::>()?; diff --git a/crates/blockifier/src/bouncer.rs b/crates/blockifier/src/bouncer.rs index 7788b06be0..c1de901aef 100644 --- a/crates/blockifier/src/bouncer.rs +++ b/crates/blockifier/src/bouncer.rs @@ -551,7 +551,7 @@ pub fn get_casm_hash_calculation_resources( let mut casm_hash_computation_resources = ExecutionResources::default(); for class_hash in executed_class_hashes { - let class = state_reader.get_compiled_class(*class_hash)?; + let (class, _) = state_reader.get_compiled_contract_class(*class_hash)?; casm_hash_computation_resources += &class.estimate_casm_hash_computation_resources(); } diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index 1d3c9a9270..ce396febd8 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -7,7 +7,7 @@ use starknet_types_core::felt::Felt; use crate::concurrency::versioned_storage::VersionedStorage; use crate::concurrency::TxIndex; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult, UpdatableState}; @@ -34,7 +34,7 @@ pub struct VersionedState { // the compiled contract classes mapping. Each key with value false, sohuld not apprear // in the compiled contract classes mapping. declared_contracts: VersionedStorage, - compiled_contract_classes: VersionedStorage, + compiled_contract_classes: VersionedStorage, } impl VersionedState { @@ -336,11 +336,14 @@ impl StateReader for VersionedStateProxy { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let mut state = self.state(); match state.compiled_contract_classes.read(self.tx_index, class_hash) { Some(value) => Ok(value), - None => match state.initial_state.get_compiled_class(class_hash) { + None => match state.initial_state.get_compiled_contract_class(class_hash) { Ok(initial_value) => { state.declared_contracts.set_initial_value(class_hash, true); state diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index 48a0a95065..7ad41171ef 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -72,7 +72,7 @@ fn test_versioned_state_proxy() { let class_hash = class_hash!(27_u8); let another_class_hash = class_hash!(28_u8); let compiled_class_hash = compiled_class_hash!(29_u8); - let contract_class = test_contract.get_runnable_class(); + let contract_class = (test_contract.get_runnable_class(), test_contract.get_sierra_version()); // Create the versioned state let cached_state = CachedState::from(DictStateReader { @@ -98,9 +98,12 @@ fn test_versioned_state_proxy() { versioned_state_proxys[5].get_compiled_class_hash(class_hash).unwrap(), compiled_class_hash ); - assert_eq!(versioned_state_proxys[7].get_compiled_class(class_hash).unwrap(), contract_class); + assert_eq!( + versioned_state_proxys[7].get_compiled_contract_class(class_hash).unwrap(), + contract_class + ); assert_matches!( - versioned_state_proxys[7].get_compiled_class(another_class_hash).unwrap_err(), + versioned_state_proxys[7].get_compiled_contract_class(another_class_hash).unwrap_err(), StateError::UndeclaredClassHash(class_hash) if another_class_hash == class_hash ); @@ -115,8 +118,10 @@ fn test_versioned_state_proxy() { let class_hash_v7 = class_hash!(28_u8); let class_hash_v10 = class_hash!(29_u8); let compiled_class_hash_v18 = compiled_class_hash!(30_u8); - let contract_class_v11 = - FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(); + let contract_class_v11 = ( + FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(), + FeatureContract::TestContract(CairoVersion::Cairo1).get_sierra_version(), + ); versioned_state_proxys[3].state().apply_writes( 3, @@ -195,7 +200,7 @@ fn test_versioned_state_proxy() { compiled_class_hash_v18 ); assert_eq!( - versioned_state_proxys[15].get_compiled_class(class_hash).unwrap(), + versioned_state_proxys[15].get_compiled_contract_class(class_hash).unwrap(), contract_class_v11 ); } @@ -323,7 +328,7 @@ fn test_validate_reads( assert!(transactional_state.cache.borrow().initial_reads.declared_contracts.is_empty()); assert_matches!( - transactional_state.get_compiled_class(class_hash), + transactional_state.get_compiled_contract_class(class_hash), Err(StateError::UndeclaredClassHash(err_class_hash)) if err_class_hash == class_hash ); @@ -405,9 +410,10 @@ fn test_false_validate_reads_declared_contracts( ..Default::default() }; let version_state_proxy = safe_versioned_state.pin_version(0); - let compiled_contract_calss = - FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(); - let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]); + let compiled_contract_class = + FeatureContract::TestContract(CairoVersion::Cairo1).get_compiled_contract_class(); + + let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_class)]); version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class); assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); } @@ -431,9 +437,12 @@ fn test_apply_writes( assert_eq!(transactional_states[0].cache.borrow().writes.class_hashes.len(), 1); // Transaction 0 contract class. - let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(); + let contract_class_0 = + FeatureContract::TestContract(CairoVersion::Cairo1).get_compiled_contract_class(); assert!(transactional_states[0].class_hash_to_class.borrow().is_empty()); - transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap(); + transactional_states[0] + .set_compiled_contract_class(class_hash, contract_class_0.clone()) + .unwrap(); assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1); safe_versioned_state.pin_version(0).apply_writes( @@ -442,7 +451,7 @@ fn test_apply_writes( &HashMap::default(), ); assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); - assert!(transactional_states[1].get_compiled_class(class_hash).unwrap() == contract_class_0); + assert!(transactional_states[1].get_compiled_contract_class(class_hash).unwrap() == contract_class_0); } #[rstest] @@ -508,9 +517,9 @@ fn test_delete_writes( } // Modify the `class_hash_to_class` member of the CachedState. tx_state - .set_contract_class( + .set_compiled_contract_class( feature_contract.get_class_hash(), - feature_contract.get_runnable_class(), + feature_contract.get_compiled_contract_class(), ) .unwrap(); safe_versioned_state.pin_version(i).apply_writes( @@ -569,8 +578,10 @@ fn test_delete_writes_completeness( )]), declared_contracts: HashMap::from([(feature_contract.get_class_hash(), true)]), }; - let class_hash_to_class_writes = - HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_runnable_class())]); + let class_hash_to_class_writes = HashMap::from([( + feature_contract.get_class_hash(), + feature_contract.get_compiled_contract_class(), + )]); let tx_index = 0; let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index); @@ -633,12 +644,15 @@ fn test_versioned_proxy_state_flow( transactional_states[3].set_class_hash_at(contract_address, class_hash_3).unwrap(); // Clients contract class values. - let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_runnable_class(); - let contract_class_2 = - FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1).get_runnable_class(); - - transactional_states[0].set_contract_class(class_hash, contract_class_0).unwrap(); - transactional_states[2].set_contract_class(class_hash, contract_class_2.clone()).unwrap(); + let contract_class_0 = + FeatureContract::TestContract(CairoVersion::Cairo0).get_compiled_contract_class(); + let contract_class_2 = FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1) + .get_compiled_contract_class(); + + transactional_states[0].set_compiled_contract_class(class_hash, contract_class_0).unwrap(); + transactional_states[2] + .set_compiled_contract_class(class_hash, contract_class_2.clone()) + .unwrap(); // Apply the changes. for (i, transactional_state) in transactional_states.iter_mut().enumerate() { @@ -658,5 +672,5 @@ fn test_versioned_proxy_state_flow( .commit_chunk_and_recover_block_state(4, HashMap::new()); assert!(modified_block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3); - assert!(modified_block_state.get_compiled_class(class_hash).unwrap() == contract_class_2); + assert!(modified_block_state.get_compiled_contract_class(class_hash).unwrap() == contract_class_2); } diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index cf3d7289f9..061684e59f 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -21,7 +21,7 @@ use itertools::Itertools; use semver::Version; use serde::de::Error as DeserializationError; use serde::{Deserialize, Deserializer, Serialize}; -use starknet_api::contract_class::{ContractClass, EntryPointType}; +use starknet_api::contract_class::{ContractClass, EntryPointType, SierraVersion}; use starknet_api::core::EntryPointSelector; use starknet_api::deprecated_contract_class::{ ContractClass as DeprecatedContractClass, @@ -68,6 +68,8 @@ pub enum RunnableCompiledClass { V1Native(NativeCompiledClassV1), } +pub type VersionedRunnableCompiledClass = (RunnableCompiledClass, SierraVersion); + impl TryFrom for RunnableCompiledClass { type Error = ProgramError; diff --git a/crates/blockifier/src/execution/deprecated_syscalls/mod.rs b/crates/blockifier/src/execution/deprecated_syscalls/mod.rs index 8a15ad9224..ff0741948f 100644 --- a/crates/blockifier/src/execution/deprecated_syscalls/mod.rs +++ b/crates/blockifier/src/execution/deprecated_syscalls/mod.rs @@ -653,7 +653,7 @@ pub fn replace_class( syscall_handler: &mut DeprecatedSyscallHintProcessor<'_>, ) -> DeprecatedSyscallResult { // Ensure the class is declared (by reading it). - syscall_handler.state.get_compiled_class(request.class_hash)?; + syscall_handler.state.get_compiled_contract_class(request.class_hash)?; syscall_handler.state.set_class_hash_at(syscall_handler.storage_address, request.class_hash)?; Ok(ReplaceClassResponse {}) diff --git a/crates/blockifier/src/execution/entry_point.rs b/crates/blockifier/src/execution/entry_point.rs index c89edd7265..12b2068363 100644 --- a/crates/blockifier/src/execution/entry_point.rs +++ b/crates/blockifier/src/execution/entry_point.rs @@ -147,7 +147,7 @@ impl CallEntryPoint { } // Add class hash to the call, that will appear in the output (call info). self.class_hash = Some(class_hash); - let compiled_class = state.get_compiled_class(class_hash)?; + let (runnable_compiled_class, _) = state.get_compiled_contract_class(class_hash)?; context.revert_infos.0.push(EntryPointRevertInfo::new( self.storage_address, @@ -157,7 +157,13 @@ impl CallEntryPoint { )); // This is the last operation of this function. - execute_entry_point_call_wrapper(self, compiled_class, state, context, remaining_gas) + execute_entry_point_call_wrapper( + self, + runnable_compiled_class, + state, + context, + remaining_gas, + ) } /// Similar to `execute`, but returns an error if the outer call is reverted. @@ -402,10 +408,11 @@ pub fn execute_constructor_entry_point( remaining_gas: &mut u64, ) -> ConstructorEntryPointExecutionResult { // Ensure the class is declared (by reading it). - let compiled_class = state.get_compiled_class(ctor_context.class_hash).map_err(|error| { - ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None) - })?; - let Some(constructor_selector) = compiled_class.constructor_selector() else { + let (contract_class, _) = + state.get_compiled_contract_class(ctor_context.class_hash).map_err(|error| { + ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None) + })?; + let Some(constructor_selector) = contract_class.constructor_selector() else { // Contract has no constructor. return handle_empty_constructor(&ctor_context, calldata, *remaining_gas) .map_err(|error| ConstructorEntryPointExecutionError::new(error, &ctor_context, None)); diff --git a/crates/blockifier/src/execution/native/syscall_handler.rs b/crates/blockifier/src/execution/native/syscall_handler.rs index 2b115f6cb2..54f3c675b7 100644 --- a/crates/blockifier/src/execution/native/syscall_handler.rs +++ b/crates/blockifier/src/execution/native/syscall_handler.rs @@ -296,7 +296,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { .map_err(|err| self.handle_error(remaining_gas, err))?; Ok(()) } - + fn library_call( &mut self, class_hash: Felt, diff --git a/crates/blockifier/src/execution/syscalls/syscall_base.rs b/crates/blockifier/src/execution/syscalls/syscall_base.rs index f664d2ab47..c2ef6ab369 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_base.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_base.rs @@ -164,7 +164,7 @@ impl<'state> SyscallHandlerBase<'state> { pub fn replace_class(&mut self, class_hash: ClassHash) -> SyscallResult<()> { // Ensure the class is declared (by reading it), and of type V1. - let compiled_class = self.state.get_compiled_class(class_hash)?; + let (compiled_class, _) = self.state.get_compiled_contract_class(class_hash)?; if !is_cairo1(&compiled_class) { return Err(SyscallExecutionError::ForbiddenClassReplacement { class_hash }); diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 356b7b8f89..852fb0daf1 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -8,7 +8,7 @@ use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use crate::context::TransactionContext; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader, StateResult, UpdatableState}; use crate::transaction::objects::TransactionExecutionInfo; @@ -18,7 +18,7 @@ use crate::utils::{strict_subtract_mappings, subtract_mappings}; #[path = "cached_state_test.rs"] mod test; -pub type ContractClassMapping = HashMap; +pub type ContractClassMapping = HashMap; /// Caches read and write requests. /// @@ -173,14 +173,17 @@ impl StateReader for CachedState { Ok(*class_hash) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let mut cache = self.cache.borrow_mut(); let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut(); if let std::collections::hash_map::Entry::Vacant(vacant_entry) = class_hash_to_class.entry(class_hash) { - match self.state.get_compiled_class(class_hash) { + match self.state.get_compiled_contract_class(class_hash) { Err(StateError::UndeclaredClassHash(class_hash)) => { cache.set_declared_contract_initial_value(class_hash, false); cache.set_compiled_class_hash_initial_value( @@ -253,10 +256,10 @@ impl State for CachedState { Ok(()) } - fn set_contract_class( + fn set_compiled_contract_class( &mut self, class_hash: ClassHash, - contract_class: RunnableCompiledClass, + contract_class: VersionedRunnableCompiledClass, ) -> StateResult<()> { self.class_hash_to_class.get_mut().insert(class_hash, contract_class); let mut cache = self.cache.borrow_mut(); @@ -524,8 +527,11 @@ impl StateReader for MutRefState<'_, S> { self.0.get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - self.0.get_compiled_class(class_hash) + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + self.0.get_compiled_contract_class(class_hash) } fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 20f191a99c..ae8369a19c 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -108,14 +108,14 @@ fn declare_contract() { let mut state = CachedState::from(DictStateReader { ..Default::default() }); let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); - let contract_class = test_contract.get_runnable_class(); + let contract_class = test_contract.get_compiled_contract_class(); assert_eq!(state.cache.borrow().writes.declared_contracts.get(&class_hash), None); assert_eq!(state.cache.borrow().initial_reads.declared_contracts.get(&class_hash), None); // Reading an undeclared contract class. assert_matches!( - state.get_compiled_class(class_hash).unwrap_err(), + state.get_compiled_contract_class(class_hash).unwrap_err(), StateError::UndeclaredClassHash(undeclared_class_hash) if undeclared_class_hash == class_hash ); @@ -124,7 +124,7 @@ fn declare_contract() { false ); - state.set_contract_class(class_hash, contract_class).unwrap(); + state.set_compiled_contract_class(class_hash, contract_class).unwrap(); assert_eq!(*state.cache.borrow().writes.declared_contracts.get(&class_hash).unwrap(), true); } @@ -161,19 +161,19 @@ fn get_and_increment_nonce() { } #[test] -fn get_contract_class() { +fn get_compiled_contract_class() { // Positive flow. let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 0)]); assert_eq!( - state.get_compiled_class(test_contract.get_class_hash()).unwrap(), - test_contract.get_runnable_class() + state.get_compiled_contract_class(test_contract.get_class_hash()).unwrap(), + test_contract.get_compiled_contract_class() ); // Negative flow. let missing_class_hash = class_hash!("0x101"); assert_matches!( - state.get_compiled_class(missing_class_hash).unwrap_err(), + state.get_compiled_contract_class(missing_class_hash).unwrap_err(), StateError::UndeclaredClassHash(undeclared) if undeclared == missing_class_hash ); } @@ -215,7 +215,7 @@ fn cached_state_state_diff_conversion() { let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let test_class_hash = test_contract.get_class_hash(); let class_hash_to_class = - HashMap::from([(test_class_hash, test_contract.get_runnable_class())]); + HashMap::from([(test_class_hash, test_contract.get_compiled_contract_class())]); let nonce_initial_values = HashMap::new(); @@ -261,7 +261,7 @@ fn cached_state_state_diff_conversion() { let class_hash = FeatureContract::Empty(CairoVersion::Cairo0).get_class_hash(); let compiled_class_hash = compiled_class_hash!(1_u8); // Cache the initial read value, as in regular declare flow. - state.get_compiled_class(class_hash).unwrap_err(); + state.get_compiled_contract_class(class_hash).unwrap_err(); state.set_compiled_class_hash(class_hash, compiled_class_hash).unwrap(); // Write the initial value using key contract_address1. @@ -309,7 +309,7 @@ fn create_state_changes_for_test( state.increment_nonce(contract_address2).unwrap(); // Fill the initial read value, as in regular flow. - state.get_compiled_class(class_hash).unwrap_err(); + state.get_compiled_contract_class(class_hash).unwrap_err(); state.set_compiled_class_hash(class_hash, compiled_class_hash).unwrap(); // Assign the existing value to the storage (this shouldn't be considered a change). @@ -491,7 +491,7 @@ fn test_contract_cache_is_used() { // cache. let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); - let contract_class = test_contract.get_runnable_class(); + let contract_class = test_contract.get_compiled_contract_class(); let mut reader = DictStateReader::default(); reader.class_hash_to_class.insert(class_hash, contract_class.clone()); let state = CachedState::new(reader); @@ -500,7 +500,7 @@ fn test_contract_cache_is_used() { assert!(state.class_hash_to_class.borrow().get(&class_hash).is_none()); // Check state uses the cache. - assert_eq!(state.get_compiled_class(class_hash).unwrap(), contract_class); + assert_eq!(state.get_compiled_contract_class(class_hash).unwrap(), contract_class); assert_eq!(state.class_hash_to_class.borrow().get(&class_hash).unwrap(), &contract_class); } diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 96e24c1e1a..6fff52d3da 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -6,7 +6,7 @@ use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use super::cached_state::{ContractClassMapping, StateMaps}; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::errors::StateError; pub type StateResult = Result; @@ -39,8 +39,11 @@ pub trait StateReader { /// Default: 0 (uninitialized class hash) for an uninitialized contract address. fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult; - /// Returns the compiled class of the given class hash. - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult; + /// Returns the compiled contract class of the given class hash. + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult; /// Returns the compiled class hash of the given class hash. fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult; @@ -90,10 +93,10 @@ pub trait State: StateReader { ) -> StateResult<()>; /// Sets the given contract class under the given class hash. - fn set_contract_class( + fn set_compiled_contract_class( &mut self, class_hash: ClassHash, - contract_class: RunnableCompiledClass, + contract_class: VersionedRunnableCompiledClass, ) -> StateResult<()>; /// Sets the given compiled class hash under the given class hash. diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index c3d4ea4ce2..92584afe16 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -1,7 +1,7 @@ use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use starknet_api::abi::abi_utils::selector_from_name; use starknet_api::abi::constants::CONSTRUCTOR_ENTRY_POINT_NAME; -use starknet_api::contract_class::{ContractClass, EntryPointType}; +use starknet_api::contract_class::{ContractClass, EntryPointType, SierraVersion}; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, EntryPointSelector}; use starknet_api::deprecated_contract_class::{ ContractClass as DeprecatedContractClass, @@ -12,7 +12,7 @@ use starknet_types_core::felt::Felt; use strum::IntoEnumIterator; use strum_macros::EnumIter; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::{VersionedRunnableCompiledClass, RunnableCompiledClass}; use crate::execution::entry_point::CallEntryPoint; #[cfg(feature = "cairo_native")] use crate::execution::native::contract_class::NativeCompiledClassV1; @@ -185,6 +185,10 @@ impl FeatureContract { } } + pub fn get_sierra_version(&self) -> SierraVersion { + todo!("Aviv 24/11/2024: Implement this function") + } + pub fn get_runnable_class(&self) -> RunnableCompiledClass { #[cfg(feature = "cairo_native")] if CairoVersion::Native == self.cairo_version() { @@ -196,6 +200,10 @@ impl FeatureContract { self.get_class().try_into().unwrap() } + pub fn get_compiled_contract_class(&self) -> VersionedRunnableCompiledClass { + (self.get_runnable_class(), self.get_sierra_version()) + } + pub fn get_raw_class(&self) -> String { get_raw_contract_class(&self.get_compiled_path()) } diff --git a/crates/blockifier/src/test_utils/dict_state_reader.rs b/crates/blockifier/src/test_utils/dict_state_reader.rs index 16e3b3ed83..9c9e678822 100644 --- a/crates/blockifier/src/test_utils/dict_state_reader.rs +++ b/crates/blockifier/src/test_utils/dict_state_reader.rs @@ -4,7 +4,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::cached_state::StorageEntry; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; @@ -15,7 +15,7 @@ pub struct DictStateReader { pub storage_view: HashMap, pub address_to_nonce: HashMap, pub address_to_class_hash: HashMap, - pub class_hash_to_class: HashMap, + pub class_hash_to_class: HashMap, pub class_hash_to_compiled_class_hash: HashMap, } @@ -35,7 +35,10 @@ impl StateReader for DictStateReader { Ok(nonce) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let contract_class = self.class_hash_to_class.get(&class_hash).cloned(); match contract_class { Some(contract_class) => Ok(contract_class), diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index 0beeb1a36d..e85c81d216 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -49,7 +49,7 @@ pub fn test_state_inner( // Declare and deploy account and ERC20 contracts. let erc20 = FeatureContract::ERC20(erc20_contract_version); - class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_runnable_class()); + class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_compiled_contract_class()); address_to_class_hash .insert(chain_info.fee_token_address(&FeeType::Eth), erc20.get_class_hash()); address_to_class_hash @@ -58,7 +58,7 @@ pub fn test_state_inner( // Set up the rest of the requested contracts. for (contract, n_instances) in contract_instances.iter() { let class_hash = contract.get_class_hash(); - class_hash_to_class.insert(class_hash, contract.get_runnable_class()); + class_hash_to_class.insert(class_hash, contract.get_compiled_contract_class()); for instance in 0..*n_instances { let instance_address = contract.get_instance_address(instance); address_to_class_hash.insert(instance_address, class_hash); diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 3d87f93c85..a2f660a660 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -844,8 +844,8 @@ impl ValidatableTransaction for AccountTransaction { })?; // Validate return data. - let compiled_class = state.get_compiled_class(class_hash)?; - if is_cairo1(&compiled_class) { + let (contract_class,_) = state.get_compiled_contract_class(class_hash)?; + if is_cairo1(&contract_class) { // The account contract class is a Cairo 1.0 contract; the `validate` entry point should // return `VALID`. let expected_retdata = retdata![*constants::VALIDATE_RETDATA]; diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index d070ca9fa3..a420f920f3 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -765,6 +765,7 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { create_test_init_data(chain_info, CairoVersion::Cairo0); let class_hash = class_hash!(0xdeadeadeaf72_u128); let contract_class = FeatureContract::Empty(CairoVersion::Cairo1).get_class(); + let sierra_version = FeatureContract::TestContract(CairoVersion::Cairo1).get_sierra_version(); let next_nonce = nonce_manager.next(account_address); // Cannot fail executing a declare tx unless it's V2 or above, and already declared. @@ -774,7 +775,12 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { sender_address: account_address, ..Default::default() }; - state.set_contract_class(class_hash, contract_class.clone().try_into().unwrap()).unwrap(); + state + .set_compiled_contract_class( + class_hash, + (contract_class.clone().try_into().unwrap(), sierra_version), + ) + .unwrap(); state.set_compiled_class_hash(class_hash, declare_tx_v2.compiled_class_hash).unwrap(); let class_info = calculate_class_info_for_testing(contract_class); let executable_declare = ApiExecutableDeclareTransaction { diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index 106423c1c3..ec374fba1d 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -208,7 +208,10 @@ impl Executable for DeclareTransaction { // We allow redeclaration of the class for backward compatibility. // In the past, we allowed redeclaration of Cairo 0 contracts since there was // no class commitment (so no need to check if the class is already declared). - state.set_contract_class(class_hash, self.contract_class().try_into()?)?; + state.set_compiled_contract_class( + class_hash, + (self.contract_class().try_into()?, self.class_info.sierra_version.clone()), + )?; } } starknet_api::transaction::DeclareTransaction::V2(DeclareTransactionV2 { @@ -418,10 +421,13 @@ fn try_declare( class_hash: ClassHash, compiled_class_hash: Option, ) -> TransactionExecutionResult<()> { - match state.get_compiled_class(class_hash) { + match state.get_compiled_contract_class(class_hash) { Err(StateError::UndeclaredClassHash(_)) => { // Class is undeclared; declare it. - state.set_contract_class(class_hash, tx.contract_class().try_into()?)?; + state.set_compiled_contract_class( + class_hash, + (tx.contract_class().try_into()?, tx.class_info.sierra_version.clone()), + )?; if let Some(compiled_class_hash) = compiled_class_hash { state.set_compiled_class_hash(class_hash, compiled_class_hash)?; } diff --git a/crates/blockifier/src/transaction/transactions_test.rs b/crates/blockifier/src/transaction/transactions_test.rs index f6fb7c0388..d81b333649 100644 --- a/crates/blockifier/src/transaction/transactions_test.rs +++ b/crates/blockifier/src/transaction/transactions_test.rs @@ -1531,7 +1531,7 @@ fn test_declare_tx( // Check state before transaction application. assert_matches!( - state.get_compiled_class(class_hash).unwrap_err(), + state.get_compiled_contract_class(class_hash).unwrap_err(), StateError::UndeclaredClassHash(undeclared_class_hash) if undeclared_class_hash == class_hash ); @@ -1633,7 +1633,7 @@ fn test_declare_tx( ); // Verify class declaration. - let contract_class_from_state = state.get_compiled_class(class_hash).unwrap(); + let (contract_class_from_state,_) = state.get_compiled_contract_class(class_hash).unwrap(); assert_eq!(contract_class_from_state, class_info.contract_class().try_into().unwrap()); // Checks that redeclaring the same contract fails. diff --git a/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs index d3f9b3ce89..eb7d4f5dc9 100644 --- a/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs @@ -5,7 +5,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig; use blockifier::blockifier::transaction_executor::TransactionExecutor; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::{ + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::{CommitmentStateDiff, StateMaps}; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -13,6 +16,7 @@ use blockifier::transaction::transaction_execution::Transaction as BlockifierTra use blockifier::versioned_constants::VersionedConstants; use serde::{Deserialize, Serialize}; use starknet_api::block::{BlockHash, BlockHashAndNumber, BlockInfo, BlockNumber, StarknetVersion}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ChainId, ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_api::transaction::{Transaction, TransactionHash}; @@ -154,13 +158,22 @@ impl StateReader for OfflineStateReader { )?) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { match self.get_contract_class(&class_hash)? { StarknetContractClass::Sierra(sierra) => { - Ok(sierra_to_contact_class_v1(sierra).unwrap().try_into().unwrap()) + let sierra_version = + SierraVersion::extract_from_program(&sierra.sierra_program).unwrap(); + let runnable_compiled_class = + sierra_to_contact_class_v1(sierra).unwrap().try_into().unwrap(); + Ok((runnable_compiled_class, sierra_version)) } StarknetContractClass::Legacy(legacy) => { - Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + let runnable_compiled_class: RunnableCompiledClass = + legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(); + Ok((runnable_compiled_class, SierraVersion::zero())) } } } diff --git a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs index a17432379a..7a9bc78c80 100644 --- a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs @@ -7,7 +7,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig; use blockifier::blockifier::transaction_executor::TransactionExecutor; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::{ + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::CommitmentStateDiff; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -16,6 +19,7 @@ use blockifier::versioned_constants::VersionedConstants; use serde::Serialize; use serde_json::{json, to_value}; use starknet_api::block::{BlockHash, BlockHashAndNumber, BlockInfo, BlockNumber, StarknetVersion}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ChainId, ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_api::transaction::{Transaction, TransactionHash}; @@ -125,16 +129,25 @@ impl StateReader for TestStateReader { /// Returns the contract class of the given class hash. /// Compile the contract class if it is Sierra. - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let contract_class = retry_request!(self.retry_config, || self.get_contract_class(&class_hash))?; match contract_class { StarknetContractClass::Sierra(sierra) => { - Ok(sierra_to_contact_class_v1(sierra).unwrap().try_into().unwrap()) + let sierra_version = + SierraVersion::extract_from_program(&sierra.sierra_program).unwrap(); + let runnable_compiled_class = + sierra_to_contact_class_v1(sierra).unwrap().try_into().unwrap(); + Ok((runnable_compiled_class, sierra_version)) } StarknetContractClass::Legacy(legacy) => { - Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + let runnable_compiled_class: RunnableCompiledClass = + legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(); + Ok((runnable_compiled_class, SierraVersion::zero())) } } } diff --git a/crates/blockifier_reexecution/src/state_reader/utils.rs b/crates/blockifier_reexecution/src/state_reader/utils.rs index 3742e1e033..54282a53fb 100644 --- a/crates/blockifier_reexecution/src/state_reader/utils.rs +++ b/crates/blockifier_reexecution/src/state_reader/utils.rs @@ -279,7 +279,10 @@ pub fn write_block_reexecution_data_to_file( let block_state = reexecute_and_verify_correctness(consecutive_state_readers).unwrap(); let serializable_data_prev_block = SerializableDataPrevBlock { state_maps: block_state.get_initial_reads().unwrap().into(), - contract_class_mapping: block_state.state.get_contract_class_mapping_dumper().unwrap(), + contract_class_mapping: block_state + .state + .get_contract_class_mapping_dumper() + .unwrap(), }; // Write the reexecution data to a json file. diff --git a/crates/native_blockifier/src/py_block_executor.rs b/crates/native_blockifier/src/py_block_executor.rs index d515e3a2da..af89ed8a5d 100644 --- a/crates/native_blockifier/src/py_block_executor.rs +++ b/crates/native_blockifier/src/py_block_executor.rs @@ -8,7 +8,7 @@ use blockifier::blockifier::transaction_executor::{TransactionExecutor, Transact use blockifier::bouncer::BouncerConfig; use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses}; use blockifier::execution::call_info::CallInfo; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::VersionedRunnableCompiledClass; use blockifier::fee::receipt::TransactionReceipt; use blockifier::state::global_cache::GlobalContractCache; use blockifier::transaction::objects::{ExecutionResourcesTraits, TransactionExecutionInfo}; @@ -138,7 +138,7 @@ pub struct PyBlockExecutor { /// `Send` trait is required for `pyclass` compatibility as Python objects must be threadsafe. pub storage: Box, pub contract_class_manager_config: ContractClassManagerConfig, - pub global_contract_cache: GlobalContractCache, + pub global_contract_cache: GlobalContractCache, } #[pymethods] diff --git a/crates/native_blockifier/src/py_block_executor_test.rs b/crates/native_blockifier/src/py_block_executor_test.rs index d223df4657..c5488a907d 100644 --- a/crates/native_blockifier/src/py_block_executor_test.rs +++ b/crates/native_blockifier/src/py_block_executor_test.rs @@ -71,12 +71,12 @@ fn global_contract_cache_update() { assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 0); - let queried_contract_class = block_executor + let (queried_contract_class, _) = block_executor .tx_executor() .block_state .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) - .get_compiled_class(class_hash) + .get_compiled_contract_class(class_hash) .unwrap(); assert_eq!(queried_contract_class, contract_class); diff --git a/crates/native_blockifier/src/py_test_utils.rs b/crates/native_blockifier/src/py_test_utils.rs index 2ba5ac4d48..6f85d6576e 100644 --- a/crates/native_blockifier/src/py_test_utils.rs +++ b/crates/native_blockifier/src/py_test_utils.rs @@ -5,6 +5,7 @@ use blockifier::state::cached_state::CachedState; use blockifier::test_utils::dict_state_reader::DictStateReader; use blockifier::test_utils::struct_impls::LoadContractFromFile; use starknet_api::class_hash; +use starknet_api::contract_class::SierraVersion; use starknet_api::deprecated_contract_class::ContractClass; pub const TOKEN_FOR_TESTING_CLASS_HASH: &str = "0x30"; @@ -16,7 +17,10 @@ pub const TOKEN_FOR_TESTING_CONTRACT_PATH: &str = pub fn create_py_test_state() -> CachedState { let contract_class: CompiledClassV0 = ContractClass::from_file(TOKEN_FOR_TESTING_CONTRACT_PATH).try_into().unwrap(); - let class_hash_to_class = - HashMap::from([(class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), contract_class.into())]); + let mut class_hash_to_class = HashMap::new(); + class_hash_to_class.insert( + class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), + (contract_class.into(), SierraVersion::zero()), + ); CachedState::from(DictStateReader { class_hash_to_class, ..Default::default() }) } diff --git a/crates/native_blockifier/src/state_readers/py_state_reader.rs b/crates/native_blockifier/src/state_readers/py_state_reader.rs index 35cdde1d91..58788d943f 100644 --- a/crates/native_blockifier/src/state_readers/py_state_reader.rs +++ b/crates/native_blockifier/src/state_readers/py_state_reader.rs @@ -1,11 +1,12 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, - RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; use pyo3::{FromPyObject, PyAny, PyErr, PyObject, PyResult, Python}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; @@ -68,8 +69,11 @@ impl StateReader for PyStateReader { .map_err(|err| StateError::StateReadError(err.to_string())) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - Python::with_gil(|py| -> Result { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + Python::with_gil(|py| -> Result { let args = (PyFelt::from(class_hash),); let py_raw_compiled_class: PyRawCompiledClass = self .state_reader_proxy @@ -77,7 +81,7 @@ impl StateReader for PyStateReader { .call_method1("get_raw_compiled_class", args)? .extract()?; - Ok(RunnableCompiledClass::try_from(py_raw_compiled_class)?) + Ok(VersionedRunnableCompiledClass::try_from(py_raw_compiled_class)?) }) .map_err(|err| { if Python::with_gil(|py| err.is_instance_of::(py)) { @@ -104,18 +108,30 @@ impl StateReader for PyStateReader { #[derive(FromPyObject)] pub struct PyRawCompiledClass { pub raw_compiled_class: String, + pub raw_sierra_version: (u64, u64, u64), pub version: usize, } -impl TryFrom for RunnableCompiledClass { +impl TryFrom for VersionedRunnableCompiledClass { type Error = NativeBlockifierError; fn try_from(raw_compiled_class: PyRawCompiledClass) -> NativeBlockifierResult { + let sierra_version = SierraVersion::new( + raw_compiled_class.raw_sierra_version.0, + raw_compiled_class.raw_sierra_version.1, + raw_compiled_class.raw_sierra_version.2, + ); match raw_compiled_class.version { - 0 => Ok(CompiledClassV0::try_from_json_string(&raw_compiled_class.raw_compiled_class)? - .into()), - 1 => Ok(CompiledClassV1::try_from_json_string(&raw_compiled_class.raw_compiled_class)? - .into()), + 0 => Ok(( + CompiledClassV0::try_from_json_string(&raw_compiled_class.raw_compiled_class)? + .into(), + sierra_version, + )), + 1 => Ok(( + CompiledClassV1::try_from_json_string(&raw_compiled_class.raw_compiled_class)? + .into(), + sierra_version, + )), _ => Err(NativeBlockifierInputError::UnsupportedContractClassVersion { version: raw_compiled_class.version, })?, diff --git a/crates/native_blockifier/src/storage.rs b/crates/native_blockifier/src/storage.rs index 5e028f14c1..4f93c52896 100644 --- a/crates/native_blockifier/src/storage.rs +++ b/crates/native_blockifier/src/storage.rs @@ -12,6 +12,7 @@ use papyrus_storage::header::{HeaderStorageReader, HeaderStorageWriter}; use papyrus_storage::state::{StateStorageReader, StateStorageWriter}; use pyo3::prelude::*; use starknet_api::block::{BlockHash, BlockHeader, BlockHeaderWithoutHash, BlockNumber}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ChainId, ClassHash, CompiledClassHash, ContractAddress}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::hash::StarkHash; @@ -174,7 +175,8 @@ impl Storage for PapyrusStorage { let mut declared_classes = IndexMap::::new(); - let mut undeclared_casm_contracts = Vec::<(ClassHash, CasmContractClass)>::new(); + let mut undeclared_casm_contracts = + Vec::<(ClassHash, (CasmContractClass, SierraVersion))>::new(); for (class_hash, (raw_sierra, (compiled_class_hash, raw_casm))) in declared_class_hash_to_class { @@ -188,18 +190,20 @@ impl Storage for PapyrusStorage { if class_undeclared { let sierra_contract_class: SierraContractClass = serde_json::from_str(&raw_sierra)?; + let sierra_version = + SierraVersion::extract_from_program(&sierra_contract_class.sierra_program)?; declared_classes.insert( class_hash, (CompiledClassHash(compiled_class_hash.0), sierra_contract_class), ); let casm_contract_class: CasmContractClass = serde_json::from_str(&raw_casm)?; - undeclared_casm_contracts.push((class_hash, casm_contract_class)); + undeclared_casm_contracts.push((class_hash, (casm_contract_class, sierra_version))); } } let mut append_txn = self.writer().begin_rw_txn()?; - for (class_hash, contract_class) in undeclared_casm_contracts { - append_txn = append_txn.append_casm(&class_hash, &contract_class)?; + for (class_hash, (casm, sierra_version)) in undeclared_casm_contracts { + append_txn = append_txn.append_versioned_casm(&class_hash, &(&casm, sierra_version))?; } // Construct state diff; manually add declared classes. diff --git a/crates/papyrus_common/src/pending_classes.rs b/crates/papyrus_common/src/pending_classes.rs index 79de3849e8..5ad2a047f1 100644 --- a/crates/papyrus_common/src/pending_classes.rs +++ b/crates/papyrus_common/src/pending_classes.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::ClassHash; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::state::SierraContractClass; @@ -15,9 +16,9 @@ pub trait PendingClassesTrait { // TODO(shahak) Return an Arc to avoid cloning the class. This requires to re-implement // From/TryFrom for various structs in a way that the input is passed by reference. - fn get_compiled_class(&self, class_hash: ClassHash) -> Option; + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> Option<(CasmContractClass,SierraVersion)>; - fn add_compiled_class(&mut self, class_hash: ClassHash, compiled_class: CasmContractClass); + fn add_compiled_contract_class(&mut self, class_hash: ClassHash, compiled_class: (CasmContractClass, SierraVersion)); fn clear(&mut self); } @@ -27,7 +28,7 @@ pub struct PendingClasses { // Putting the contracts inside Arc so we won't have to clone them when we clone the entire // PendingClasses struct. pub classes: HashMap>, - pub compiled_classes: HashMap>, + pub compiled_classes: HashMap>, } #[derive(Debug, Eq, PartialEq, Clone)] @@ -61,11 +62,11 @@ impl PendingClassesTrait for PendingClasses { self.classes.insert(class_hash, Arc::new(class)); } - fn get_compiled_class(&self, class_hash: ClassHash) -> Option { + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> Option<(CasmContractClass,SierraVersion)> { self.compiled_classes.get(&class_hash).map(|compiled_class| (**compiled_class).clone()) } - fn add_compiled_class(&mut self, class_hash: ClassHash, compiled_class: CasmContractClass) { + fn add_compiled_contract_class(&mut self, class_hash: ClassHash, compiled_class: (CasmContractClass,SierraVersion)) { self.compiled_classes.insert(class_hash, Arc::new(compiled_class)); } diff --git a/crates/papyrus_execution/src/execution_utils.rs b/crates/papyrus_execution/src/execution_utils.rs index 9a6f34ffb7..bcf34203d8 100644 --- a/crates/papyrus_execution/src/execution_utils.rs +++ b/crates/papyrus_execution/src/execution_utils.rs @@ -19,6 +19,7 @@ use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageResult, StorageTxn}; // Expose the tool for creating entry point selectors from function names. pub use starknet_api::abi::abi_utils::selector_from_name; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey, ThinStateDiff}; use starknet_types_core::felt::Felt; @@ -59,15 +60,18 @@ pub(crate) fn get_contract_class( txn: &StorageTxn<'_, RO>, class_hash: &ClassHash, state_number: StateNumber, -) -> Result, ExecutionUtilsError> { +) -> Result, ExecutionUtilsError> { match txn.get_state_reader()?.get_class_definition_block_number(class_hash)? { Some(block_number) if state_number.is_before(block_number) => return Ok(None), Some(_block_number) => { - let Some(casm) = txn.get_casm(class_hash)? else { + let Some((casm, sierra_version)) = txn.get_versioned_casm(class_hash)? else { return Err(ExecutionUtilsError::CasmTableNotSynced); }; - return Ok(Some(RunnableCompiledClass::V1( - CompiledClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?, + return Ok(Some(( + RunnableCompiledClass::V1( + CompiledClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?, + ), + sierra_version, ))); } None => {} @@ -78,8 +82,12 @@ pub(crate) fn get_contract_class( else { return Ok(None); }; - Ok(Some(RunnableCompiledClass::V0( - CompiledClassV0::try_from(deprecated_class).map_err(ExecutionUtilsError::ProgramError)?, + Ok(Some(( + RunnableCompiledClass::V0( + CompiledClassV0::try_from(deprecated_class) + .map_err(ExecutionUtilsError::ProgramError)?, + ), + SierraVersion::zero(), ))) } diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index b67aaa170e..9bb5a452ad 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -8,6 +8,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; @@ -15,6 +16,7 @@ use papyrus_common::pending_classes::{ApiContractClass, PendingClassesTrait}; use papyrus_common::state::DeclaredClassHashEntry; use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageReader}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey}; use starknet_types_core::felt::Felt; @@ -75,14 +77,20 @@ impl BlockifierStateReader for ExecutionStateReader { .unwrap_or_default()) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - if let Some(pending_casm) = self + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + if let Some((pending_casm, sierra_version)) = self .maybe_pending_data .as_ref() - .and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash)) + .and_then(|pending_data| pending_data.classes.get_compiled_contract_class(class_hash)) { - return Ok(RunnableCompiledClass::V1( - CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, + return Ok(( + RunnableCompiledClass::V1( + CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, + ), + sierra_version, )); } if let Some(ApiContractClass::DeprecatedContractClass(pending_deprecated_class)) = self @@ -90,9 +98,12 @@ impl BlockifierStateReader for ExecutionStateReader { .as_ref() .and_then(|pending_data| pending_data.classes.get_class(class_hash)) { - return Ok(RunnableCompiledClass::V0( - CompiledClassV0::try_from(pending_deprecated_class) - .map_err(StateError::ProgramError)?, + return Ok(( + RunnableCompiledClass::V0( + CompiledClassV0::try_from(pending_deprecated_class) + .map_err(StateError::ProgramError)?, + ), + SierraVersion::zero(), )); } match get_contract_class( diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 8e6a3c9558..427cc8d48b 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -24,6 +24,7 @@ use papyrus_storage::header::HeaderStorageWriter; use papyrus_storage::state::StateStorageWriter; use papyrus_storage::test_utils::get_test_storage; use starknet_api::block::{BlockBody, BlockHash, BlockHeader, BlockHeaderWithoutHash, BlockNumber}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, Nonce}; use starknet_api::hash::StarkHash; use starknet_api::state::{SierraContractClass, StateNumber, ThinStateDiff}; @@ -48,6 +49,7 @@ fn read_state() { let storage_value1 = felt!(888_u128); // The class is not used in the execution, so it can be default. let class0 = SierraContractClass::default(); + let sierra_version_0 = SierraVersion::extract_from_program(&class0.sierra_program).unwrap(); let casm0 = get_test_casm(); let blockifier_casm0 = RunnableCompiledClass::V1(CompiledClassV1::try_from(casm0.clone()).unwrap()); @@ -55,6 +57,7 @@ fn read_state() { let class_hash1 = ClassHash(1u128.into()); let class1 = get_test_deprecated_contract_class(); + let sierra_version_1 = SierraVersion::zero(); let address1 = contract_address!(DEPRECATED_CONTRACT_ADDRESS); let nonce0 = Nonce(felt!(1_u128)); @@ -127,7 +130,7 @@ fn read_state() { &[(class_hash1, &class1)], ) .unwrap() - .append_casm(&class_hash0, &casm0) + .append_versioned_casm(&class_hash0, &(&casm0, sierra_version_0.clone())) .unwrap() .append_header( BlockNumber(2), @@ -163,9 +166,10 @@ fn read_state() { assert_eq!(nonce_after_block_0, Nonce::default()); let class_hash_after_block_0 = state_reader0.get_class_hash_at(address0).unwrap(); assert_eq!(class_hash_after_block_0, ClassHash::default()); - let compiled_contract_class_after_block_0 = state_reader0.get_compiled_class(class_hash0); + let compiled_contract_class_after_block_0_result = + state_reader0.get_compiled_contract_class(class_hash0); assert_matches!( - compiled_contract_class_after_block_0, Err(StateError::UndeclaredClassHash(class_hash)) + compiled_contract_class_after_block_0_result, Err(StateError::UndeclaredClassHash(class_hash)) if class_hash == class_hash0 ); assert_eq!(state_reader0.get_compiled_class_hash(class_hash0).unwrap(), compiled_class_hash0); @@ -183,13 +187,13 @@ fn read_state() { assert_eq!(nonce_after_block_1, nonce0); let class_hash_after_block_1 = state_reader1.get_class_hash_at(address0).unwrap(); assert_eq!(class_hash_after_block_1, class_hash0); - let compiled_contract_class_after_block_1 = - state_reader1.get_compiled_class(class_hash0).unwrap(); + let (compiled_contract_class_after_block_1, _) = + state_reader1.get_compiled_contract_class(class_hash0).unwrap(); assert_eq!(compiled_contract_class_after_block_1, blockifier_casm0); // Test that an error is returned if we try to get a missing casm, and the field // `missing_compiled_class` is set to the missing casm's hash. - state_reader1.get_compiled_class(class_hash5).unwrap_err(); + state_reader1.get_compiled_contract_class(class_hash5).unwrap_err(); assert_eq!(state_reader1.missing_compiled_class.get().unwrap(), class_hash5); let state_number2 = StateNumber::unchecked_right_after_block(BlockNumber(2)); @@ -204,7 +208,7 @@ fn read_state() { // Test pending state diff let mut pending_classes = PendingClasses::default(); - pending_classes.add_compiled_class(class_hash2, casm1); + pending_classes.add_compiled_contract_class(class_hash2, (casm1, sierra_version_1.clone())); pending_classes.add_class(class_hash3, ApiContractClass::ContractClass(class0)); pending_classes .add_class(class_hash4, ApiContractClass::DeprecatedContractClass(class1.clone())); @@ -233,14 +237,23 @@ fn read_state() { assert_eq!(state_reader2.get_compiled_class_hash(class_hash2).unwrap(), compiled_class_hash2); assert_eq!(state_reader2.get_nonce_at(address0).unwrap(), nonce0); assert_eq!(state_reader2.get_nonce_at(address2).unwrap(), nonce1); - assert_eq!(state_reader2.get_compiled_class(class_hash0).unwrap(), blockifier_casm0); - assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm1); + assert_eq!( + state_reader2.get_compiled_contract_class(class_hash0).unwrap(), + (blockifier_casm0, sierra_version_0) + ); + assert_eq!( + state_reader2.get_compiled_contract_class(class_hash2).unwrap(), + (blockifier_casm1, sierra_version_1) + ); // Test that an error is returned if we only got the class without the casm. - state_reader2.get_compiled_class(class_hash3).unwrap_err(); + state_reader2.get_compiled_contract_class(class_hash3).unwrap_err(); // Test that if the class is deprecated it is returned. assert_eq!( - state_reader2.get_compiled_class(class_hash4).unwrap(), - RunnableCompiledClass::V0(CompiledClassV0::try_from(class1).unwrap()) + state_reader2.get_compiled_contract_class(class_hash4).unwrap(), + ( + RunnableCompiledClass::V0(CompiledClassV0::try_from(class1).unwrap()), + SierraVersion::zero() + ) ); // Test get_class_hash_at when the class is replaced. diff --git a/crates/papyrus_execution/src/test_utils.rs b/crates/papyrus_execution/src/test_utils.rs index 377cc957a1..0319efa470 100644 --- a/crates/papyrus_execution/src/test_utils.rs +++ b/crates/papyrus_execution/src/test_utils.rs @@ -96,6 +96,10 @@ pub fn get_test_account_class() -> DeprecatedContractClass { get_test_instance("account_class.json") } +pub fn get_dummy_sierra_versaion() -> SierraVersion { + SierraVersion::latest() +} + pub fn prepare_storage(mut storage_writer: StorageWriter) { let class_hash0 = class_hash!("0x2"); let class_hash1 = class_hash!("0x1"); @@ -172,7 +176,7 @@ pub fn prepare_storage(mut storage_writer: StorageWriter) { ], ) .unwrap() - .append_casm(&class_hash0, &get_test_casm()) + .append_versioned_casm(&class_hash0, &(&get_test_casm(), get_dummy_sierra_versaion())) .unwrap() .append_header( BlockNumber(1), diff --git a/crates/papyrus_rpc/src/v0_8/api/api_impl.rs b/crates/papyrus_rpc/src/v0_8/api/api_impl.rs index ad9094da1a..b7ab8f45a3 100644 --- a/crates/papyrus_rpc/src/v0_8/api/api_impl.rs +++ b/crates/papyrus_rpc/src/v0_8/api/api_impl.rs @@ -1462,8 +1462,8 @@ impl JsonRpcServer for JsonRpcServerImpl { if class_definition_block_number > block_number { return Err(ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND)); } - let casm = storage_txn - .get_casm(&class_hash) + let (casm,_) = storage_txn + .get_versioned_casm(&class_hash) .map_err(internal_server_error)? .ok_or_else(|| ErrorObjectOwned::from(CLASS_HASH_NOT_FOUND))?; return Ok(CompiledContractClass::V1(casm)); diff --git a/crates/papyrus_rpc/src/v0_8/api/mod.rs b/crates/papyrus_rpc/src/v0_8/api/mod.rs index 5e45ff1386..fa10664fd8 100644 --- a/crates/papyrus_rpc/src/v0_8/api/mod.rs +++ b/crates/papyrus_rpc/src/v0_8/api/mod.rs @@ -371,8 +371,8 @@ pub(crate) fn stored_txn_to_executable_txn( starknet_api::transaction::Transaction::Declare( starknet_api::transaction::DeclareTransaction::V2(value), ) => { - let casm = storage_txn - .get_casm(&value.class_hash) + let versioned_casm = storage_txn + .get_versioned_casm(&value.class_hash) .map_err(internal_server_error)? .ok_or_else(|| { internal_server_error(format!( @@ -384,7 +384,7 @@ pub(crate) fn stored_txn_to_executable_txn( get_class_lengths(storage_txn, state_number, value.class_hash)?; Ok(ExecutableTransactionInput::DeclareV2( value, - casm, + versioned_casm.0, sierra_program_length, abi_length, false, @@ -394,8 +394,8 @@ pub(crate) fn stored_txn_to_executable_txn( starknet_api::transaction::Transaction::Declare( starknet_api::transaction::DeclareTransaction::V3(value), ) => { - let casm = storage_txn - .get_casm(&value.class_hash) + let versioned_casm = storage_txn + .get_versioned_casm(&value.class_hash) .map_err(internal_server_error)? .ok_or_else(|| { internal_server_error(format!( @@ -407,7 +407,7 @@ pub(crate) fn stored_txn_to_executable_txn( get_class_lengths(storage_txn, state_number, value.class_hash)?; Ok(ExecutableTransactionInput::DeclareV3( value, - casm, + versioned_casm.0, sierra_program_length, abi_length, false, diff --git a/crates/papyrus_rpc/src/v0_8/api/test.rs b/crates/papyrus_rpc/src/v0_8/api/test.rs index fe2b144322..a09eb14231 100644 --- a/crates/papyrus_rpc/src/v0_8/api/test.rs +++ b/crates/papyrus_rpc/src/v0_8/api/test.rs @@ -52,6 +52,7 @@ use starknet_api::block::{ GasPricePerToken, StarknetVersion, }; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ ClassHash, CompiledClassHash, @@ -3705,7 +3706,7 @@ async fn get_compiled_class() { }, ) .unwrap() - .append_casm(&cairo1_class_hash, &cairo1_contract_class) + .append_versioned_casm(&cairo1_class_hash, &(&cairo1_contract_class, SierraVersion::default())) .unwrap() // Note: there is no need to write the cairo1 contract class here because the // declared_classes_table is not used in the rpc method. diff --git a/crates/papyrus_rpc/src/v0_8/execution_test.rs b/crates/papyrus_rpc/src/v0_8/execution_test.rs index 732525a62f..75fe74e7ef 100644 --- a/crates/papyrus_rpc/src/v0_8/execution_test.rs +++ b/crates/papyrus_rpc/src/v0_8/execution_test.rs @@ -48,7 +48,7 @@ use starknet_api::block::{ BlockTimestamp, GasPricePerToken, }; -use starknet_api::contract_class::EntryPointType; +use starknet_api::contract_class::{EntryPointType, SierraVersion}; use starknet_api::core::{ ClassHash, CompiledClassHash, @@ -1563,6 +1563,7 @@ async fn write_block_0_as_pending( let class2 = starknet_api::state::SierraContractClass::default(); let casm = serde_json::from_value::(read_json_file("casm.json")).unwrap(); + let sierra_version = SierraVersion::default(); let class_hash2 = class_hash!("0x2"); let compiled_class_hash = CompiledClassHash(StarkHash::default()); @@ -1578,7 +1579,7 @@ async fn write_block_0_as_pending( let mut pending_classes_ref = pending_classes.write().await; pending_classes_ref.add_class(class_hash2, ApiContractClass::ContractClass(class2)); - pending_classes_ref.add_compiled_class(class_hash2, casm); + pending_classes_ref.add_compiled_contract_class(class_hash2, (casm, sierra_version)); pending_classes_ref.add_class(class_hash1, ApiContractClass::DeprecatedContractClass(class1)); pending_classes_ref .add_class(*ACCOUNT_CLASS_HASH, ApiContractClass::DeprecatedContractClass(account_class)); @@ -1655,6 +1656,7 @@ fn prepare_storage_for_execution(mut storage_writer: StorageWriter) -> StorageWr let class2 = starknet_api::state::SierraContractClass::default(); let casm = serde_json::from_value::(read_json_file("casm.json")).unwrap(); + let sierra_version = SierraVersion::default(); let class_hash2 = class_hash!("0x2"); let compiled_class_hash = CompiledClassHash(StarkHash::default()); @@ -1735,7 +1737,7 @@ fn prepare_storage_for_execution(mut storage_writer: StorageWriter) -> StorageWr ], ) .unwrap() - .append_casm(&class_hash2, &casm) + .append_versioned_casm(&class_hash2, &(&casm,sierra_version)) .unwrap() .append_header( BlockNumber(1), diff --git a/crates/papyrus_state_reader/src/papyrus_state.rs b/crates/papyrus_state_reader/src/papyrus_state.rs index 2af2651d10..ba544910f5 100644 --- a/crates/papyrus_state_reader/src/papyrus_state.rs +++ b/crates/papyrus_state_reader/src/papyrus_state.rs @@ -2,6 +2,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::global_cache::GlobalContractCache; @@ -11,6 +12,7 @@ use papyrus_storage::db::RO; use papyrus_storage::state::StateStorageReader; use papyrus_storage::StorageReader; use starknet_api::block::BlockNumber; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey}; use starknet_types_core::felt::Felt; @@ -23,14 +25,14 @@ type RawPapyrusReader<'env> = papyrus_storage::StorageTxn<'env, RO>; pub struct PapyrusReader { storage_reader: StorageReader, latest_block: BlockNumber, - global_class_hash_to_class: GlobalContractCache, + global_class_hash_to_class: GlobalContractCache, } impl PapyrusReader { pub fn new( storage_reader: StorageReader, latest_block: BlockNumber, - global_class_hash_to_class: GlobalContractCache, + global_class_hash_to_class: GlobalContractCache, ) -> Self { Self { storage_reader, latest_block, global_class_hash_to_class } } @@ -46,7 +48,7 @@ impl PapyrusReader { fn get_compiled_class_inner( &self, class_hash: ClassHash, - ) -> StateResult { + ) -> StateResult { let state_number = StateNumber(self.latest_block); let class_declaration_block_number = self .reader()? @@ -57,16 +59,19 @@ impl PapyrusReader { Some(block_number) if block_number <= state_number.0); if class_is_declared { - let casm_compiled_class = self + let (casm_compiled_class, sierra_version) = self .reader()? - .get_casm(&class_hash) + .get_versioned_casm(&class_hash) .map_err(|err| StateError::StateReadError(err.to_string()))? .expect( "Should be able to fetch a Casm class if its definition exists, database is \ inconsistent.", ); - return Ok(RunnableCompiledClass::V1(CompiledClassV1::try_from(casm_compiled_class)?)); + return Ok(( + (RunnableCompiledClass::V1(CompiledClassV1::try_from(casm_compiled_class)?)), + sierra_version, + )); } let v0_compiled_class = self @@ -76,9 +81,10 @@ impl PapyrusReader { .map_err(|err| StateError::StateReadError(err.to_string()))?; match v0_compiled_class { - Some(starknet_api_contract_class) => { - Ok(CompiledClassV0::try_from(starknet_api_contract_class)?.into()) - } + Some(starknet_api_contract_class) => Ok(( + CompiledClassV0::try_from(starknet_api_contract_class)?.into(), + SierraVersion::zero(), + )), None => Err(StateError::UndeclaredClassHash(class_hash)), } } @@ -124,7 +130,10 @@ impl StateReader for PapyrusReader { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { // Assumption: the global cache is cleared upon reverted blocks. let contract_class = self.global_class_hash_to_class.get(&class_hash); diff --git a/crates/papyrus_storage/src/compiled_class.rs b/crates/papyrus_storage/src/compiled_class.rs index d0e2b656aa..d737db95f6 100644 --- a/crates/papyrus_storage/src/compiled_class.rs +++ b/crates/papyrus_storage/src/compiled_class.rs @@ -49,6 +49,7 @@ mod casm_test; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use papyrus_proc_macros::latency_histogram; use starknet_api::block::BlockNumber; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::ClassHash; use crate::db::serialization::VersionZeroWrapper; @@ -60,7 +61,7 @@ use crate::{FileHandlers, MarkerKind, MarkersTable, OffsetKind, StorageResult, S /// Interface for reading data related to the compiled classes. pub trait CasmStorageReader { /// Returns the Cairo assembly of a class given its Sierra class hash. - fn get_casm(&self, class_hash: &ClassHash) -> StorageResult>; + fn get_versioned_casm(&self, class_hash: &ClassHash) -> StorageResult>; /// The block marker is the first block number that doesn't exist yet. /// /// Note: If the last blocks don't contain any declared classes, the marker will point at the @@ -75,11 +76,11 @@ where { /// Stores the Cairo assembly of a class, mapped to its class hash. // To enforce that no commit happen after a failure, we consume and return Self on success. - fn append_casm(self, class_hash: &ClassHash, casm: &CasmContractClass) -> StorageResult; + fn append_versioned_casm(self, class_hash: &ClassHash, versioned_casm: &(&CasmContractClass, SierraVersion)) -> StorageResult; } -impl CasmStorageReader for StorageTxn<'_, Mode> { - fn get_casm(&self, class_hash: &ClassHash) -> StorageResult> { +impl<'env, Mode: TransactionKind> CasmStorageReader for StorageTxn<'env, Mode> { + fn get_versioned_casm(&self, class_hash: &ClassHash) -> StorageResult> { let casm_table = self.open_table(&self.tables.casms)?; let casm_location = casm_table.get(&self.txn, class_hash)?; casm_location.map(|location| self.file_handlers.get_casm_unchecked(location)).transpose() @@ -93,13 +94,13 @@ impl CasmStorageReader for StorageTxn<'_, Mode> { impl CasmStorageWriter for StorageTxn<'_, RW> { #[latency_histogram("storage_append_casm_latency_seconds", false)] - fn append_casm(self, class_hash: &ClassHash, casm: &CasmContractClass) -> StorageResult { + fn append_versioned_casm(self, class_hash: &ClassHash, versioned_casm: &(&CasmContractClass,SierraVersion)) -> StorageResult { let casm_table = self.open_table(&self.tables.casms)?; let markers_table = self.open_table(&self.tables.markers)?; let state_diff_table = self.open_table(&self.tables.state_diffs)?; let file_offset_table = self.txn.open_table(&self.tables.file_offsets)?; - let location = self.file_handlers.append_casm(casm); + let location = self.file_handlers.append_versioned_casm(versioned_casm); casm_table.insert(&self.txn, class_hash, &location)?; file_offset_table.upsert(&self.txn, &OffsetKind::Casm, &location.next_offset())?; update_marker( diff --git a/crates/papyrus_storage/src/compiled_class_test.rs b/crates/papyrus_storage/src/compiled_class_test.rs index 204cada574..bb6ec06e1c 100644 --- a/crates/papyrus_storage/src/compiled_class_test.rs +++ b/crates/papyrus_storage/src/compiled_class_test.rs @@ -1,6 +1,7 @@ use assert_matches::assert_matches; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use pretty_assertions::assert_eq; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::ClassHash; use starknet_api::test_utils::read_json_file; @@ -10,7 +11,7 @@ use crate::test_utils::get_test_storage; use crate::StorageError; #[test] -fn append_casm() { +fn append_versioned_casm() { let casm_json = read_json_file("compiled_class.json"); let expected_casm: CasmContractClass = serde_json::from_value(casm_json).unwrap(); let ((reader, mut writer), _temp_dir) = get_test_storage(); @@ -18,13 +19,14 @@ fn append_casm() { writer .begin_rw_txn() .unwrap() - .append_casm(&ClassHash::default(), &expected_casm) + .append_versioned_casm(&ClassHash::default(), &(&expected_casm, SierraVersion::default())) .unwrap() .commit() .unwrap(); - let casm = reader.begin_ro_txn().unwrap().get_casm(&ClassHash::default()).unwrap().unwrap(); - assert_eq!(casm, expected_casm); + let versioned_casm = + reader.begin_ro_txn().unwrap().get_versioned_casm(&ClassHash::default()).unwrap().unwrap(); + assert_eq!(versioned_casm, (expected_casm, SierraVersion::default())); } #[test] @@ -34,8 +36,28 @@ fn casm_rewrite() { writer .begin_rw_txn() .unwrap() - .append_casm( + .append_versioned_casm( &ClassHash::default(), + &( + &CasmContractClass { + prime: Default::default(), + compiler_version: Default::default(), + bytecode: Default::default(), + bytecode_segment_lengths: Default::default(), + hints: Default::default(), + pythonic_hints: Default::default(), + entry_points_by_type: Default::default(), + }, + SierraVersion::default(), + ), + ) + .unwrap() + .commit() + .unwrap(); + + let Err(err) = writer.begin_rw_txn().unwrap().append_versioned_casm( + &ClassHash::default(), + &( &CasmContractClass { prime: Default::default(), compiler_version: Default::default(), @@ -45,22 +67,8 @@ fn casm_rewrite() { pythonic_hints: Default::default(), entry_points_by_type: Default::default(), }, - ) - .unwrap() - .commit() - .unwrap(); - - let Err(err) = writer.begin_rw_txn().unwrap().append_casm( - &ClassHash::default(), - &CasmContractClass { - prime: Default::default(), - compiler_version: Default::default(), - bytecode: Default::default(), - bytecode_segment_lengths: Default::default(), - hints: Default::default(), - pythonic_hints: Default::default(), - entry_points_by_type: Default::default(), - }, + SierraVersion::default(), + ), ) else { panic!("Unexpected Ok."); }; diff --git a/crates/papyrus_storage/src/lib.rs b/crates/papyrus_storage/src/lib.rs index df7bdcecda..7746d64edf 100644 --- a/crates/papyrus_storage/src/lib.rs +++ b/crates/papyrus_storage/src/lib.rs @@ -123,6 +123,7 @@ use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; use papyrus_proc_macros::latency_histogram; use serde::{Deserialize, Serialize}; use starknet_api::block::{BlockHash, BlockNumber, BlockSignature, StarknetVersion}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::state::{SierraContractClass, StateNumber, StorageKey, ThinStateDiff}; @@ -667,7 +668,7 @@ pub(crate) type MarkersTable<'env> = struct FileHandlers { thin_state_diff: FileHandler, Mode>, contract_class: FileHandler, Mode>, - casm: FileHandler, Mode>, + versioned_casm: FileHandler, Mode>, deprecated_contract_class: FileHandler, Mode>, transaction_output: FileHandler, Mode>, transaction: FileHandler, Mode>, @@ -686,8 +687,9 @@ impl FileHandlers { } // Appends a CASM to the corresponding file and returns its location. - fn append_casm(&self, casm: &CasmContractClass) -> LocationInFile { - self.clone().casm.append(casm) + fn append_versioned_casm(&self, versioned_casm: &(&CasmContractClass,SierraVersion)) -> LocationInFile { + let value = (versioned_casm.0.clone(), versioned_casm.1.clone()); + self.clone().versioned_casm.append(&value) } // Appends a deprecated contract class to the corresponding file and returns its location. @@ -714,7 +716,7 @@ impl FileHandlers { debug!("Flushing the mmap files."); self.thin_state_diff.flush(); self.contract_class.flush(); - self.casm.flush(); + self.versioned_casm.flush(); self.deprecated_contract_class.flush(); self.transaction_output.flush(); self.transaction.flush(); @@ -727,7 +729,7 @@ impl FileHandlers { HashMap::from_iter([ ("thin_state_diff".to_string(), self.thin_state_diff.stats()), ("contract_class".to_string(), self.contract_class.stats()), - ("casm".to_string(), self.casm.stats()), + ("versioned_casm".to_string(), self.versioned_casm.stats()), ("deprecated_contract_class".to_string(), self.deprecated_contract_class.stats()), ("transaction_output".to_string(), self.transaction_output.stats()), ("transaction".to_string(), self.transaction.stats()), @@ -755,8 +757,8 @@ impl FileHandlers { } // Returns the CASM at the given location or an error in case it doesn't exist. - fn get_casm_unchecked(&self, location: LocationInFile) -> StorageResult { - self.casm.get(location)?.ok_or(StorageError::DBInconsistency { + fn get_casm_unchecked(&self, location: LocationInFile) -> StorageResult<(CasmContractClass, SierraVersion)> { + self.versioned_casm.get(location)?.ok_or(StorageError::DBInconsistency { msg: format!("CasmContractClass at location {:?} not found.", location), }) } @@ -817,9 +819,9 @@ fn open_storage_files( contract_class_offset, )?; - let casm_offset = table.get(&db_transaction, &OffsetKind::Casm)?.unwrap_or_default(); - let (casm_writer, casm_reader) = - open_file(mmap_file_config.clone(), db_config.path().join("casm.dat"), casm_offset)?; + let versioned_casm_offset = table.get(&db_transaction, &OffsetKind::Casm)?.unwrap_or_default(); + let (versioned_casm_writer, versioned_casm_reader) = + open_file(mmap_file_config.clone(), db_config.path().join("casm.dat"), versioned_casm_offset)?; let deprecated_contract_class_offset = table.get(&db_transaction, &OffsetKind::DeprecatedContractClass)?.unwrap_or_default(); @@ -846,7 +848,7 @@ fn open_storage_files( FileHandlers { thin_state_diff: thin_state_diff_writer, contract_class: contract_class_writer, - casm: casm_writer, + versioned_casm: versioned_casm_writer, deprecated_contract_class: deprecated_contract_class_writer, transaction_output: transaction_output_writer, transaction: transaction_writer, @@ -854,7 +856,7 @@ fn open_storage_files( FileHandlers { thin_state_diff: thin_state_diff_reader, contract_class: contract_class_reader, - casm: casm_reader, + versioned_casm: versioned_casm_reader, deprecated_contract_class: deprecated_contract_class_reader, transaction_output: transaction_output_reader, transaction: transaction_reader, diff --git a/crates/papyrus_storage/src/serialization/serializers.rs b/crates/papyrus_storage/src/serialization/serializers.rs index 8daf2510c2..468db139b8 100644 --- a/crates/papyrus_storage/src/serialization/serializers.rs +++ b/crates/papyrus_storage/src/serialization/serializers.rs @@ -28,7 +28,7 @@ use starknet_api::block::{ GasPricePerToken, StarknetVersion, }; -use starknet_api::contract_class::EntryPointType; +use starknet_api::contract_class::{EntryPointType, SierraVersion}; use starknet_api::core::{ ClassHash, CompiledClassHash, @@ -512,6 +512,7 @@ auto_storage_serde! { ((ContractAddress, StorageKey), BlockNumber); (usize, Vec); (usize, Vec); + (CasmContractClass, SierraVersion); } //////////////////////////////////////////////////////////////////////// @@ -745,6 +746,16 @@ impl StorageSerde for String { } } +impl StorageSerde for SierraVersion { + fn serialize_into(&self, res: &mut impl std::io::Write) -> Result<(), StorageSerdeError> { + serde_json::to_value(self)?.serialize_into(res) + } + + fn deserialize_from(bytes: &mut impl std::io::Read) -> Option { + serde_json::from_value(serde_json::Value::deserialize_from(bytes)?).ok() + } +} + impl StorageSerde for Option { fn serialize_into(&self, res: &mut impl std::io::Write) -> Result<(), StorageSerdeError> { match self { diff --git a/crates/papyrus_storage/src/state/mod.rs b/crates/papyrus_storage/src/state/mod.rs index 6b8704eb41..56493dcf07 100644 --- a/crates/papyrus_storage/src/state/mod.rs +++ b/crates/papyrus_storage/src/state/mod.rs @@ -60,6 +60,7 @@ use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use indexmap::IndexMap; use papyrus_proc_macros::latency_histogram; use starknet_api::block::BlockNumber; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::state::{SierraContractClass, StateNumber, StorageKey, ThinStateDiff}; @@ -136,7 +137,7 @@ type RevertedStateDiff = ( ThinStateDiff, IndexMap, IndexMap, - IndexMap, + IndexMap, ); /// Interface for writing data related to the state. @@ -769,7 +770,7 @@ fn delete_compiled_classes<'a, 'env>( class_hashes: impl Iterator, compiled_classes_table: &'env CompiledClassesTable<'env>, file_handlers: &FileHandlers, -) -> StorageResult> { +) -> StorageResult> { let mut deleted_data = IndexMap::new(); for class_hash in class_hashes { let Some(compiled_class_location) = compiled_classes_table.get(txn, class_hash)? diff --git a/crates/papyrus_storage/src/state/state_test.rs b/crates/papyrus_storage/src/state/state_test.rs index 3f82f82f0f..5c440882f4 100644 --- a/crates/papyrus_storage/src/state/state_test.rs +++ b/crates/papyrus_storage/src/state/state_test.rs @@ -4,6 +4,7 @@ use indexmap::{indexmap, IndexMap}; use papyrus_test_utils::get_test_state_diff; use pretty_assertions::assert_eq; use starknet_api::block::BlockNumber; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::hash::StarkHash; @@ -464,6 +465,7 @@ fn revert_state() { pythonic_hints: Default::default(), entry_points_by_type: Default::default(), }; + let sierra_version_2 = SierraVersion::default(); let updated_storage_key = storage_key!("0x1"); let new_data = Felt::from(1_u8); let updated_storage = IndexMap::from([(updated_storage_key, new_data)]); @@ -500,7 +502,7 @@ fn revert_state() { &[(class1, &DeprecatedContractClass::default())], ) .unwrap() - .append_casm(&class2, &compiled_class2) + .append_versioned_casm(&class2, &(&compiled_class2, sierra_version_2)) .unwrap() .commit() .unwrap(); @@ -531,15 +533,18 @@ fn revert_state() { let expected_deleted_classes = IndexMap::from([(class2, SierraContractClass::default())]); let expected_deleted_compiled_classes = IndexMap::from([( class2, - CasmContractClass { - prime: Default::default(), - compiler_version: Default::default(), - bytecode: Default::default(), - bytecode_segment_lengths: Default::default(), - hints: Default::default(), - pythonic_hints: Default::default(), - entry_points_by_type: Default::default(), - }, + ( + CasmContractClass { + prime: Default::default(), + compiler_version: Default::default(), + bytecode: Default::default(), + bytecode_segment_lengths: Default::default(), + hints: Default::default(), + pythonic_hints: Default::default(), + entry_points_by_type: Default::default(), + }, + SierraVersion::default(), + ), )]); assert_matches!( deleted_data, @@ -565,7 +570,7 @@ fn revert_state() { state_reader.get_storage_at(state_number, contract0, &updated_storage_key).unwrap(), Felt::ZERO ); - assert!(txn.get_casm(&class2).unwrap().is_none()); + assert!(txn.get_versioned_casm(&class2).unwrap().is_none()); } #[test] diff --git a/crates/papyrus_sync/src/lib.rs b/crates/papyrus_sync/src/lib.rs index b1e1f4165c..4e5bd7810c 100644 --- a/crates/papyrus_sync/src/lib.rs +++ b/crates/papyrus_sync/src/lib.rs @@ -35,6 +35,7 @@ use papyrus_storage::{StorageError, StorageReader, StorageWriter}; use serde::{Deserialize, Serialize}; use sources::base_layer::BaseLayerSourceError; use starknet_api::block::{Block, BlockHash, BlockHashAndNumber, BlockNumber, BlockSignature}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, SequencerPublicKey}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::state::{StateDiff, ThinStateDiff}; @@ -219,6 +220,7 @@ pub enum SyncEvent { class_hash: ClassHash, compiled_class_hash: CompiledClassHash, compiled_class: CasmContractClass, + sierra_version: SierraVersion, }, NewBaseLayerBlock { block_number: BlockNumber, @@ -390,7 +392,13 @@ impl< class_hash, compiled_class_hash, compiled_class, - } => self.store_compiled_class(class_hash, compiled_class_hash, compiled_class), + sierra_version, + } => self.store_compiled_class( + class_hash, + compiled_class_hash, + compiled_class, + sierra_version, + ), SyncEvent::NewBaseLayerBlock { block_number, block_hash } => { self.store_base_layer_block(block_number, block_hash) } @@ -500,10 +508,11 @@ impl< class_hash: ClassHash, compiled_class_hash: CompiledClassHash, compiled_class: CasmContractClass, + sierra_version: SierraVersion, ) -> StateSyncResult { let txn = self.writer.begin_rw_txn()?; // TODO: verifications - verify casm corresponds to a class on storage. - match txn.append_casm(&class_hash, &compiled_class) { + match txn.append_versioned_casm(&class_hash, &(&compiled_class, sierra_version)) { #[allow(clippy::as_conversions)] // FIXME: use int metrics so `as f64` may be removed. Ok(txn) => { txn.commit()?; @@ -847,12 +856,13 @@ fn stream_new_compiled_classes central_source.stream_compiled_classes(from, up_to).fuse(); pin_mut!(compiled_classes_stream); - while let Some(maybe_compiled_class) = compiled_classes_stream.next().await { - let (class_hash, compiled_class_hash, compiled_class) = maybe_compiled_class?; + while let Some(maybe_versioned_compiled_class) = compiled_classes_stream.next().await { + let (class_hash, compiled_class_hash, (compiled_class,sierra_version)) = maybe_versioned_compiled_class?; yield SyncEvent::CompiledClassAvailable { class_hash, compiled_class_hash, compiled_class, + sierra_version }; } } diff --git a/crates/papyrus_sync/src/pending_sync.rs b/crates/papyrus_sync/src/pending_sync.rs index 0a6afbce3d..e4c11808ae 100644 --- a/crates/papyrus_sync/src/pending_sync.rs +++ b/crates/papyrus_sync/src/pending_sync.rs @@ -199,7 +199,7 @@ async fn get_pending_compiled_class, pending_classes: Arc>, ) -> Result { - let compiled_class = central_source.get_compiled_class(class_hash).await?; - pending_classes.write().await.add_compiled_class(class_hash, compiled_class); + let versioned_compiled_class = central_source.get_compiled_class(class_hash).await?; + pending_classes.write().await.add_compiled_contract_class(class_hash, versioned_compiled_class); Ok(PendingSyncTaskResult::DownloadedClassOrCompiledClass) } diff --git a/crates/papyrus_sync/src/sources/central.rs b/crates/papyrus_sync/src/sources/central.rs index a1fa122bf8..4f832fb8d6 100644 --- a/crates/papyrus_sync/src/sources/central.rs +++ b/crates/papyrus_sync/src/sources/central.rs @@ -25,6 +25,7 @@ use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageReader}; use serde::{Deserialize, Serialize}; use starknet_api::block::{Block, BlockHash, BlockHashAndNumber, BlockNumber, BlockSignature}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, SequencerPublicKey}; use starknet_api::crypto::utils::Signature; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; @@ -133,7 +134,7 @@ pub struct GenericCentralSource { pub storage_reader: StorageReader, pub state_update_stream_config: StateUpdateStreamConfig, pub(crate) class_cache: Arc>>, - compiled_class_cache: Arc>>, + versioned_compiled_class_cache: Arc>>, } #[derive(thiserror::Error, Debug)] @@ -195,7 +196,7 @@ pub trait CentralSourceTrait { async fn get_compiled_class( &self, class_hash: ClassHash, - ) -> Result; + ) -> Result<(CasmContractClass, SierraVersion), CentralError>; async fn get_sequencer_pub_key(&self) -> Result; } @@ -205,7 +206,7 @@ pub(crate) type BlocksStream<'a> = type CentralStateUpdate = (BlockNumber, BlockHash, StateDiff, IndexMap); pub(crate) type StateUpdatesStream<'a> = BoxStream<'a, CentralResult>; -type CentralCompiledClass = (ClassHash, CompiledClassHash, CasmContractClass); +type CentralCompiledClass = (ClassHash, CompiledClassHash, (CasmContractClass, SierraVersion)); pub(crate) type CompiledClassesStream<'a> = BoxStream<'a, CentralResult>; #[async_trait] @@ -369,20 +370,20 @@ impl CentralSourceTrait async fn get_compiled_class( &self, class_hash: ClassHash, - ) -> Result { + ) -> Result<(CasmContractClass, SierraVersion), CentralError> { { - let mut compiled_class_cache = - self.compiled_class_cache.lock().expect("Failed to lock class cache."); - if let Some(class) = compiled_class_cache.get(&class_hash) { - return Ok(class.clone()); + let mut versioned_compiled_class_cache = + self.versioned_compiled_class_cache.lock().expect("Failed to lock class cache."); + if let Some(versioned_class) = versioned_compiled_class_cache.get(&class_hash) { + return Ok(versioned_class.clone()); } } match self.starknet_client.compiled_class_by_hash(class_hash).await { - Ok(Some(compiled_class)) => { - let mut compiled_class_cache = - self.compiled_class_cache.lock().expect("Failed to lock class cache."); - compiled_class_cache.put(class_hash, compiled_class.clone()); - Ok(compiled_class) + Ok(Some(versioned_compiled_class)) => { + let mut versioned_compiled_class_cache = + self.versioned_compiled_class_cache.lock().expect("Failed to lock class cache."); + versioned_compiled_class_cache.put(class_hash, versioned_compiled_class.clone()); + Ok(versioned_compiled_class) } Ok(None) => Err(CentralError::CompiledClassNotFound { class_hash }), Err(err) => Err(CentralError::ClientError(Arc::new(err))), @@ -464,7 +465,7 @@ impl CentralSource { NonZeroUsize::new(config.class_cache_size) .expect("class_cache_size should be a positive integer."), ))), - compiled_class_cache: Arc::from(Mutex::new(LruCache::new( + versioned_compiled_class_cache: Arc::from(Mutex::new(LruCache::new( NonZeroUsize::new(config.class_cache_size) .expect("class_cache_size should be a positive integer."), ))), diff --git a/crates/papyrus_sync/src/sources/central_sync_test.rs b/crates/papyrus_sync/src/sources/central_sync_test.rs index 581de1263e..c1ed6479e3 100644 --- a/crates/papyrus_sync/src/sources/central_sync_test.rs +++ b/crates/papyrus_sync/src/sources/central_sync_test.rs @@ -23,6 +23,7 @@ use starknet_api::block::{ BlockNumber, BlockSignature, }; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, SequencerPublicKey}; use starknet_api::crypto::utils::PublicKey; use starknet_api::felt; @@ -589,7 +590,7 @@ async fn sync_with_revert() { async fn get_compiled_class( &self, _class_hash: ClassHash, - ) -> Result { + ) -> Result<(CasmContractClass, SierraVersion), CentralError> { unimplemented!(); } diff --git a/crates/papyrus_sync/src/sources/central_test.rs b/crates/papyrus_sync/src/sources/central_test.rs index c8c593ee10..a79f666ff2 100644 --- a/crates/papyrus_sync/src/sources/central_test.rs +++ b/crates/papyrus_sync/src/sources/central_test.rs @@ -13,6 +13,7 @@ use papyrus_storage::test_utils::get_test_storage; use pretty_assertions::assert_eq; use reqwest::StatusCode; use starknet_api::block::{BlockHash, BlockNumber}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, GlobalRoot, Nonce, SequencerPublicKey}; use starknet_api::crypto::utils::PublicKey; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; @@ -62,7 +63,7 @@ async fn last_block_number() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; let last_block_number = central_source.get_latest_block().await.unwrap().unwrap().number; @@ -101,7 +102,7 @@ async fn stream_block_headers() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; let mut expected_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -181,7 +182,7 @@ async fn stream_block_headers_some_are_missing() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; let mut expected_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -245,7 +246,7 @@ async fn stream_block_headers_error() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; let mut expected_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -384,7 +385,7 @@ async fn stream_state_updates() { state_update_stream_config: state_update_stream_config_for_test(), // TODO(shahak): Check that downloaded classes appear in the cache. class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; let initial_block_num = BlockNumber(START_BLOCK_NUMBER); @@ -518,15 +519,18 @@ async fn stream_compiled_classes() { .with(predicate::eq(ClassHash(felt))) .times(1) .returning(move |_x| { - Ok(Some(CasmContractClass { - prime: Default::default(), - compiler_version: Default::default(), - bytecode: Default::default(), - bytecode_segment_lengths: Default::default(), - hints: Default::default(), - pythonic_hints: Default::default(), - entry_points_by_type: Default::default(), - })) + Ok(Some(( + CasmContractClass { + prime: Default::default(), + compiler_version: Default::default(), + bytecode: Default::default(), + bytecode_segment_lengths: Default::default(), + hints: Default::default(), + pythonic_hints: Default::default(), + entry_points_by_type: Default::default(), + }, + SierraVersion::default(), + ))) }); } @@ -536,29 +540,32 @@ async fn stream_compiled_classes() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; let stream = central_source.stream_compiled_classes(BlockNumber(0), BlockNumber(2)); pin_mut!(stream); - let expected_compiled_class = CasmContractClass { - prime: Default::default(), - compiler_version: Default::default(), - bytecode: Default::default(), - bytecode_segment_lengths: Default::default(), - hints: Default::default(), - pythonic_hints: Default::default(), - entry_points_by_type: Default::default(), - }; + let expected_compiled_class = ( + CasmContractClass { + prime: Default::default(), + compiler_version: Default::default(), + bytecode: Default::default(), + bytecode_segment_lengths: Default::default(), + hints: Default::default(), + pythonic_hints: Default::default(), + entry_points_by_type: Default::default(), + }, + SierraVersion::default(), + ); for felt in felts { - let (class_hash, compiled_class_hash, compiled_class) = + let (class_hash, compiled_class_hash, versioned_compiled_class) = stream.next().await.unwrap().unwrap(); let expected_class_hash = ClassHash(felt); let expected_compiled_class_hash = CompiledClassHash(felt); assert_eq!(class_hash, expected_class_hash); assert_eq!(compiled_class_hash, expected_compiled_class_hash); - assert_eq!(compiled_class, expected_compiled_class); + assert_eq!(versioned_compiled_class, expected_compiled_class); } } @@ -590,7 +597,7 @@ async fn get_class() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; assert_eq!( @@ -613,20 +620,23 @@ async fn get_compiled_class() { let mut mock = MockStarknetReader::new(); let class_hash = ClassHash(StarkHash::ONE); - let compiled_class = CasmContractClass { - prime: Default::default(), - compiler_version: Default::default(), - bytecode: Default::default(), - bytecode_segment_lengths: Default::default(), - hints: Default::default(), - pythonic_hints: Default::default(), - entry_points_by_type: Default::default(), - }; - let compiled_class_clone = compiled_class.clone(); + let versioned_compiled_class = ( + CasmContractClass { + prime: Default::default(), + compiler_version: Default::default(), + bytecode: Default::default(), + bytecode_segment_lengths: Default::default(), + hints: Default::default(), + pythonic_hints: Default::default(), + entry_points_by_type: Default::default(), + }, + SierraVersion::default(), + ); + let versioned_compiled_class_clone = versioned_compiled_class.clone(); mock.expect_compiled_class_by_hash() .with(predicate::eq(class_hash)) .times(1) - .return_once(move |_x| Ok(Some(compiled_class_clone))); + .return_once(move |_x| Ok(Some(versioned_compiled_class_clone))); let ((reader, _), _temp_dir) = get_test_storage(); let central_source = GenericCentralSource { @@ -635,14 +645,20 @@ async fn get_compiled_class() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; - assert_eq!(central_source.get_compiled_class(class_hash).await.unwrap(), compiled_class); + assert_eq!( + central_source.get_compiled_class(class_hash).await.unwrap(), + versioned_compiled_class + ); // Repeating the call to see that source doesn't call the client and gets the result from // cache. - assert_eq!(central_source.get_compiled_class(class_hash).await.unwrap(), compiled_class); + assert_eq!( + central_source.get_compiled_class(class_hash).await.unwrap(), + versioned_compiled_class + ); } #[tokio::test] @@ -659,7 +675,7 @@ async fn get_sequencer_pub_key() { storage_reader: reader, state_update_stream_config: state_update_stream_config_for_test(), class_cache: get_test_class_cache(), - compiled_class_cache: get_test_compiled_class_cache(), + versioned_compiled_class_cache: get_test_versioned_compiled_class_cache(), }; assert_eq!(central_source.get_sequencer_pub_key().await.unwrap(), sequencer_pub_key); @@ -677,6 +693,7 @@ fn get_test_class_cache() -> Arc>> { Arc::from(Mutex::new(LruCache::new(NonZeroUsize::new(2).unwrap()))) } -fn get_test_compiled_class_cache() -> Arc>> { +fn get_test_versioned_compiled_class_cache() +-> Arc>> { Arc::from(Mutex::new(LruCache::new(NonZeroUsize::new(2).unwrap()))) } diff --git a/crates/papyrus_sync/src/sync_test.rs b/crates/papyrus_sync/src/sync_test.rs index 7a2796d2f0..713b9f6bfa 100644 --- a/crates/papyrus_sync/src/sync_test.rs +++ b/crates/papyrus_sync/src/sync_test.rs @@ -13,6 +13,7 @@ use papyrus_storage::{StorageReader, StorageWriter}; use papyrus_test_utils::{get_rng, GetTestInstance}; use pretty_assertions::assert_eq; use starknet_api::block::{BlockHash, BlockHeader, BlockHeaderWithoutHash, BlockNumber}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, Nonce}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::hash::StarkHash; @@ -251,7 +252,7 @@ async fn test_pending_sync( old_pending_classes_data: Option, // Verifies that the classes will be requested in the given order. new_pending_classes: Vec<(ClassHash, ApiContractClass)>, - new_pending_compiled_classes: Vec<(ClassHash, CasmContractClass)>, + new_pending_compiled_classes: Vec<(ClassHash, (CasmContractClass, SierraVersion))>, expected_pending_classes: Option, ) { let mut mock_pending_source = MockPendingSourceTrait::new(); @@ -683,7 +684,10 @@ async fn pending_sync_classes_request_only_new_classes() { let mut expected_pending_classes = PendingClasses::default(); expected_pending_classes.add_class(first_class_hash, first_class.clone()); expected_pending_classes.add_class(second_class_hash, second_class.clone()); - expected_pending_classes.add_compiled_class(first_class_hash, compiled_class.clone()); + expected_pending_classes.add_compiled_contract_class( + first_class_hash, + (compiled_class.clone(), SierraVersion::default()), + ); let old_pending_data = PendingData { block: PendingBlockOrDeprecated::Deprecated(DeprecatedPendingBlock { @@ -698,7 +702,8 @@ async fn pending_sync_classes_request_only_new_classes() { let old_pending_classes_data = PendingClasses::default(); let new_pending_classes = vec![(first_class_hash, first_class.clone()), (second_class_hash, second_class.clone())]; - let new_pending_compiled_classes = vec![(first_class_hash, compiled_class.clone())]; + let new_pending_compiled_classes = + vec![(first_class_hash, (compiled_class.clone(), SierraVersion::default()))]; test_pending_sync( reader, old_pending_data, @@ -767,9 +772,9 @@ async fn pending_sync_classes_are_cleaned_on_first_pending_data_from_latest_bloc &mut rng, )), ); - old_pending_classes_data.add_compiled_class( + old_pending_classes_data.add_compiled_contract_class( ClassHash(StarkHash::TWO), - CasmContractClass::get_test_instance(&mut rng), + (CasmContractClass::get_test_instance(&mut rng), SierraVersion::default()), ); let new_pending_datas = vec![new_pending_data.clone(), new_block_pending_data]; diff --git a/crates/starknet_api/src/contract_class.rs b/crates/starknet_api/src/contract_class.rs index 86445cd26f..195bdf9215 100644 --- a/crates/starknet_api/src/contract_class.rs +++ b/crates/starknet_api/src/contract_class.rs @@ -49,7 +49,7 @@ impl ContractClass { } #[derive(Deref, Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] -pub struct SierraVersion(Version); +pub struct SierraVersion(pub Version); impl SierraVersion { pub fn new(major: u64, minor: u64, patch: u64) -> Self { diff --git a/crates/starknet_batcher/src/block_builder.rs b/crates/starknet_batcher/src/block_builder.rs index b900f35d9c..b88eabad51 100644 --- a/crates/starknet_batcher/src/block_builder.rs +++ b/crates/starknet_batcher/src/block_builder.rs @@ -34,6 +34,7 @@ use starknet_api::block::{ BlockTimestamp, NonzeroGasPrice, }; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::ContractAddress; use starknet_api::executable_transaction::Transaction; use starknet_api::transaction::TransactionHash; @@ -317,7 +318,7 @@ impl SerializeConfig for BlockBuilderConfig { pub struct BlockBuilderFactory { pub block_builder_config: BlockBuilderConfig, pub storage_reader: StorageReader, - pub global_class_hash_to_class: GlobalContractCache, + pub global_class_hash_to_class: GlobalContractCache<(RunnableCompiledClass, SierraVersion)>, } impl BlockBuilderFactory { diff --git a/crates/starknet_client/src/reader/mod.rs b/crates/starknet_client/src/reader/mod.rs index 7dc4b31ff1..445ba5a508 100644 --- a/crates/starknet_client/src/reader/mod.rs +++ b/crates/starknet_client/src/reader/mod.rs @@ -18,6 +18,7 @@ use mockall::automock; use papyrus_common::pending_classes::ApiContractClass; use serde::{Deserialize, Serialize}; use starknet_api::block::BlockNumber; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, SequencerPublicKey}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::transaction::TransactionHash; @@ -93,7 +94,7 @@ pub trait StarknetReader { async fn compiled_class_by_hash( &self, class_hash: ClassHash, - ) -> ReaderClientResult>; + ) -> ReaderClientResult>; /// Returns a [`starknet_client`][`StateUpdate`] corresponding to `block_number`. async fn state_update( &self, @@ -276,7 +277,7 @@ impl StarknetReader for StarknetFeederGatewayClient { async fn compiled_class_by_hash( &self, class_hash: ClassHash, - ) -> ReaderClientResult> { + ) -> ReaderClientResult> { debug!("Got compiled_class_by_hash {} from starknet server.", class_hash); // FIXME: Remove the following default CasmContractClass once integration environment gets // regenesissed. @@ -302,15 +303,18 @@ impl StarknetReader for StarknetFeederGatewayClient { .contains(&class_hash) { debug!("Using default compiled class for class hash {}.", class_hash); - return Ok(Some(CasmContractClass { - prime: Default::default(), - compiler_version: String::default(), - bytecode: vec![], - bytecode_segment_lengths: None, - hints: vec![], - pythonic_hints: None, - entry_points_by_type: CasmContractEntryPoints::default(), - })); + return Ok(Some(( + CasmContractClass { + prime: Default::default(), + compiler_version: String::default(), + bytecode: vec![], + bytecode_segment_lengths: None, + hints: vec![], + pythonic_hints: None, + entry_points_by_type: CasmContractEntryPoints::default(), + }, + SierraVersion::default(), + ))); } let mut url = self.urls.get_compiled_class_by_class_hash.clone(); diff --git a/crates/starknet_client/src/reader/starknet_feeder_gateway_client_test.rs b/crates/starknet_client/src/reader/starknet_feeder_gateway_client_test.rs index a2cf5cf262..4c9ea56f66 100644 --- a/crates/starknet_client/src/reader/starknet_feeder_gateway_client_test.rs +++ b/crates/starknet_client/src/reader/starknet_feeder_gateway_client_test.rs @@ -404,7 +404,7 @@ async fn compiled_class_by_hash() { .with_status(200) .with_body(&raw_casm_contract_class) .create(); - let casm_contract_class = + let (casm_contract_class, _) = starknet_client.compiled_class_by_hash(class_hash!("0x7")).await.unwrap().unwrap(); mock_casm_contract_class.assert(); let expected_casm_contract_class: CasmContractClass = diff --git a/crates/starknet_gateway/src/rpc_state_reader.rs b/crates/starknet_gateway/src/rpc_state_reader.rs index a01ccdc703..848981006f 100644 --- a/crates/starknet_gateway/src/rpc_state_reader.rs +++ b/crates/starknet_gateway/src/rpc_state_reader.rs @@ -2,6 +2,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; @@ -10,6 +11,7 @@ use reqwest::blocking::Client as BlockingClient; use serde::Serialize; use serde_json::{json, Value}; use starknet_api::block::{BlockInfo, BlockNumber}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; @@ -138,20 +140,31 @@ impl BlockifierStateReader for RpcStateReader { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let get_compiled_class_params = GetCompiledClassParams { class_hash, block_id: self.block_id }; let result = self.send_rpc_request("starknet_getCompiledContractClass", get_compiled_class_params)?; - let contract_class: CompiledContractClass = + let versioned_contract_class: (CompiledContractClass, SierraVersion) = serde_json::from_value(result).map_err(serde_err_to_state_err)?; - match contract_class { - CompiledContractClass::V1(contract_class_v1) => Ok(RunnableCompiledClass::V1( - CompiledClassV1::try_from(contract_class_v1).map_err(StateError::ProgramError)?, + match versioned_contract_class.0 { + CompiledContractClass::V1(contract_class_v1) => Ok(( + RunnableCompiledClass::V1( + CompiledClassV1::try_from(contract_class_v1) + .map_err(StateError::ProgramError)?, + ), + versioned_contract_class.1, )), - CompiledContractClass::V0(contract_class_v0) => Ok(RunnableCompiledClass::V0( - CompiledClassV0::try_from(contract_class_v0).map_err(StateError::ProgramError)?, + CompiledContractClass::V0(contract_class_v0) => Ok(( + RunnableCompiledClass::V0( + CompiledClassV0::try_from(contract_class_v0) + .map_err(StateError::ProgramError)?, + ), + SierraVersion::zero(), )), } } diff --git a/crates/starknet_gateway/src/rpc_state_reader_test.rs b/crates/starknet_gateway/src/rpc_state_reader_test.rs index 34756045fb..749db039ab 100644 --- a/crates/starknet_gateway/src/rpc_state_reader_test.rs +++ b/crates/starknet_gateway/src/rpc_state_reader_test.rs @@ -180,10 +180,11 @@ async fn test_get_compiled_class() { ); let client = RpcStateReader::from_latest(&config); - let result = tokio::task::spawn_blocking(move || client.get_compiled_class(class_hash!("0x1"))) - .await - .unwrap() - .unwrap(); + let (result, _) = + tokio::task::spawn_blocking(move || client.get_compiled_contract_class(class_hash!("0x1"))) + .await + .unwrap() + .unwrap(); assert_eq!(result, RunnableCompiledClass::V1(expected_result.try_into().unwrap())); mock.assert_async().await; } diff --git a/crates/starknet_gateway/src/state_reader.rs b/crates/starknet_gateway/src/state_reader.rs index b026bbe6ef..857f6ee3c8 100644 --- a/crates/starknet_gateway/src/state_reader.rs +++ b/crates/starknet_gateway/src/state_reader.rs @@ -1,4 +1,4 @@ -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::VersionedRunnableCompiledClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; #[cfg(test)] @@ -44,8 +44,11 @@ impl BlockifierStateReader for Box { self.as_ref().get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - self.as_ref().get_compiled_class(class_hash) + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + self.as_ref().get_compiled_contract_class(class_hash) } fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { diff --git a/crates/starknet_gateway/src/state_reader_test_utils.rs b/crates/starknet_gateway/src/state_reader_test_utils.rs index 80c8fc8b51..b4ad87dede 100644 --- a/crates/starknet_gateway/src/state_reader_test_utils.rs +++ b/crates/starknet_gateway/src/state_reader_test_utils.rs @@ -1,5 +1,5 @@ use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::VersionedRunnableCompiledClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; use blockifier::test_utils::contracts::FeatureContract; @@ -43,8 +43,8 @@ impl BlockifierStateReader for TestStateReader { self.blockifier_state_reader.get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - self.blockifier_state_reader.get_compiled_class(class_hash) + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + self.blockifier_state_reader.get_compiled_contract_class(class_hash) } fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { diff --git a/crates/starknet_integration_tests/src/state_reader.rs b/crates/starknet_integration_tests/src/state_reader.rs index ca88fb9cb3..cfb22800a0 100644 --- a/crates/starknet_integration_tests/src/state_reader.rs +++ b/crates/starknet_integration_tests/src/state_reader.rs @@ -35,6 +35,7 @@ use starknet_api::block::{ FeeType, GasPricePerToken, }; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ChainId, ClassHash, ContractAddress, Nonce, SequencerContractAddress}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use starknet_api::state::{StorageKey, ThinStateDiff}; @@ -198,7 +199,8 @@ fn write_state_to_papyrus_storage( let mut write_txn = storage_writer.begin_rw_txn().unwrap(); for (class_hash, casm) in cairo1_contract_classes { - write_txn = write_txn.append_casm(class_hash, casm).unwrap(); + write_txn = + write_txn.append_versioned_casm(class_hash, &(casm, SierraVersion::latest())).unwrap(); } write_txn .append_header(block_number, &block_header)