diff --git a/crates/papyrus_execution/src/execution_utils.rs b/crates/papyrus_execution/src/execution_utils.rs index 9a6f34ffb7..ac1d986419 100644 --- a/crates/papyrus_execution/src/execution_utils.rs +++ b/crates/papyrus_execution/src/execution_utils.rs @@ -6,6 +6,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::cached_state::{CachedState, CommitmentStateDiff, MutRefState}; use blockifier::state::state_api::StateReader; @@ -19,6 +20,7 @@ use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageResult, StorageTxn}; // Expose the tool for creating entry point selectors from function names. pub use starknet_api::abi::abi_utils::selector_from_name; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey, ThinStateDiff}; use starknet_types_core::felt::Felt; @@ -43,6 +45,8 @@ pub(crate) enum ExecutionUtilsError { StorageError(#[from] StorageError), #[error("Casm table not fully synced")] CasmTableNotSynced, + #[error(transparent)] + SierraValidationError(#[from] starknet_api::StarknetApiError), } /// Returns the execution config from the config file. @@ -55,20 +59,25 @@ impl TryFrom for ExecutionConfig { } } -pub(crate) fn get_contract_class( +pub(crate) fn get_versioned_contract_class( txn: &StorageTxn<'_, RO>, class_hash: &ClassHash, state_number: StateNumber, -) -> Result, ExecutionUtilsError> { +) -> Result, ExecutionUtilsError> { match txn.get_state_reader()?.get_class_definition_block_number(class_hash)? { Some(block_number) if state_number.is_before(block_number) => return Ok(None), Some(_block_number) => { - let Some(casm) = txn.get_casm(class_hash)? else { + let (Some(casm), Some(sierra)) = txn.get_casm_and_sierra(class_hash)? else { return Err(ExecutionUtilsError::CasmTableNotSynced); }; - return Ok(Some(RunnableCompiledClass::V1( + let runnable_compiled_class = RunnableCompiledClass::V1( CompiledClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?, - ))); + ); + let sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)?; + return Ok(Some(VersionedRunnableCompiledClass::Cairo1(( + runnable_compiled_class, + sierra_version, + )))); } None => {} }; @@ -78,9 +87,9 @@ pub(crate) fn get_contract_class( else { return Ok(None); }; - Ok(Some(RunnableCompiledClass::V0( + Ok(Some(VersionedRunnableCompiledClass::Cairo0(RunnableCompiledClass::V0( CompiledClassV0::try_from(deprecated_class).map_err(ExecutionUtilsError::ProgramError)?, - ))) + )))) } /// Given an ExecutableTransactionInput, returns a function that will convert the corresponding diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index b67aaa170e..da3377d71c 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -15,12 +15,13 @@ use papyrus_common::pending_classes::{ApiContractClass, PendingClassesTrait}; use papyrus_common::state::DeclaredClassHashEntry; use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageReader}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey}; use starknet_types_core::felt::Felt; use crate::execution_utils; -use crate::execution_utils::{get_contract_class, ExecutionUtilsError}; +use crate::execution_utils::{get_versioned_contract_class, ExecutionUtilsError}; use crate::objects::PendingData; /// A view into the state at a specific state number. @@ -81,10 +82,20 @@ impl BlockifierStateReader for ExecutionStateReader { .as_ref() .and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash)) { - return Ok(RunnableCompiledClass::V1( - CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, - )); + if let Some(ApiContractClass::ContractClass(sierra)) = self + .maybe_pending_data + .as_ref() + .and_then(|pending_data| pending_data.classes.get_class(class_hash)) + { + let runnable_compiled_class = RunnableCompiledClass::V1( + CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, + ); + let _sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)?; + // TODO(AVIV): Use the sierra version when the return type is updated. + return Ok(runnable_compiled_class); + } } + if let Some(ApiContractClass::DeprecatedContractClass(pending_deprecated_class)) = self .maybe_pending_data .as_ref() @@ -95,12 +106,12 @@ impl BlockifierStateReader for ExecutionStateReader { .map_err(StateError::ProgramError)?, )); } - match get_contract_class( + match get_versioned_contract_class( &self.storage_reader.begin_ro_txn().map_err(storage_err_to_state_err)?, &class_hash, self.state_number, ) { - Ok(Some(contract_class)) => Ok(contract_class), + Ok(Some(versioned_contract_class)) => Ok(versioned_contract_class.into()), Ok(None) => Err(StateError::UndeclaredClassHash(class_hash)), Err(ExecutionUtilsError::CasmTableNotSynced) => { self.missing_compiled_class.set(Some(class_hash)); @@ -108,6 +119,9 @@ impl BlockifierStateReader for ExecutionStateReader { } Err(ExecutionUtilsError::ProgramError(err)) => Err(StateError::ProgramError(err)), Err(ExecutionUtilsError::StorageError(err)) => Err(storage_err_to_state_err(err)), + Err(ExecutionUtilsError::SierraValidationError(err)) => { + Err(StateError::StarknetApiError(err)) + } } } diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 8e6a3c9558..93a7124965 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -62,10 +62,11 @@ fn read_state() { let storage_value2 = felt!(999_u128); let class_hash2 = ClassHash(1234u128.into()); let compiled_class_hash2 = CompiledClassHash(StarkHash::TWO); - let mut casm1 = get_test_casm(); - casm1.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; - let blockifier_casm1 = - RunnableCompiledClass::V1(CompiledClassV1::try_from(casm1.clone()).unwrap()); + let mut casm2 = get_test_casm(); + casm2.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; + let class2 = SierraContractClass::default(); + let blockifier_casm2 = + RunnableCompiledClass::V1(CompiledClassV1::try_from(casm2.clone()).unwrap()); let nonce1 = Nonce(felt!(2_u128)); let class_hash3 = ClassHash(567_u128.into()); let class_hash4 = ClassHash(89_u128.into()); @@ -204,7 +205,8 @@ fn read_state() { // Test pending state diff let mut pending_classes = PendingClasses::default(); - pending_classes.add_compiled_class(class_hash2, casm1); + pending_classes.add_compiled_class(class_hash2, casm2); + pending_classes.add_class(class_hash2, ApiContractClass::ContractClass(class2)); pending_classes.add_class(class_hash3, ApiContractClass::ContractClass(class0)); pending_classes .add_class(class_hash4, ApiContractClass::DeprecatedContractClass(class1.clone())); @@ -234,7 +236,7 @@ fn read_state() { assert_eq!(state_reader2.get_nonce_at(address0).unwrap(), nonce0); assert_eq!(state_reader2.get_nonce_at(address2).unwrap(), nonce1); assert_eq!(state_reader2.get_compiled_class(class_hash0).unwrap(), blockifier_casm0); - assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm1); + assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm2); // Test that an error is returned if we only got the class without the casm. state_reader2.get_compiled_class(class_hash3).unwrap_err(); // Test that if the class is deprecated it is returned.