Skip to content

Commit

Permalink
chore(papyrus_execution): get versioned contract class
Browse files Browse the repository at this point in the history
  • Loading branch information
AvivYossef-starkware committed Dec 19, 2024
1 parent 2393f51 commit a9e3e0b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
7 changes: 6 additions & 1 deletion crates/papyrus_execution/src/execution_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use papyrus_storage::state::StateStorageReader;
use papyrus_storage::{StorageError, StorageResult, StorageTxn};
// Expose the tool for creating entry point selectors from function names.
pub use starknet_api::abi::abi_utils::selector_from_name;
use starknet_api::contract_class::SierraVersion;
use starknet_api::core::{ClassHash, ContractAddress, Nonce};
use starknet_api::state::{StateNumber, StorageKey, ThinStateDiff};
use starknet_types_core::felt::Felt;
Expand All @@ -43,6 +44,8 @@ pub(crate) enum ExecutionUtilsError {
StorageError(#[from] StorageError),
#[error("Casm table not fully synced")]
CasmTableNotSynced,
#[error(transparent)]
SierraValidationError(starknet_api::StarknetApiError),
}

/// Returns the execution config from the config file.
Expand All @@ -63,9 +66,11 @@ pub(crate) fn get_contract_class(
match txn.get_state_reader()?.get_class_definition_block_number(class_hash)? {
Some(block_number) if state_number.is_before(block_number) => return Ok(None),
Some(_block_number) => {
let Some(casm) = txn.get_casm(class_hash)? else {
let (Some(casm), Some(sierra)) = txn.get_casm_and_sierra(class_hash)? else {
return Err(ExecutionUtilsError::CasmTableNotSynced);
};
let _sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)
.map_err(ExecutionUtilsError::SierraValidationError);
return Ok(Some(RunnableCompiledClass::V1(
CompiledClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?,
)));
Expand Down
45 changes: 28 additions & 17 deletions crates/papyrus_execution/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use papyrus_common::pending_classes::{ApiContractClass, PendingClassesTrait};
use papyrus_common::state::DeclaredClassHashEntry;
use papyrus_storage::state::StateStorageReader;
use papyrus_storage::{StorageError, StorageReader};
use starknet_api::contract_class::SierraVersion;
use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::{StateNumber, StorageKey};
use starknet_types_core::felt::Felt;
Expand Down Expand Up @@ -76,24 +77,31 @@ impl BlockifierStateReader for ExecutionStateReader {
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
if let Some(pending_casm) = self
.maybe_pending_data
.as_ref()
.and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash))
if let Some(pending_classes) =
self.maybe_pending_data.as_ref().map(|pending_data| &pending_data.classes)
{
return Ok(RunnableCompiledClass::V1(
CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?,
));
}
if let Some(ApiContractClass::DeprecatedContractClass(pending_deprecated_class)) = self
.maybe_pending_data
.as_ref()
.and_then(|pending_data| pending_data.classes.get_class(class_hash))
{
return Ok(RunnableCompiledClass::V0(
CompiledClassV0::try_from(pending_deprecated_class)
.map_err(StateError::ProgramError)?,
));
if let Some(api_contract_class) = pending_classes.get_class(class_hash) {
match api_contract_class {
ApiContractClass::ContractClass(sierra) => {
if let Some(pending_casm) = pending_classes.get_compiled_class(class_hash) {
let runnable_compiled_class = RunnableCompiledClass::V1(
CompiledClassV1::try_from(pending_casm)
.map_err(StateError::ProgramError)?,
);
let _sierra_version =
SierraVersion::extract_from_program(&sierra.sierra_program)?;
// TODO: Use the Sierra version when the return type is updated.
return Ok(runnable_compiled_class);
}
}
ApiContractClass::DeprecatedContractClass(pending_deprecated_class) => {
return Ok(RunnableCompiledClass::V0(
CompiledClassV0::try_from(pending_deprecated_class)
.map_err(StateError::ProgramError)?,
));
}
}
}
}
match get_contract_class(
&self.storage_reader.begin_ro_txn().map_err(storage_err_to_state_err)?,
Expand All @@ -108,6 +116,9 @@ impl BlockifierStateReader for ExecutionStateReader {
}
Err(ExecutionUtilsError::ProgramError(err)) => Err(StateError::ProgramError(err)),
Err(ExecutionUtilsError::StorageError(err)) => Err(storage_err_to_state_err(err)),
Err(ExecutionUtilsError::SierraValidationError(err)) => {
Err(StateError::StarknetApiError(err))
}
}
}

Expand Down
14 changes: 8 additions & 6 deletions crates/papyrus_execution/src/state_reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ fn read_state() {
let storage_value2 = felt!(999_u128);
let class_hash2 = ClassHash(1234u128.into());
let compiled_class_hash2 = CompiledClassHash(StarkHash::TWO);
let mut casm1 = get_test_casm();
casm1.bytecode[0] = BigUintAsHex { value: 12345u32.into() };
let blockifier_casm1 =
RunnableCompiledClass::V1(CompiledClassV1::try_from(casm1.clone()).unwrap());
let mut casm2 = get_test_casm();
casm2.bytecode[0] = BigUintAsHex { value: 12345u32.into() };
let class2 = SierraContractClass::default();
let blockifier_casm2 =
RunnableCompiledClass::V1(CompiledClassV1::try_from(casm2.clone()).unwrap());
let nonce1 = Nonce(felt!(2_u128));
let class_hash3 = ClassHash(567_u128.into());
let class_hash4 = ClassHash(89_u128.into());
Expand Down Expand Up @@ -204,7 +205,8 @@ fn read_state() {

// Test pending state diff
let mut pending_classes = PendingClasses::default();
pending_classes.add_compiled_class(class_hash2, casm1);
pending_classes.add_compiled_class(class_hash2, casm2);
pending_classes.add_class(class_hash2, ApiContractClass::ContractClass(class2));
pending_classes.add_class(class_hash3, ApiContractClass::ContractClass(class0));
pending_classes
.add_class(class_hash4, ApiContractClass::DeprecatedContractClass(class1.clone()));
Expand Down Expand Up @@ -234,7 +236,7 @@ fn read_state() {
assert_eq!(state_reader2.get_nonce_at(address0).unwrap(), nonce0);
assert_eq!(state_reader2.get_nonce_at(address2).unwrap(), nonce1);
assert_eq!(state_reader2.get_compiled_class(class_hash0).unwrap(), blockifier_casm0);
assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm1);
assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm2);
// Test that an error is returned if we only got the class without the casm.
state_reader2.get_compiled_class(class_hash3).unwrap_err();
// Test that if the class is deprecated it is returned.
Expand Down

0 comments on commit a9e3e0b

Please sign in to comment.