From a9e3e0b07510ebb82f419408eea81bf6e25db89b Mon Sep 17 00:00:00 2001 From: AvivYossef-starkware Date: Sun, 15 Dec 2024 17:11:41 +0200 Subject: [PATCH] chore(papyrus_execution): get versioned contract class --- .../papyrus_execution/src/execution_utils.rs | 7 ++- crates/papyrus_execution/src/state_reader.rs | 45 ++++++++++++------- .../src/state_reader_test.rs | 14 +++--- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/crates/papyrus_execution/src/execution_utils.rs b/crates/papyrus_execution/src/execution_utils.rs index 9a6f34ffb7..7ea11e1909 100644 --- a/crates/papyrus_execution/src/execution_utils.rs +++ b/crates/papyrus_execution/src/execution_utils.rs @@ -19,6 +19,7 @@ use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageResult, StorageTxn}; // Expose the tool for creating entry point selectors from function names. pub use starknet_api::abi::abi_utils::selector_from_name; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey, ThinStateDiff}; use starknet_types_core::felt::Felt; @@ -43,6 +44,8 @@ pub(crate) enum ExecutionUtilsError { StorageError(#[from] StorageError), #[error("Casm table not fully synced")] CasmTableNotSynced, + #[error(transparent)] + SierraValidationError(starknet_api::StarknetApiError), } /// Returns the execution config from the config file. @@ -63,9 +66,11 @@ pub(crate) fn get_contract_class( match txn.get_state_reader()?.get_class_definition_block_number(class_hash)? { Some(block_number) if state_number.is_before(block_number) => return Ok(None), Some(_block_number) => { - let Some(casm) = txn.get_casm(class_hash)? else { + let (Some(casm), Some(sierra)) = txn.get_casm_and_sierra(class_hash)? else { return Err(ExecutionUtilsError::CasmTableNotSynced); }; + let _sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program) + .map_err(ExecutionUtilsError::SierraValidationError); return Ok(Some(RunnableCompiledClass::V1( CompiledClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?, ))); diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index b67aaa170e..ee905ea05f 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -15,6 +15,7 @@ use papyrus_common::pending_classes::{ApiContractClass, PendingClassesTrait}; use papyrus_common::state::DeclaredClassHashEntry; use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageReader}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey}; use starknet_types_core::felt::Felt; @@ -76,24 +77,31 @@ impl BlockifierStateReader for ExecutionStateReader { } fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - if let Some(pending_casm) = self - .maybe_pending_data - .as_ref() - .and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash)) + if let Some(pending_classes) = + self.maybe_pending_data.as_ref().map(|pending_data| &pending_data.classes) { - return Ok(RunnableCompiledClass::V1( - CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, - )); - } - if let Some(ApiContractClass::DeprecatedContractClass(pending_deprecated_class)) = self - .maybe_pending_data - .as_ref() - .and_then(|pending_data| pending_data.classes.get_class(class_hash)) - { - return Ok(RunnableCompiledClass::V0( - CompiledClassV0::try_from(pending_deprecated_class) - .map_err(StateError::ProgramError)?, - )); + if let Some(api_contract_class) = pending_classes.get_class(class_hash) { + match api_contract_class { + ApiContractClass::ContractClass(sierra) => { + if let Some(pending_casm) = pending_classes.get_compiled_class(class_hash) { + let runnable_compiled_class = RunnableCompiledClass::V1( + CompiledClassV1::try_from(pending_casm) + .map_err(StateError::ProgramError)?, + ); + let _sierra_version = + SierraVersion::extract_from_program(&sierra.sierra_program)?; + // TODO: Use the Sierra version when the return type is updated. + return Ok(runnable_compiled_class); + } + } + ApiContractClass::DeprecatedContractClass(pending_deprecated_class) => { + return Ok(RunnableCompiledClass::V0( + CompiledClassV0::try_from(pending_deprecated_class) + .map_err(StateError::ProgramError)?, + )); + } + } + } } match get_contract_class( &self.storage_reader.begin_ro_txn().map_err(storage_err_to_state_err)?, @@ -108,6 +116,9 @@ impl BlockifierStateReader for ExecutionStateReader { } Err(ExecutionUtilsError::ProgramError(err)) => Err(StateError::ProgramError(err)), Err(ExecutionUtilsError::StorageError(err)) => Err(storage_err_to_state_err(err)), + Err(ExecutionUtilsError::SierraValidationError(err)) => { + Err(StateError::StarknetApiError(err)) + } } } diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 8e6a3c9558..93a7124965 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -62,10 +62,11 @@ fn read_state() { let storage_value2 = felt!(999_u128); let class_hash2 = ClassHash(1234u128.into()); let compiled_class_hash2 = CompiledClassHash(StarkHash::TWO); - let mut casm1 = get_test_casm(); - casm1.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; - let blockifier_casm1 = - RunnableCompiledClass::V1(CompiledClassV1::try_from(casm1.clone()).unwrap()); + let mut casm2 = get_test_casm(); + casm2.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; + let class2 = SierraContractClass::default(); + let blockifier_casm2 = + RunnableCompiledClass::V1(CompiledClassV1::try_from(casm2.clone()).unwrap()); let nonce1 = Nonce(felt!(2_u128)); let class_hash3 = ClassHash(567_u128.into()); let class_hash4 = ClassHash(89_u128.into()); @@ -204,7 +205,8 @@ fn read_state() { // Test pending state diff let mut pending_classes = PendingClasses::default(); - pending_classes.add_compiled_class(class_hash2, casm1); + pending_classes.add_compiled_class(class_hash2, casm2); + pending_classes.add_class(class_hash2, ApiContractClass::ContractClass(class2)); pending_classes.add_class(class_hash3, ApiContractClass::ContractClass(class0)); pending_classes .add_class(class_hash4, ApiContractClass::DeprecatedContractClass(class1.clone())); @@ -234,7 +236,7 @@ fn read_state() { assert_eq!(state_reader2.get_nonce_at(address0).unwrap(), nonce0); assert_eq!(state_reader2.get_nonce_at(address2).unwrap(), nonce1); assert_eq!(state_reader2.get_compiled_class(class_hash0).unwrap(), blockifier_casm0); - assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm1); + assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm2); // Test that an error is returned if we only got the class without the casm. state_reader2.get_compiled_class(class_hash3).unwrap_err(); // Test that if the class is deprecated it is returned.