diff --git a/crates/mempool_test_utils/src/starknet_api_test_utils.rs b/crates/mempool_test_utils/src/starknet_api_test_utils.rs index 666cb1720c3..99a9686539a 100644 --- a/crates/mempool_test_utils/src/starknet_api_test_utils.rs +++ b/crates/mempool_test_utils/src/starknet_api_test_utils.rs @@ -336,6 +336,10 @@ impl Contract { self.contract.cairo_version() } + pub fn sierra(&self) -> SierraContractClass { + self.contract.get_sierra() + } + pub fn raw_class(&self) -> String { self.contract.get_raw_class() } diff --git a/crates/starknet_integration_tests/src/state_reader.rs b/crates/starknet_integration_tests/src/state_reader.rs index ca88fb9cb3b..9c5eb085d67 100644 --- a/crates/starknet_integration_tests/src/state_reader.rs +++ b/crates/starknet_integration_tests/src/state_reader.rs @@ -37,7 +37,7 @@ use starknet_api::block::{ }; use starknet_api::core::{ChainId, ClassHash, ContractAddress, Nonce, SequencerContractAddress}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; -use starknet_api::state::{StorageKey, ThinStateDiff}; +use starknet_api::state::{SierraContractClass, StorageKey, ThinStateDiff}; use starknet_api::transaction::fields::Fee; use starknet_api::{contract_address, felt}; use starknet_client::reader::PendingData; @@ -121,6 +121,7 @@ fn initialize_papyrus_test_state( let contract_classes_to_retrieve = test_defined_accounts.into_iter().chain(default_test_contracts).chain([erc20_contract]); + let sierra_vec: Vec<_> = prepare_sierra_classes(contract_classes_to_retrieve.clone()); let (cairo0_contract_classes, cairo1_contract_classes) = prepare_compiled_contract_classes(contract_classes_to_retrieve); @@ -129,6 +130,7 @@ fn initialize_papyrus_test_state( state_diff, &cairo0_contract_classes, &cairo1_contract_classes, + &sierra_vec, ) } @@ -157,6 +159,18 @@ fn prepare_state_diff( state_diff_builder.build() } +fn prepare_sierra_classes( + contract_classes_to_retrieve: impl Iterator, +) -> Vec<(ClassHash, SierraContractClass)> { + let mut sierra_contract_classes = Vec::new(); + for contract in contract_classes_to_retrieve { + if contract.cairo_version() == CairoVersion::Cairo1 { + sierra_contract_classes.push((contract.class_hash(), contract.sierra())); + } + } + sierra_contract_classes +} + fn prepare_compiled_contract_classes( contract_classes_to_retrieve: impl Iterator, ) -> ContractClassesMap { @@ -189,6 +203,7 @@ fn write_state_to_papyrus_storage( state_diff: ThinStateDiff, cairo0_contract_classes: &[(ClassHash, DeprecatedContractClass)], cairo1_contract_classes: &[(ClassHash, CasmContractClass)], + cairo1_sierra: &[(ClassHash, SierraContractClass)], ) { let block_number = BlockNumber(0); let block_header = test_block_header(block_number); @@ -200,6 +215,7 @@ fn write_state_to_papyrus_storage( for (class_hash, casm) in cairo1_contract_classes { write_txn = write_txn.append_casm(class_hash, casm).unwrap(); } + write_txn .append_header(block_number, &block_header) .unwrap() @@ -207,7 +223,14 @@ fn write_state_to_papyrus_storage( .unwrap() .append_state_diff(block_number, state_diff) .unwrap() - .append_classes(block_number, &[], &cairo0_contract_classes) + .append_classes( + block_number, + &(cairo1_sierra + .iter() + .map(|(class_hash, sierra)| (*class_hash, sierra)) + .collect::>()), + &cairo0_contract_classes, + ) .unwrap() .commit() .unwrap();