From 397223064609d3c50642ae93fad84edc230e71a6 Mon Sep 17 00:00:00 2001 From: ShahakShama <70578257+ShahakShama@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:44:23 +0200 Subject: [PATCH] refactor(papyrus_p2p_sync): convert state diff tests to use run_test (#2511) --- .../src/client/state_diff_test.rs | 346 +++++++++--------- 1 file changed, 175 insertions(+), 171 deletions(-) diff --git a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs index 04175a7568..d4e86f259e 100644 --- a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs +++ b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs @@ -1,6 +1,6 @@ use std::cmp::min; +use std::collections::HashMap; -use futures::future::join; use futures::{FutureExt, StreamExt}; use indexmap::indexmap; use papyrus_network::network_manager::GenericReceiver; @@ -16,24 +16,22 @@ use papyrus_protobuf::sync::{ StateDiffChunk, }; use papyrus_storage::state::StateStorageReader; -use papyrus_test_utils::{get_rng, GetTestInstance}; -use rand::RngCore; -use rand_chacha::ChaCha8Rng; +use papyrus_test_utils::get_rng; use starknet_api::block::{BlockHeader, BlockHeaderWithoutHash, BlockNumber}; -use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; +use starknet_api::core::{ascii_as_felt, ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StorageKey, ThinStateDiff}; use starknet_types_core::felt::Felt; -use static_assertions::const_assert; use tokio::sync::mpsc::{channel, Receiver}; use super::test_utils::{ create_block_hashes_and_signatures, - setup, + random_header, + run_test, wait_for_marker, + Action, DataType, HeaderTestPayload, StateDiffTestPayload, - TestArgs, HEADER_QUERY_LENGTH, SLEEP_DURATION_TO_LET_SYNC_ADVANCE, STATE_DIFF_QUERY_LENGTH, @@ -44,124 +42,152 @@ use super::{P2PSyncClientConfig, StateDiffQuery}; #[tokio::test] async fn state_diff_basic_flow() { - // Asserting the constants so the test can assume there will be 2 state diff queries for a - // single header query and the second will be smaller than the first. - const_assert!(STATE_DIFF_QUERY_LENGTH < HEADER_QUERY_LENGTH); - const_assert!(HEADER_QUERY_LENGTH < 2 * STATE_DIFF_QUERY_LENGTH); - - let TestArgs { - p2p_sync, - storage_reader, - mut mock_state_diff_response_manager, - mut mock_header_response_manager, - // The test will fail if we drop these - mock_transaction_response_manager: _mock_transaction_responses_manager, - mock_class_response_manager: _mock_class_responses_manager, - .. - } = setup(); - let mut rng = get_rng(); - // TODO(eitan): Add a 3rd constant for NUM_CHUNKS_PER_BLOCK so that ThinStateDiff is made from - // multiple StateDiffChunks - let (state_diffs, header_state_diff_lengths): (Vec<_>, Vec<_>) = (0..HEADER_QUERY_LENGTH) - .map(|_| { - let diff = create_random_state_diff_chunk(&mut rng); - let length = diff.len(); - (diff, length) - }) - .unzip(); - - let (state_diff_sender, mut state_diff_receiver) = channel(p2p_sync.config.buffer_size); - - // Create a future that will receive send responses and validate the results. - let test_future = async move { - for (start_block_number, num_blocks) in [ - (0u64, STATE_DIFF_QUERY_LENGTH), - (STATE_DIFF_QUERY_LENGTH, HEADER_QUERY_LENGTH - STATE_DIFF_QUERY_LENGTH), - ] { - for block_number in start_block_number..(start_block_number + num_blocks) { - let state_diff_chunk = state_diffs[usize::try_from(block_number).unwrap()].clone(); - - let block_number = BlockNumber(block_number); - - // Check that before we've sent all parts the state diff wasn't written yet. - let txn = storage_reader.begin_ro_txn().unwrap(); - assert_eq!(block_number, txn.get_state_marker().unwrap()); - - state_diff_sender.send(Some(state_diff_chunk.clone())).await.unwrap(); - - // Check state diff was written to the storage. This way we make sure that the sync - // writes to the storage each block's state diff before receiving all query - // responses. + let class_hash0 = ClassHash(ascii_as_felt("class_hash0").unwrap()); + let class_hash1 = ClassHash(ascii_as_felt("class_hash1").unwrap()); + let casm_hash0 = CompiledClassHash(ascii_as_felt("casm_hash0").unwrap()); + let address0 = ContractAddress(ascii_as_felt("address0").unwrap().try_into().unwrap()); + let address1 = ContractAddress(ascii_as_felt("address1").unwrap().try_into().unwrap()); + let address2 = ContractAddress(ascii_as_felt("address2").unwrap().try_into().unwrap()); + let key0 = StorageKey(ascii_as_felt("key0").unwrap().try_into().unwrap()); + let key1 = StorageKey(ascii_as_felt("key1").unwrap().try_into().unwrap()); + let value0 = ascii_as_felt("value0").unwrap(); + let value1 = ascii_as_felt("value1").unwrap(); + let nonce0 = Nonce(ascii_as_felt("nonce0").unwrap()); + + let state_diffs_and_chunks = vec![ + ( + ThinStateDiff { + deployed_contracts: indexmap!(address0 => class_hash0), + storage_diffs: indexmap!(address0 => indexmap!(key0 => value0, key1 => value1)), + declared_classes: indexmap!(class_hash0 => casm_hash0), + deprecated_declared_classes: vec![class_hash1], + nonces: indexmap!(address0 => nonce0), + replaced_classes: Default::default(), + }, + vec![ + StateDiffChunk::DeclaredClass(DeclaredClass { + class_hash: class_hash0, + compiled_class_hash: casm_hash0, + }), + StateDiffChunk::ContractDiff(ContractDiff { + contract_address: address0, + class_hash: Some(class_hash0), + nonce: Some(nonce0), + storage_diffs: indexmap!(key0 => value0, key1 => value1), + }), + StateDiffChunk::DeprecatedDeclaredClass(DeprecatedDeclaredClass { + class_hash: class_hash1, + }), + ], + ), + ( + ThinStateDiff { + deployed_contracts: indexmap!(address1 => class_hash1), + storage_diffs: indexmap!( + address1 => indexmap!(key0 => value0), + address2 => indexmap!(key1 => value1) + ), + nonces: indexmap!(address2 => nonce0), + ..Default::default() + }, + vec![ + StateDiffChunk::ContractDiff(ContractDiff { + contract_address: address1, + class_hash: Some(class_hash1), + nonce: None, + storage_diffs: indexmap!(key0 => value0), + }), + StateDiffChunk::ContractDiff(ContractDiff { + contract_address: address2, + class_hash: None, + nonce: Some(nonce0), + storage_diffs: indexmap!(key1 => value1), + }), + ], + ), + ]; + + let mut actions = vec![ + // We already validate the header query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::Header), + ]; + + // Send headers with corresponding state diff length + for (i, (state_diff, _)) in state_diffs_and_chunks.iter().enumerate() { + actions.push(Action::SendHeader(DataOrFin(Some(random_header( + &mut rng, + BlockNumber(i.try_into().unwrap()), + Some(state_diff.len()), + None, + ))))); + } + actions.push(Action::SendHeader(DataOrFin(None))); + + let len = state_diffs_and_chunks.len(); + actions.push(Action::ReceiveQuery( + Box::new(move |query| { + assert_eq!( + query, + Query { + start_block: BlockHashOrNumber::Number(BlockNumber(0)), + direction: Direction::Forward, + limit: len.try_into().unwrap(), + step: 1, + } + ) + }), + DataType::StateDiff, + )); + // Send state diff chunks and check storage + for (i, (expected_state_diff, state_diff_chunks)) in + state_diffs_and_chunks.iter().cloned().enumerate() + { + for state_diff_chunk in state_diff_chunks { + // Check that before the last chunk was sent, the state diff isn't written. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + assert_eq!( + u64::try_from(i).unwrap(), + reader.begin_ro_txn().unwrap().get_state_marker().unwrap().0 + ); + } + .boxed() + }))); + + actions.push(Action::SendStateDiff(DataOrFin(Some(state_diff_chunk)))); + } + // Check that a block's state diff is written before the entire query finished. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + let block_number = BlockNumber(i.try_into().unwrap()); wait_for_marker( DataType::StateDiff, - &storage_reader, + &reader, block_number.unchecked_next(), SLEEP_DURATION_TO_LET_SYNC_ADVANCE, TIMEOUT_FOR_TEST, ) .await; - let txn = storage_reader.begin_ro_txn().unwrap(); - let state_diff = txn.get_state_diff(block_number).unwrap().unwrap(); - // TODO(noamsp): refactor test so that we treat multiple state diff chunks as a - // single state diff - let expected_state_diff = match state_diff_chunk { - StateDiffChunk::ContractDiff(contract_diff) => { - let mut deployed_contracts = indexmap! {}; - if let Some(class_hash) = contract_diff.class_hash { - deployed_contracts.insert(contract_diff.contract_address, class_hash); - }; - let mut nonces = indexmap! {}; - if let Some(nonce) = contract_diff.nonce { - nonces.insert(contract_diff.contract_address, nonce); - } - ThinStateDiff { - deployed_contracts, - nonces, - storage_diffs: indexmap! { - contract_diff.contract_address => contract_diff.storage_diffs - }, - ..Default::default() - } - } - StateDiffChunk::DeclaredClass(declared_class) => ThinStateDiff { - declared_classes: indexmap! { - declared_class.class_hash => declared_class.compiled_class_hash - }, - ..Default::default() - }, - StateDiffChunk::DeprecatedDeclaredClass(deprecated_declared_class) => { - ThinStateDiff { - deprecated_declared_classes: vec![deprecated_declared_class.class_hash], - ..Default::default() - } - } - }; - assert_eq!(state_diff, expected_state_diff); + let txn = reader.begin_ro_txn().unwrap(); + let actual_state_diff = txn.get_state_diff(block_number).unwrap().unwrap(); + assert_eq!(actual_state_diff, expected_state_diff); } - - state_diff_sender.send(None).await.unwrap(); - } - }; - - tokio::select! { - sync_result = p2p_sync.run() => { - sync_result.unwrap(); - panic!("P2P sync aborted with no failure."); - } - _ = join( - run_state_diff_sync_through_channel( - &mut mock_header_response_manager, - &mut mock_state_diff_response_manager, - header_state_diff_lengths, - &mut state_diff_receiver, - false, - ), - test_future, - ) => {} + .boxed() + }))); } + actions.push(Action::SendStateDiff(DataOrFin(None))); + + run_test( + HashMap::from([ + (DataType::Header, state_diffs_and_chunks.len().try_into().unwrap()), + (DataType::StateDiff, state_diffs_and_chunks.len().try_into().unwrap()), + ]), + actions, + ) + .await; } // TODO(noamsp): Consider verifying that ParseDataError::BadPeerError(EmptyStateDiffPart) was @@ -288,46 +314,44 @@ async fn validate_state_diff_fails( header_state_diff_lengths: Vec, state_diff_chunks: Vec>, ) { - let TestArgs { - storage_reader, - p2p_sync, - mut mock_state_diff_response_manager, - mut mock_header_response_manager, - // The test will fail if we drop these - mock_transaction_response_manager: _mock_transaction_responses_manager, - mock_class_response_manager: _mock_class_responses_manager, - .. - } = setup(); - - let (state_diff_sender, mut state_diff_receiver) = channel(p2p_sync.config.buffer_size); - - // Create a future that will send responses and validate the results. - let test_future = async move { - for state_diff_chunk in state_diff_chunks { - // Check that before we've sent all parts the state diff wasn't written yet. - let txn = storage_reader.begin_ro_txn().unwrap(); - assert_eq!(0, txn.get_state_marker().unwrap().0); + let mut rng = get_rng(); - state_diff_sender.send(state_diff_chunk).await.unwrap(); - } - }; + let mut actions = vec![ + // We already validate the header query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::Header), + ]; + + // Send headers with corresponding state diff length + for (i, state_diff_length) in header_state_diff_lengths.iter().copied().enumerate() { + actions.push(Action::SendHeader(DataOrFin(Some(random_header( + &mut rng, + BlockNumber(i.try_into().unwrap()), + Some(state_diff_length), + None, + ))))); + } + actions.push(Action::SendHeader(DataOrFin(None))); - tokio::select! { - sync_result = p2p_sync.run() => { - sync_result.unwrap(); - panic!("P2P sync aborted with no failure."); - } - _ = join( - run_state_diff_sync_through_channel( - &mut mock_header_response_manager, - &mut mock_state_diff_response_manager, - header_state_diff_lengths, - &mut state_diff_receiver, - true, - ), - test_future - ) => {} + actions.push( + // We already validate the state diff query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::StateDiff), + ); + + // Send state diff chunks. + for state_diff_chunk in state_diff_chunks { + actions.push(Action::SendStateDiff(DataOrFin(state_diff_chunk))); } + + actions.push(Action::ValidateReportSent(DataType::StateDiff)); + + run_test( + HashMap::from([ + (DataType::Header, header_state_diff_lengths.len().try_into().unwrap()), + (DataType::StateDiff, header_state_diff_lengths.len().try_into().unwrap()), + ]), + actions, + ) + .await; } // Advances the header sync with associated header state diffs. @@ -484,23 +508,3 @@ pub(crate) async fn run_state_diff_sync( } }; } - -fn create_random_state_diff_chunk(rng: &mut ChaCha8Rng) -> StateDiffChunk { - let mut state_diff_chunk = StateDiffChunk::get_test_instance(rng); - let contract_address = ContractAddress::from(rng.next_u64()); - let class_hash = ClassHash(rng.next_u64().into()); - match &mut state_diff_chunk { - StateDiffChunk::ContractDiff(contract_diff) => { - contract_diff.contract_address = contract_address; - contract_diff.class_hash = Some(class_hash); - } - StateDiffChunk::DeclaredClass(declared_class) => { - declared_class.class_hash = class_hash; - declared_class.compiled_class_hash = CompiledClassHash(rng.next_u64().into()); - } - StateDiffChunk::DeprecatedDeclaredClass(deprecated_declared_class) => { - deprecated_declared_class.class_hash = class_hash; - } - } - state_diff_chunk -}