From 511eafb5ae2973757b660f83e458037b636c5bf1 Mon Sep 17 00:00:00 2001 From: AvivYossef-starkware Date: Thu, 5 Dec 2024 11:36:10 +0200 Subject: [PATCH] refactor(starknet_integration_tests): add sierra contract class to dummy state --- crates/blockifier/src/test_utils.rs | 9 +++++++ .../src/starknet_api_test_utils.rs | 4 ++++ .../src/state_reader.rs | 24 +++++++++++++++++-- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/crates/blockifier/src/test_utils.rs b/crates/blockifier/src/test_utils.rs index a2d8164767..374aacbfff 100644 --- a/crates/blockifier/src/test_utils.rs +++ b/crates/blockifier/src/test_utils.rs @@ -95,6 +95,15 @@ impl CairoVersion { Self::Native => panic!("There is no other version for native"), } } + + pub fn is_cairo0(&self) -> bool { + match self { + Self::Cairo0 => true, + Self::Cairo1 => false, + #[cfg(feature = "cairo_native")] + Self::Native => false, + } + } } #[derive(Clone, Copy, PartialEq, Eq, Debug)] diff --git a/crates/mempool_test_utils/src/starknet_api_test_utils.rs b/crates/mempool_test_utils/src/starknet_api_test_utils.rs index 105e06c29a..e6338ec67f 100644 --- a/crates/mempool_test_utils/src/starknet_api_test_utils.rs +++ b/crates/mempool_test_utils/src/starknet_api_test_utils.rs @@ -334,6 +334,10 @@ impl Contract { self.contract.cairo_version() } + pub fn sierra(&self) -> SierraContractClass { + self.contract.get_sierra() + } + pub fn raw_class(&self) -> String { self.contract.get_raw_class() } diff --git a/crates/starknet_integration_tests/src/state_reader.rs b/crates/starknet_integration_tests/src/state_reader.rs index 6db10bc02e..0dedd6055f 100644 --- a/crates/starknet_integration_tests/src/state_reader.rs +++ b/crates/starknet_integration_tests/src/state_reader.rs @@ -37,7 +37,7 @@ use starknet_api::block::{ }; use starknet_api::core::{ChainId, ClassHash, ContractAddress, Nonce, SequencerContractAddress}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; -use starknet_api::state::{StorageKey, ThinStateDiff}; +use starknet_api::state::{SierraContractClass, StorageKey, ThinStateDiff}; use starknet_api::transaction::fields::Fee; use starknet_api::{contract_address, felt}; use starknet_client::reader::PendingData; @@ -123,6 +123,7 @@ fn initialize_papyrus_test_state( let contract_classes_to_retrieve = test_defined_accounts.into_iter().chain(default_test_contracts).chain([erc20_contract]); + let sierra_vec: Vec<_> = prepare_sierra_classes(contract_classes_to_retrieve.clone()); let (cairo0_contract_classes, cairo1_contract_classes) = prepare_compiled_contract_classes(contract_classes_to_retrieve); @@ -131,6 +132,7 @@ fn initialize_papyrus_test_state( state_diff, &cairo0_contract_classes, &cairo1_contract_classes, + &sierra_vec, ) } @@ -159,6 +161,15 @@ fn prepare_state_diff( state_diff_builder.build() } +fn prepare_sierra_classes( + contract_classes_to_retrieve: impl Iterator, +) -> Vec<(ClassHash, SierraContractClass)> { + contract_classes_to_retrieve + .filter(|contract| !contract.cairo_version().is_cairo0()) + .map(|contract| (contract.class_hash(), contract.sierra())) + .collect() +} + fn prepare_compiled_contract_classes( contract_classes_to_retrieve: impl Iterator, ) -> ContractClassesMap { @@ -191,6 +202,7 @@ fn write_state_to_papyrus_storage( state_diff: ThinStateDiff, cairo0_contract_classes: &[(ClassHash, DeprecatedContractClass)], cairo1_contract_classes: &[(ClassHash, CasmContractClass)], + cairo1_sierra: &[(ClassHash, SierraContractClass)], ) { let block_number = BlockNumber(0); let block_header = test_block_header(block_number); @@ -202,6 +214,7 @@ fn write_state_to_papyrus_storage( for (class_hash, casm) in cairo1_contract_classes { write_txn = write_txn.append_casm(class_hash, casm).unwrap(); } + write_txn .append_header(block_number, &block_header) .unwrap() @@ -209,7 +222,14 @@ fn write_state_to_papyrus_storage( .unwrap() .append_state_diff(block_number, state_diff) .unwrap() - .append_classes(block_number, &[], &cairo0_contract_classes) + .append_classes( + block_number, + &(cairo1_sierra + .iter() + .map(|(class_hash, sierra)| (*class_hash, sierra)) + .collect::>()), + &cairo0_contract_classes, + ) .unwrap() .commit() .unwrap();