diff --git a/Cargo.lock b/Cargo.lock index 849e52446a..3c86b6de9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7716,6 +7716,7 @@ dependencies = [ "prometheus-parse", "rand 0.8.5", "rand_chacha 0.3.1", + "rstest", "schemars", "serde", "serde_json", diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index 56b2ed1f9b..85c135f7a3 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -21,7 +21,7 @@ use itertools::Itertools; use semver::Version; use serde::de::Error as DeserializationError; use serde::{Deserialize, Deserializer, Serialize}; -use starknet_api::contract_class::{ContractClass, EntryPointType}; +use starknet_api::contract_class::{ContractClass, EntryPointType, SierraVersion}; use starknet_api::core::EntryPointSelector; use starknet_api::deprecated_contract_class::{ ContractClass as DeprecatedContractClass, @@ -67,6 +67,25 @@ pub enum RunnableCompiledClass { V1Native(NativeCompiledClassV1), } +/// Represents a runnable compiled class for Cairo, with the Sierra version (for Cairo 1). +pub enum VersionedRunnableCompiledClass { + Cairo0(RunnableCompiledClass), + Cairo1((RunnableCompiledClass, SierraVersion)), +} + +impl From for RunnableCompiledClass { + fn from( + versioned_runnable_compiled_class: VersionedRunnableCompiledClass, + ) -> RunnableCompiledClass { + match versioned_runnable_compiled_class { + VersionedRunnableCompiledClass::Cairo0(runnable_compiled_class) + | VersionedRunnableCompiledClass::Cairo1((runnable_compiled_class, _)) => { + runnable_compiled_class + } + } + } +} + impl TryFrom for RunnableCompiledClass { type Error = ProgramError; diff --git a/crates/blockifier/src/state/errors.rs b/crates/blockifier/src/state/errors.rs index 1eba9b18bf..c3ad9fbf0d 100644 --- a/crates/blockifier/src/state/errors.rs +++ b/crates/blockifier/src/state/errors.rs @@ -1,6 +1,8 @@ +use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use cairo_vm::types::errors::program_errors::ProgramError; use num_bigint::{BigUint, TryFromBigIntError}; use starknet_api::core::{ClassHash, ContractAddress}; +use starknet_api::state::SierraContractClass; use starknet_api::StarknetApiError; use thiserror::Error; @@ -8,6 +10,8 @@ use crate::abi::constants; #[derive(Debug, Error)] pub enum StateError { + #[error("CASM and Sierra mismatch for class hash {:#064x}: {message}.", class_hash.0)] + CasmAndSierraMismatch { class_hash: ClassHash, message: String }, #[error(transparent)] FromBigUint(#[from] TryFromBigIntError), #[error( @@ -29,3 +33,25 @@ pub enum StateError { #[error("Failed to read from state: {0}.")] StateReadError(String), } + +/// Ensures that the CASM and Sierra classes are coupled - Meaning that they both exist or are +/// missing. Returns a `CasmAndSierraMismatch` error when there is an inconsistency in their +/// existence. +pub fn couple_casm_and_sierra( + class_hash: ClassHash, + option_casm: Option, + option_sierra: Option, +) -> Result, StateError> { + match (option_casm, option_sierra) { + (Some(casm), Some(sierra)) => Ok(Some((casm, sierra))), + (Some(_), None) => Err(StateError::CasmAndSierraMismatch { + class_hash, + message: "Class exists in CASM but not in Sierra".to_string(), + }), + (None, Some(_)) => Err(StateError::CasmAndSierraMismatch { + class_hash, + message: "Class exists in Sierra but not in CASM".to_string(), + }), + (None, None) => Ok(None), + } +} diff --git a/crates/papyrus_state_reader/src/papyrus_state.rs b/crates/papyrus_state_reader/src/papyrus_state.rs index 2af2651d10..0f9fe0b74c 100644 --- a/crates/papyrus_state_reader/src/papyrus_state.rs +++ b/crates/papyrus_state_reader/src/papyrus_state.rs @@ -2,8 +2,9 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; -use blockifier::state::errors::StateError; +use blockifier::state::errors::{couple_casm_and_sierra, StateError}; use blockifier::state::global_cache::GlobalContractCache; use blockifier::state::state_api::{StateReader, StateResult}; use papyrus_storage::compiled_class::CasmStorageReader; @@ -11,6 +12,7 @@ use papyrus_storage::db::RO; use papyrus_storage::state::StateStorageReader; use papyrus_storage::StorageReader; use starknet_api::block::BlockNumber; +use starknet_api::contract_class::SierraVersion; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StateNumber, StorageKey}; use starknet_types_core::felt::Felt; @@ -46,7 +48,7 @@ impl PapyrusReader { fn get_compiled_class_inner( &self, class_hash: ClassHash, - ) -> StateResult { + ) -> StateResult { let state_number = StateNumber(self.latest_block); let class_declaration_block_number = self .reader()? @@ -57,16 +59,20 @@ impl PapyrusReader { Some(block_number) if block_number <= state_number.0); if class_is_declared { - let casm_compiled_class = self + let (option_casm, option_sierra) = self .reader()? - .get_casm(&class_hash) - .map_err(|err| StateError::StateReadError(err.to_string()))? - .expect( - "Should be able to fetch a Casm class if its definition exists, database is \ - inconsistent.", + .get_casm_and_sierra(&class_hash) + .map_err(|err| StateError::StateReadError(err.to_string()))?; + let (casm_compiled_class, sierra) = + couple_casm_and_sierra(class_hash, option_casm, option_sierra)?.expect( + "Should be able to fetch a Casm and Sierra class if its definition exists, \ + database is inconsistent.", ); + let sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)?; + let runnable_compiled = + RunnableCompiledClass::V1(CompiledClassV1::try_from(casm_compiled_class)?); - return Ok(RunnableCompiledClass::V1(CompiledClassV1::try_from(casm_compiled_class)?)); + return Ok(VersionedRunnableCompiledClass::Cairo1((runnable_compiled, sierra_version))); } let v0_compiled_class = self @@ -76,9 +82,9 @@ impl PapyrusReader { .map_err(|err| StateError::StateReadError(err.to_string()))?; match v0_compiled_class { - Some(starknet_api_contract_class) => { - Ok(CompiledClassV0::try_from(starknet_api_contract_class)?.into()) - } + Some(starknet_api_contract_class) => Ok(VersionedRunnableCompiledClass::Cairo0( + CompiledClassV0::try_from(starknet_api_contract_class)?.into(), + )), None => Err(StateError::UndeclaredClassHash(class_hash)), } } @@ -131,7 +137,8 @@ impl StateReader for PapyrusReader { match contract_class { Some(contract_class) => Ok(contract_class), None => { - let contract_class_from_db = self.get_compiled_class_inner(class_hash)?; + let contract_class_from_db = + RunnableCompiledClass::from(self.get_compiled_class_inner(class_hash)?); // The class was declared in a previous (finalized) state; update the global cache. self.global_class_hash_to_class.set(class_hash, contract_class_from_db.clone()); Ok(contract_class_from_db) diff --git a/crates/papyrus_storage/Cargo.toml b/crates/papyrus_storage/Cargo.toml index 925261d1f8..4320f91e12 100644 --- a/crates/papyrus_storage/Cargo.toml +++ b/crates/papyrus_storage/Cargo.toml @@ -62,6 +62,7 @@ pretty_assertions.workspace = true prometheus-parse.workspace = true rand.workspace = true rand_chacha.workspace = true +rstest.workspace = true schemars = { workspace = true, features = ["preserve_order"] } simple_logger.workspace = true tempfile = { workspace = true } diff --git a/crates/papyrus_storage/src/compiled_class.rs b/crates/papyrus_storage/src/compiled_class.rs index d0e2b656aa..1ee8c2d583 100644 --- a/crates/papyrus_storage/src/compiled_class.rs +++ b/crates/papyrus_storage/src/compiled_class.rs @@ -50,7 +50,9 @@ use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use papyrus_proc_macros::latency_histogram; use starknet_api::block::BlockNumber; use starknet_api::core::ClassHash; +use starknet_api::state::SierraContractClass; +use crate::class::ClassStorageReader; use crate::db::serialization::VersionZeroWrapper; use crate::db::table_types::{SimpleTable, Table}; use crate::db::{DbTransaction, TableHandle, TransactionKind, RW}; @@ -61,6 +63,15 @@ use crate::{FileHandlers, MarkerKind, MarkersTable, OffsetKind, StorageResult, S pub trait CasmStorageReader { /// Returns the Cairo assembly of a class given its Sierra class hash. fn get_casm(&self, class_hash: &ClassHash) -> StorageResult>; + + /// Returns the CASM and Sierra contract classes for the given hash. + /// If both exist, returns `(Some(casm), Some(sierra))`. + /// If neither, returns `(None, None)`. + /// If only one exists, returns `(Some, None)` or `(None, Some)`. + fn get_casm_and_sierra( + &self, + class_hash: &ClassHash, + ) -> StorageResult<(Option, Option)>; /// The block marker is the first block number that doesn't exist yet. /// /// Note: If the last blocks don't contain any declared classes, the marker will point at the @@ -85,6 +96,13 @@ impl CasmStorageReader for StorageTxn<'_, Mode> { casm_location.map(|location| self.file_handlers.get_casm_unchecked(location)).transpose() } + fn get_casm_and_sierra( + &self, + class_hash: &ClassHash, + ) -> StorageResult<(Option, Option)> { + Ok((self.get_casm(class_hash)?, self.get_class(class_hash)?)) + } + fn get_compiled_class_marker(&self) -> StorageResult { let markers_table = self.open_table(&self.tables.markers)?; Ok(markers_table.get(&self.txn, &MarkerKind::CompiledClass)?.unwrap_or_default()) diff --git a/crates/papyrus_storage/src/compiled_class_test.rs b/crates/papyrus_storage/src/compiled_class_test.rs index 204cada574..2e2aec5c01 100644 --- a/crates/papyrus_storage/src/compiled_class_test.rs +++ b/crates/papyrus_storage/src/compiled_class_test.rs @@ -1,9 +1,14 @@ use assert_matches::assert_matches; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; +use papyrus_test_utils::{get_rng, GetTestInstance}; use pretty_assertions::assert_eq; +use rstest::rstest; +use starknet_api::block::BlockNumber; use starknet_api::core::ClassHash; +use starknet_api::state::SierraContractClass; use starknet_api::test_utils::read_json_file; +use crate::class::ClassStorageWriter; use crate::compiled_class::{CasmStorageReader, CasmStorageWriter}; use crate::db::{DbError, KeyAlreadyExistsError}; use crate::test_utils::get_test_storage; @@ -27,6 +32,46 @@ fn append_casm() { assert_eq!(casm, expected_casm); } +#[rstest] +fn test_casm_and_sierra( + #[values(true, false)] has_casm: bool, + #[values(true, false)] has_sierra: bool, +) { + let test_class_hash = ClassHash::default(); + let mut rng = get_rng(); + + // Setup storage. + let ((reader, mut writer), _temp_dir) = get_test_storage(); + let expected_casm = CasmContractClass::get_test_instance(&mut rng); + let expected_sierra = ::get_test_instance(&mut rng); + + if has_casm { + writer + .begin_rw_txn() + .unwrap() + .append_casm(&test_class_hash, &expected_casm) + .unwrap() + .commit() + .unwrap(); + } + if has_sierra { + writer + .begin_rw_txn() + .unwrap() + .append_classes(BlockNumber::default(), &[(test_class_hash, &expected_sierra)], &[]) + .unwrap() + .commit() + .unwrap(); + } + + let result = reader.begin_ro_txn().unwrap().get_casm_and_sierra(&test_class_hash); + + assert_eq!( + result.unwrap(), + (has_casm.then_some(expected_casm), has_sierra.then_some(expected_sierra)) + ); +} + #[test] fn casm_rewrite() { let ((_, mut writer), _temp_dir) = get_test_storage();