From b38374197ad0ee7603d42124da4fabafcaca3027 Mon Sep 17 00:00:00 2001 From: AvivYossef-starkware Date: Thu, 5 Dec 2024 11:36:10 +0200 Subject: [PATCH] refactor(starknet_integration_tests): add sierra contract class to dummy state --- crates/blockifier/src/test_utils.rs | 4 ++++ .../src/starknet_api_test_utils.rs | 4 ++++ .../src/state_reader.rs | 24 +++++++++++++++++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/crates/blockifier/src/test_utils.rs b/crates/blockifier/src/test_utils.rs index 94ac064489..17ec1f0b6a 100644 --- a/crates/blockifier/src/test_utils.rs +++ b/crates/blockifier/src/test_utils.rs @@ -97,6 +97,10 @@ impl CairoVersion { Self::Native => panic!("There is no other version for native"), } } + + pub fn is_cairo0(&self) -> bool { + matches!(self, Self::Cairo0) + } } #[derive(Clone, Copy, PartialEq, Eq, Debug)] diff --git a/crates/mempool_test_utils/src/starknet_api_test_utils.rs b/crates/mempool_test_utils/src/starknet_api_test_utils.rs index 666cb1720c..99a9686539 100644 --- a/crates/mempool_test_utils/src/starknet_api_test_utils.rs +++ b/crates/mempool_test_utils/src/starknet_api_test_utils.rs @@ -336,6 +336,10 @@ impl Contract { self.contract.cairo_version() } + pub fn sierra(&self) -> SierraContractClass { + self.contract.get_sierra() + } + pub fn raw_class(&self) -> String { self.contract.get_raw_class() } diff --git a/crates/starknet_integration_tests/src/state_reader.rs b/crates/starknet_integration_tests/src/state_reader.rs index ca88fb9cb3..c0f75af3d1 100644 --- a/crates/starknet_integration_tests/src/state_reader.rs +++ b/crates/starknet_integration_tests/src/state_reader.rs @@ -37,7 +37,7 @@ use starknet_api::block::{ }; use starknet_api::core::{ChainId, ClassHash, ContractAddress, Nonce, SequencerContractAddress}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; -use starknet_api::state::{StorageKey, ThinStateDiff}; +use starknet_api::state::{SierraContractClass, StorageKey, ThinStateDiff}; use starknet_api::transaction::fields::Fee; use starknet_api::{contract_address, felt}; use starknet_client::reader::PendingData; @@ -121,6 +121,7 @@ fn initialize_papyrus_test_state( let contract_classes_to_retrieve = test_defined_accounts.into_iter().chain(default_test_contracts).chain([erc20_contract]); + let sierra_vec: Vec<_> = prepare_sierra_classes(contract_classes_to_retrieve.clone()); let (cairo0_contract_classes, cairo1_contract_classes) = prepare_compiled_contract_classes(contract_classes_to_retrieve); @@ -129,6 +130,7 @@ fn initialize_papyrus_test_state( state_diff, &cairo0_contract_classes, &cairo1_contract_classes, + &sierra_vec, ) } @@ -157,6 +159,15 @@ fn prepare_state_diff( state_diff_builder.build() } +fn prepare_sierra_classes( + contract_classes_to_retrieve: impl Iterator, +) -> Vec<(ClassHash, SierraContractClass)> { + contract_classes_to_retrieve + .filter(|contract| !contract.cairo_version().is_cairo0()) + .map(|contract| (contract.class_hash(), contract.sierra())) + .collect() +} + fn prepare_compiled_contract_classes( contract_classes_to_retrieve: impl Iterator, ) -> ContractClassesMap { @@ -189,6 +200,7 @@ fn write_state_to_papyrus_storage( state_diff: ThinStateDiff, cairo0_contract_classes: &[(ClassHash, DeprecatedContractClass)], cairo1_contract_classes: &[(ClassHash, CasmContractClass)], + cairo1_sierra: &[(ClassHash, SierraContractClass)], ) { let block_number = BlockNumber(0); let block_header = test_block_header(block_number); @@ -200,6 +212,7 @@ fn write_state_to_papyrus_storage( for (class_hash, casm) in cairo1_contract_classes { write_txn = write_txn.append_casm(class_hash, casm).unwrap(); } + write_txn .append_header(block_number, &block_header) .unwrap() @@ -207,7 +220,14 @@ fn write_state_to_papyrus_storage( .unwrap() .append_state_diff(block_number, state_diff) .unwrap() - .append_classes(block_number, &[], &cairo0_contract_classes) + .append_classes( + block_number, + &(cairo1_sierra + .iter() + .map(|(class_hash, sierra)| (*class_hash, sierra)) + .collect::>()), + &cairo0_contract_classes, + ) .unwrap() .commit() .unwrap();