Skip to content

Commit

Permalink
refactor(starknet_integration_test): add sierra contract class to dum…
Browse files Browse the repository at this point in the history
…my state
  • Loading branch information
AvivYossef-starkware committed Dec 5, 2024
1 parent 7045fbb commit 833b635
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
4 changes: 4 additions & 0 deletions crates/mempool_test_utils/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ impl Contract {
self.contract.cairo_version()
}

pub fn sierra(&self) -> SierraContractClass {
self.contract.get_sierra()
}

pub fn raw_class(&self) -> String {
self.contract.get_raw_class()
}
Expand Down
27 changes: 25 additions & 2 deletions crates/starknet_integration_tests/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use starknet_api::block::{
};
use starknet_api::core::{ChainId, ClassHash, ContractAddress, Nonce, SequencerContractAddress};
use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass;
use starknet_api::state::{StorageKey, ThinStateDiff};
use starknet_api::state::{SierraContractClass, StorageKey, ThinStateDiff};
use starknet_api::transaction::fields::Fee;
use starknet_api::{contract_address, felt};
use starknet_client::reader::PendingData;
Expand Down Expand Up @@ -121,6 +121,7 @@ fn initialize_papyrus_test_state(

let contract_classes_to_retrieve =
test_defined_accounts.into_iter().chain(default_test_contracts).chain([erc20_contract]);
let sierra_vec: Vec<_> = prepare_sierra_classes(contract_classes_to_retrieve.clone());
let (cairo0_contract_classes, cairo1_contract_classes) =
prepare_compiled_contract_classes(contract_classes_to_retrieve);

Expand All @@ -129,6 +130,7 @@ fn initialize_papyrus_test_state(
state_diff,
&cairo0_contract_classes,
&cairo1_contract_classes,
&sierra_vec,
)
}

Expand Down Expand Up @@ -157,6 +159,18 @@ fn prepare_state_diff(
state_diff_builder.build()
}

fn prepare_sierra_classes(
contract_classes_to_retrieve: impl Iterator<Item = Contract>,
) -> Vec<(ClassHash, SierraContractClass)> {
let mut sierra_contract_classes = Vec::new();
for contract in contract_classes_to_retrieve {
if contract.cairo_version() == CairoVersion::Cairo1 {
sierra_contract_classes.push((contract.class_hash(), contract.sierra()));
}
}
sierra_contract_classes
}

fn prepare_compiled_contract_classes(
contract_classes_to_retrieve: impl Iterator<Item = Contract>,
) -> ContractClassesMap {
Expand Down Expand Up @@ -189,6 +203,7 @@ fn write_state_to_papyrus_storage(
state_diff: ThinStateDiff,
cairo0_contract_classes: &[(ClassHash, DeprecatedContractClass)],
cairo1_contract_classes: &[(ClassHash, CasmContractClass)],
cairo1_sierra: &[(ClassHash, SierraContractClass)],
) {
let block_number = BlockNumber(0);
let block_header = test_block_header(block_number);
Expand All @@ -200,14 +215,22 @@ fn write_state_to_papyrus_storage(
for (class_hash, casm) in cairo1_contract_classes {
write_txn = write_txn.append_casm(class_hash, casm).unwrap();
}

write_txn
.append_header(block_number, &block_header)
.unwrap()
.append_body(block_number, BlockBody::default())
.unwrap()
.append_state_diff(block_number, state_diff)
.unwrap()
.append_classes(block_number, &[], &cairo0_contract_classes)
.append_classes(
block_number,
&(cairo1_sierra
.iter()
.map(|(class_hash, sierra)| (*class_hash, sierra))
.collect::<Vec<(ClassHash, &SierraContractClass)>>()),
&cairo0_contract_classes,
)
.unwrap()
.commit()
.unwrap();
Expand Down

0 comments on commit 833b635

Please sign in to comment.