From f34812f52de0c84ef0f98170610217c37a95c013 Mon Sep 17 00:00:00 2001 From: AvivYossef-starkware Date: Sun, 15 Dec 2024 17:11:41 +0200 Subject: [PATCH] chore(papyrus_execution): get versioned contract class --- .../papyrus_execution/src/execution_utils.rs | 7 ++++++- crates/papyrus_execution/src/state_reader.rs | 20 ++++++++++++++++--- .../src/state_reader_test.rs | 14 +++++++------ 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/crates/papyrus_execution/src/execution_utils.rs b/crates/papyrus_execution/src/execution_utils.rs index 9a6f34ffb79..43f941bc1c9 100644 --- a/crates/papyrus_execution/src/execution_utils.rs +++ b/crates/papyrus_execution/src/execution_utils.rs @@ -6,6 +6,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::cached_state::{CachedState, CommitmentStateDiff, MutRefState}; use blockifier::state::state_api::StateReader; @@ -19,6 +20,7 @@ use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageResult, StorageTxn}; // Expose the tool for creating entry point selectors from function names. pub use starknet_api::abi::abi_utils::selector_from_name; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey, ThinStateDiff}; use starknet_types_core::felt::Felt; @@ -43,6 +45,8 @@ pub(crate) enum ExecutionUtilsError { StorageError(#[from] StorageError), #[error("Casm table not fully synced")] CasmTableNotSynced, + #[error(transparent)] + SierraValidationError(#[from] starknet_api::StarknetApiError), } /// Returns the execution config from the config file. @@ -63,9 +67,10 @@ pub(crate) fn get_contract_class( match txn.get_state_reader()?.get_class_definition_block_number(class_hash)? { Some(block_number) if state_number.is_before(block_number) => return Ok(None), Some(_block_number) => { - let Some(casm) = txn.get_casm(class_hash)? else { + let (Some(casm), Some(sierra)) = txn.get_casm_and_sierra(class_hash)? else { return Err(ExecutionUtilsError::CasmTableNotSynced); }; + let _sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)?; return Ok(Some(RunnableCompiledClass::V1( CompiledClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?, ))); diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index b67aaa170e3..aa499a8385c 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -15,6 +15,7 @@ use papyrus_common::pending_classes::{ApiContractClass, PendingClassesTrait}; use papyrus_common::state::DeclaredClassHashEntry; use papyrus_storage::state::StateStorageReader; use papyrus_storage::{StorageError, StorageReader}; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey}; use starknet_types_core::felt::Felt; @@ -81,10 +82,20 @@ impl BlockifierStateReader for ExecutionStateReader { .as_ref() .and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash)) { - return Ok(RunnableCompiledClass::V1( - CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, - )); + if let Some(ApiContractClass::ContractClass(sierra)) = self + .maybe_pending_data + .as_ref() + .and_then(|pending_data| pending_data.classes.get_class(class_hash)) + { + let runnable_compiled_class = RunnableCompiledClass::V1( + CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, + ); + let _sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)?; + // TODO(AVIV): Use the sierra version when the return type is updated. + return Ok(runnable_compiled_class); + } } + if let Some(ApiContractClass::DeprecatedContractClass(pending_deprecated_class)) = self .maybe_pending_data .as_ref() @@ -108,6 +119,9 @@ impl BlockifierStateReader for ExecutionStateReader { } Err(ExecutionUtilsError::ProgramError(err)) => Err(StateError::ProgramError(err)), Err(ExecutionUtilsError::StorageError(err)) => Err(storage_err_to_state_err(err)), + Err(ExecutionUtilsError::SierraValidationError(err)) => { + Err(StateError::StarknetApiError(err)) + } } } diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 8e6a3c9558d..93a71249655 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -62,10 +62,11 @@ fn read_state() { let storage_value2 = felt!(999_u128); let class_hash2 = ClassHash(1234u128.into()); let compiled_class_hash2 = CompiledClassHash(StarkHash::TWO); - let mut casm1 = get_test_casm(); - casm1.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; - let blockifier_casm1 = - RunnableCompiledClass::V1(CompiledClassV1::try_from(casm1.clone()).unwrap()); + let mut casm2 = get_test_casm(); + casm2.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; + let class2 = SierraContractClass::default(); + let blockifier_casm2 = + RunnableCompiledClass::V1(CompiledClassV1::try_from(casm2.clone()).unwrap()); let nonce1 = Nonce(felt!(2_u128)); let class_hash3 = ClassHash(567_u128.into()); let class_hash4 = ClassHash(89_u128.into()); @@ -204,7 +205,8 @@ fn read_state() { // Test pending state diff let mut pending_classes = PendingClasses::default(); - pending_classes.add_compiled_class(class_hash2, casm1); + pending_classes.add_compiled_class(class_hash2, casm2); + pending_classes.add_class(class_hash2, ApiContractClass::ContractClass(class2)); pending_classes.add_class(class_hash3, ApiContractClass::ContractClass(class0)); pending_classes .add_class(class_hash4, ApiContractClass::DeprecatedContractClass(class1.clone())); @@ -234,7 +236,7 @@ fn read_state() { assert_eq!(state_reader2.get_nonce_at(address0).unwrap(), nonce0); assert_eq!(state_reader2.get_nonce_at(address2).unwrap(), nonce1); assert_eq!(state_reader2.get_compiled_class(class_hash0).unwrap(), blockifier_casm0); - assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm1); + assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm2); // Test that an error is returned if we only got the class without the casm. state_reader2.get_compiled_class(class_hash3).unwrap_err(); // Test that if the class is deprecated it is returned.