diff --git a/crates/papyrus_p2p_sync/src/client/class_test.rs b/crates/papyrus_p2p_sync/src/client/class_test.rs index 6e9e6629bd..85b49d931e 100644 --- a/crates/papyrus_p2p_sync/src/client/class_test.rs +++ b/crates/papyrus_p2p_sync/src/client/class_test.rs @@ -1,10 +1,9 @@ -use std::cmp::min; +use std::collections::HashMap; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use papyrus_common::pending_classes::ApiContractClass; use papyrus_protobuf::sync::{ BlockHashOrNumber, - ClassQuery, DataOrFin, DeclaredClass, DeprecatedDeclaredClass, @@ -22,129 +21,131 @@ use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContract use starknet_api::state::SierraContractClass; use super::test_utils::{ - setup, + random_header, + run_test, wait_for_marker, + Action, DataType, - TestArgs, - CLASS_DIFF_QUERY_LENGTH, - HEADER_QUERY_LENGTH, SLEEP_DURATION_TO_LET_SYNC_ADVANCE, TIMEOUT_FOR_TEST, }; -use crate::client::state_diff_test::run_state_diff_sync; #[tokio::test] async fn class_basic_flow() { - let TestArgs { - p2p_sync, - storage_reader, - mut mock_state_diff_response_manager, - mut mock_header_response_manager, - mut mock_class_response_manager, - // The test will fail if we drop this - mock_transaction_response_manager: _mock_transaction_responses_manager, - .. - } = setup(); - let mut rng = get_rng(); - // TODO(noamsp): Add multiple state diffs per header - let (class_state_diffs, api_contract_classes): (Vec<_>, Vec<_>) = (0..HEADER_QUERY_LENGTH) - .map(|_| create_random_state_diff_chunk_with_class(&mut rng)) - .unzip(); - let header_state_diff_lengths = - class_state_diffs.iter().map(|class_state_diff| class_state_diff.len()).collect::>(); - - // Create a future that will receive queries, send responses and validate the results - let parse_queries_future = async move { - // Check that before we send state diffs there is no class query. - assert!(mock_class_response_manager.next().now_or_never().is_none()); - run_state_diff_sync( - p2p_sync.config, - &mut mock_header_response_manager, - &mut mock_state_diff_response_manager, - header_state_diff_lengths.clone(), - class_state_diffs.clone().into_iter().map(Some).collect(), - ) - .await; - - let num_declare_class_state_diff_headers = - u64::try_from(header_state_diff_lengths.len()).unwrap(); - let num_class_queries = - num_declare_class_state_diff_headers.div_ceil(CLASS_DIFF_QUERY_LENGTH); - for i in 0..num_class_queries { - let start_block_number = i * CLASS_DIFF_QUERY_LENGTH; - let limit = min( - num_declare_class_state_diff_headers - start_block_number, - CLASS_DIFF_QUERY_LENGTH, - ); + let state_diffs_and_classes_of_blocks = [ + vec![ + create_random_state_diff_chunk_with_class(&mut rng), + create_random_state_diff_chunk_with_class(&mut rng), + ], + vec![ + create_random_state_diff_chunk_with_class(&mut rng), + create_random_state_diff_chunk_with_class(&mut rng), + create_random_state_diff_chunk_with_class(&mut rng), + ], + ]; + + let mut actions = vec![ + // We already validate the header query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::Header), + ]; + + // Send headers with corresponding state diff length. + for (i, state_diffs_and_classes) in state_diffs_and_classes_of_blocks.iter().enumerate() { + actions.push(Action::SendHeader(DataOrFin(Some(random_header( + &mut rng, + BlockNumber(i.try_into().unwrap()), + Some(state_diffs_and_classes.len()), + None, + ))))); + } + actions.push(Action::SendHeader(DataOrFin(None))); + + // Send state diffs. + actions.push( + // We already validate the state diff query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::StateDiff), + ); + for state_diffs_and_classes in &state_diffs_and_classes_of_blocks { + for (state_diff, _) in state_diffs_and_classes { + actions.push(Action::SendStateDiff(DataOrFin(Some(state_diff.clone())))); + } + } - // Get a class query and validate it - let mut mock_class_responses_manager = - mock_class_response_manager.next().await.unwrap(); + let len = state_diffs_and_classes_of_blocks.len(); + actions.push(Action::ReceiveQuery( + Box::new(move |query| { assert_eq!( - *mock_class_responses_manager.query(), - Ok(ClassQuery(Query { - start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)), + query, + Query { + start_block: BlockHashOrNumber::Number(BlockNumber(0)), direction: Direction::Forward, - limit, + limit: len.try_into().unwrap(), step: 1, - })), - "If the limit of the query is too low, try to increase \ - SLEEP_DURATION_TO_LET_SYNC_ADVANCE", - ); - - for block_number in start_block_number..(start_block_number + limit) { - let class_hash = - class_state_diffs[usize::try_from(block_number).unwrap()].get_class_hash(); - let expected_class = - api_contract_classes[usize::try_from(block_number).unwrap()].clone(); - - let block_number = BlockNumber(block_number); - - // Check that before we've sent all parts the contract class wasn't written yet - let txn = storage_reader.begin_ro_txn().unwrap(); - assert_eq!(block_number, txn.get_class_marker().unwrap()); - - mock_class_responses_manager - .send_response(DataOrFin(Some((expected_class.clone(), class_hash)))) - .await - .unwrap(); - + } + ) + }), + DataType::Class, + )); + for (i, state_diffs_and_classes) in state_diffs_and_classes_of_blocks.into_iter().enumerate() { + for (state_diff, class) in &state_diffs_and_classes { + let class_hash = state_diff.get_class_hash(); + + // Check that before the last class was sent, the classes aren't written. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + assert_eq!( + u64::try_from(i).unwrap(), + reader.begin_ro_txn().unwrap().get_class_marker().unwrap().0 + ); + } + .boxed() + }))); + actions.push(Action::SendClass(DataOrFin(Some((class.clone(), class_hash))))); + } + // Check that a block's classes are written before the entire query finished. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + let block_number = BlockNumber(i.try_into().unwrap()); wait_for_marker( DataType::Class, - &storage_reader, + &reader, block_number.unchecked_next(), SLEEP_DURATION_TO_LET_SYNC_ADVANCE, TIMEOUT_FOR_TEST, ) .await; - let txn = storage_reader.begin_ro_txn().unwrap(); - let actual_class = match expected_class { - ApiContractClass::ContractClass(_) => ApiContractClass::ContractClass( - txn.get_class(&class_hash).unwrap().unwrap(), - ), - ApiContractClass::DeprecatedContractClass(_) => { - ApiContractClass::DeprecatedContractClass( - txn.get_deprecated_class(&class_hash).unwrap().unwrap(), - ) + let txn = reader.begin_ro_txn().unwrap(); + for (state_diff, expected_class) in state_diffs_and_classes { + let class_hash = state_diff.get_class_hash(); + match expected_class { + ApiContractClass::ContractClass(expected_class) => { + let actual_class = txn.get_class(&class_hash).unwrap().unwrap(); + assert_eq!(actual_class, expected_class.clone()); + } + ApiContractClass::DeprecatedContractClass(expected_class) => { + let actual_class = + txn.get_deprecated_class(&class_hash).unwrap().unwrap(); + assert_eq!(actual_class, expected_class.clone()); + } } - }; - assert_eq!(expected_class, actual_class); + } } - - mock_class_responses_manager.send_response(DataOrFin(None)).await.unwrap(); - } - }; - - tokio::select! { - sync_result = p2p_sync.run() => { - sync_result.unwrap(); - panic!("P2P sync aborted with no failure."); - } - _ = parse_queries_future => {} + .boxed() + }))); } + + run_test( + HashMap::from([ + (DataType::Header, len.try_into().unwrap()), + (DataType::StateDiff, len.try_into().unwrap()), + (DataType::Class, len.try_into().unwrap()), + ]), + actions, + ) + .await; } // We define this new trait here so we can use the get_class_hash function in the test. @@ -176,6 +177,8 @@ fn create_random_state_diff_chunk_with_class( }; ( StateDiffChunk::DeclaredClass(declared_class), + // TODO(noamsp): get_test_instance on these types returns the same value, making this + // test redundant. Fix this. ApiContractClass::ContractClass(SierraContractClass::get_test_instance(rng)), ) } else { diff --git a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs index 04175a7568..eb8c1528fa 100644 --- a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs +++ b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs @@ -1,9 +1,7 @@ -use std::cmp::min; +use std::collections::HashMap; -use futures::future::join; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use indexmap::indexmap; -use papyrus_network::network_manager::GenericReceiver; use papyrus_protobuf::sync::{ BlockHashOrNumber, ContractDiff, @@ -12,156 +10,173 @@ use papyrus_protobuf::sync::{ DeprecatedDeclaredClass, Direction, Query, - SignedBlockHeader, StateDiffChunk, }; use papyrus_storage::state::StateStorageReader; -use papyrus_test_utils::{get_rng, GetTestInstance}; -use rand::RngCore; -use rand_chacha::ChaCha8Rng; -use starknet_api::block::{BlockHeader, BlockHeaderWithoutHash, BlockNumber}; -use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; +use papyrus_test_utils::get_rng; +use starknet_api::block::BlockNumber; +use starknet_api::core::{ascii_as_felt, ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StorageKey, ThinStateDiff}; use starknet_types_core::felt::Felt; -use static_assertions::const_assert; -use tokio::sync::mpsc::{channel, Receiver}; use super::test_utils::{ - create_block_hashes_and_signatures, - setup, + random_header, + run_test, wait_for_marker, + Action, DataType, - HeaderTestPayload, - StateDiffTestPayload, - TestArgs, - HEADER_QUERY_LENGTH, SLEEP_DURATION_TO_LET_SYNC_ADVANCE, - STATE_DIFF_QUERY_LENGTH, TIMEOUT_FOR_TEST, - WAIT_PERIOD_FOR_NEW_DATA, }; -use super::{P2PSyncClientConfig, StateDiffQuery}; #[tokio::test] async fn state_diff_basic_flow() { - // Asserting the constants so the test can assume there will be 2 state diff queries for a - // single header query and the second will be smaller than the first. - const_assert!(STATE_DIFF_QUERY_LENGTH < HEADER_QUERY_LENGTH); - const_assert!(HEADER_QUERY_LENGTH < 2 * STATE_DIFF_QUERY_LENGTH); - - let TestArgs { - p2p_sync, - storage_reader, - mut mock_state_diff_response_manager, - mut mock_header_response_manager, - // The test will fail if we drop these - mock_transaction_response_manager: _mock_transaction_responses_manager, - mock_class_response_manager: _mock_class_responses_manager, - .. - } = setup(); - let mut rng = get_rng(); - // TODO(eitan): Add a 3rd constant for NUM_CHUNKS_PER_BLOCK so that ThinStateDiff is made from - // multiple StateDiffChunks - let (state_diffs, header_state_diff_lengths): (Vec<_>, Vec<_>) = (0..HEADER_QUERY_LENGTH) - .map(|_| { - let diff = create_random_state_diff_chunk(&mut rng); - let length = diff.len(); - (diff, length) - }) - .unzip(); - - let (state_diff_sender, mut state_diff_receiver) = channel(p2p_sync.config.buffer_size); - - // Create a future that will receive send responses and validate the results. - let test_future = async move { - for (start_block_number, num_blocks) in [ - (0u64, STATE_DIFF_QUERY_LENGTH), - (STATE_DIFF_QUERY_LENGTH, HEADER_QUERY_LENGTH - STATE_DIFF_QUERY_LENGTH), - ] { - for block_number in start_block_number..(start_block_number + num_blocks) { - let state_diff_chunk = state_diffs[usize::try_from(block_number).unwrap()].clone(); - - let block_number = BlockNumber(block_number); - // Check that before we've sent all parts the state diff wasn't written yet. - let txn = storage_reader.begin_ro_txn().unwrap(); - assert_eq!(block_number, txn.get_state_marker().unwrap()); - - state_diff_sender.send(Some(state_diff_chunk.clone())).await.unwrap(); - - // Check state diff was written to the storage. This way we make sure that the sync - // writes to the storage each block's state diff before receiving all query - // responses. + let class_hash0 = ClassHash(ascii_as_felt("class_hash0").unwrap()); + let class_hash1 = ClassHash(ascii_as_felt("class_hash1").unwrap()); + let casm_hash0 = CompiledClassHash(ascii_as_felt("casm_hash0").unwrap()); + let address0 = ContractAddress(ascii_as_felt("address0").unwrap().try_into().unwrap()); + let address1 = ContractAddress(ascii_as_felt("address1").unwrap().try_into().unwrap()); + let address2 = ContractAddress(ascii_as_felt("address2").unwrap().try_into().unwrap()); + let key0 = StorageKey(ascii_as_felt("key0").unwrap().try_into().unwrap()); + let key1 = StorageKey(ascii_as_felt("key1").unwrap().try_into().unwrap()); + let value0 = ascii_as_felt("value0").unwrap(); + let value1 = ascii_as_felt("value1").unwrap(); + let nonce0 = Nonce(ascii_as_felt("nonce0").unwrap()); + + let state_diffs_and_chunks = vec![ + ( + ThinStateDiff { + deployed_contracts: indexmap!(address0 => class_hash0), + storage_diffs: indexmap!(address0 => indexmap!(key0 => value0, key1 => value1)), + declared_classes: indexmap!(class_hash0 => casm_hash0), + deprecated_declared_classes: vec![class_hash1], + nonces: indexmap!(address0 => nonce0), + replaced_classes: Default::default(), + }, + vec![ + StateDiffChunk::DeclaredClass(DeclaredClass { + class_hash: class_hash0, + compiled_class_hash: casm_hash0, + }), + StateDiffChunk::ContractDiff(ContractDiff { + contract_address: address0, + class_hash: Some(class_hash0), + nonce: Some(nonce0), + storage_diffs: indexmap!(key0 => value0, key1 => value1), + }), + StateDiffChunk::DeprecatedDeclaredClass(DeprecatedDeclaredClass { + class_hash: class_hash1, + }), + ], + ), + ( + ThinStateDiff { + deployed_contracts: indexmap!(address1 => class_hash1), + storage_diffs: indexmap!( + address1 => indexmap!(key0 => value0), + address2 => indexmap!(key1 => value1) + ), + nonces: indexmap!(address2 => nonce0), + ..Default::default() + }, + vec![ + StateDiffChunk::ContractDiff(ContractDiff { + contract_address: address1, + class_hash: Some(class_hash1), + nonce: None, + storage_diffs: indexmap!(key0 => value0), + }), + StateDiffChunk::ContractDiff(ContractDiff { + contract_address: address2, + class_hash: None, + nonce: Some(nonce0), + storage_diffs: indexmap!(key1 => value1), + }), + ], + ), + ]; + + let mut actions = vec![ + // We already validate the header query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::Header), + ]; + + // Send headers with corresponding state diff length + for (i, (state_diff, _)) in state_diffs_and_chunks.iter().enumerate() { + actions.push(Action::SendHeader(DataOrFin(Some(random_header( + &mut rng, + BlockNumber(i.try_into().unwrap()), + Some(state_diff.len()), + None, + ))))); + } + actions.push(Action::SendHeader(DataOrFin(None))); + + let len = state_diffs_and_chunks.len(); + actions.push(Action::ReceiveQuery( + Box::new(move |query| { + assert_eq!( + query, + Query { + start_block: BlockHashOrNumber::Number(BlockNumber(0)), + direction: Direction::Forward, + limit: len.try_into().unwrap(), + step: 1, + } + ) + }), + DataType::StateDiff, + )); + // Send state diff chunks and check storage + for (i, (expected_state_diff, state_diff_chunks)) in + state_diffs_and_chunks.iter().cloned().enumerate() + { + for state_diff_chunk in state_diff_chunks { + // Check that before the last chunk was sent, the state diff isn't written. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + assert_eq!( + u64::try_from(i).unwrap(), + reader.begin_ro_txn().unwrap().get_state_marker().unwrap().0 + ); + } + .boxed() + }))); + actions.push(Action::SendStateDiff(DataOrFin(Some(state_diff_chunk)))); + } + // Check that a block's state diff is written before the entire query finished. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + let block_number = BlockNumber(i.try_into().unwrap()); wait_for_marker( DataType::StateDiff, - &storage_reader, + &reader, block_number.unchecked_next(), SLEEP_DURATION_TO_LET_SYNC_ADVANCE, TIMEOUT_FOR_TEST, ) .await; - let txn = storage_reader.begin_ro_txn().unwrap(); - let state_diff = txn.get_state_diff(block_number).unwrap().unwrap(); - // TODO(noamsp): refactor test so that we treat multiple state diff chunks as a - // single state diff - let expected_state_diff = match state_diff_chunk { - StateDiffChunk::ContractDiff(contract_diff) => { - let mut deployed_contracts = indexmap! {}; - if let Some(class_hash) = contract_diff.class_hash { - deployed_contracts.insert(contract_diff.contract_address, class_hash); - }; - let mut nonces = indexmap! {}; - if let Some(nonce) = contract_diff.nonce { - nonces.insert(contract_diff.contract_address, nonce); - } - ThinStateDiff { - deployed_contracts, - nonces, - storage_diffs: indexmap! { - contract_diff.contract_address => contract_diff.storage_diffs - }, - ..Default::default() - } - } - StateDiffChunk::DeclaredClass(declared_class) => ThinStateDiff { - declared_classes: indexmap! { - declared_class.class_hash => declared_class.compiled_class_hash - }, - ..Default::default() - }, - StateDiffChunk::DeprecatedDeclaredClass(deprecated_declared_class) => { - ThinStateDiff { - deprecated_declared_classes: vec![deprecated_declared_class.class_hash], - ..Default::default() - } - } - }; - assert_eq!(state_diff, expected_state_diff); + let txn = reader.begin_ro_txn().unwrap(); + let actual_state_diff = txn.get_state_diff(block_number).unwrap().unwrap(); + assert_eq!(actual_state_diff, expected_state_diff); } - - state_diff_sender.send(None).await.unwrap(); - } - }; - - tokio::select! { - sync_result = p2p_sync.run() => { - sync_result.unwrap(); - panic!("P2P sync aborted with no failure."); - } - _ = join( - run_state_diff_sync_through_channel( - &mut mock_header_response_manager, - &mut mock_state_diff_response_manager, - header_state_diff_lengths, - &mut state_diff_receiver, - false, - ), - test_future, - ) => {} + .boxed() + }))); } + actions.push(Action::SendStateDiff(DataOrFin(None))); + + run_test( + HashMap::from([ + (DataType::Header, state_diffs_and_chunks.len().try_into().unwrap()), + (DataType::StateDiff, state_diffs_and_chunks.len().try_into().unwrap()), + ]), + actions, + ) + .await; } // TODO(noamsp): Consider verifying that ParseDataError::BadPeerError(EmptyStateDiffPart) was @@ -288,219 +303,42 @@ async fn validate_state_diff_fails( header_state_diff_lengths: Vec, state_diff_chunks: Vec>, ) { - let TestArgs { - storage_reader, - p2p_sync, - mut mock_state_diff_response_manager, - mut mock_header_response_manager, - // The test will fail if we drop these - mock_transaction_response_manager: _mock_transaction_responses_manager, - mock_class_response_manager: _mock_class_responses_manager, - .. - } = setup(); - - let (state_diff_sender, mut state_diff_receiver) = channel(p2p_sync.config.buffer_size); - - // Create a future that will send responses and validate the results. - let test_future = async move { - for state_diff_chunk in state_diff_chunks { - // Check that before we've sent all parts the state diff wasn't written yet. - let txn = storage_reader.begin_ro_txn().unwrap(); - assert_eq!(0, txn.get_state_marker().unwrap().0); - - state_diff_sender.send(state_diff_chunk).await.unwrap(); - } - }; - - tokio::select! { - sync_result = p2p_sync.run() => { - sync_result.unwrap(); - panic!("P2P sync aborted with no failure."); - } - _ = join( - run_state_diff_sync_through_channel( - &mut mock_header_response_manager, - &mut mock_state_diff_response_manager, - header_state_diff_lengths, - &mut state_diff_receiver, - true, - ), - test_future - ) => {} - } -} - -// Advances the header sync with associated header state diffs. -// The receiver waits for external sender to provide the state diff chunks. -async fn run_state_diff_sync_through_channel( - mock_header_response_manager: &mut GenericReceiver, - mock_state_diff_response_manager: &mut GenericReceiver, - header_state_diff_lengths: Vec, - state_diff_chunk_receiver: &mut Receiver>, - should_assert_reported: bool, -) { - // We wait for the state diff sync to see that there are no headers and start sleeping - tokio::time::sleep(SLEEP_DURATION_TO_LET_SYNC_ADVANCE).await; - - // Check that before we send headers there is no state diff query. - assert!(mock_state_diff_response_manager.next().now_or_never().is_none()); - - let num_headers = header_state_diff_lengths.len(); - let block_hashes_and_signatures = - create_block_hashes_and_signatures(num_headers.try_into().unwrap()); - - // split the headers into queries of size HEADER_QUERY_LENGTH and send headers for each query - for headers_for_current_query in block_hashes_and_signatures - .into_iter() - .zip(header_state_diff_lengths.clone().into_iter()) - .enumerate() - .collect::>() - .chunks(HEADER_QUERY_LENGTH.try_into().unwrap()) - .map(Vec::from) - { - // Receive the next query from header sync - let mut mock_header_responses_manager = mock_header_response_manager.next().await.unwrap(); - - for (i, ((block_hash, block_signature), header_state_diff_length)) in - headers_for_current_query - { - // Send header responses - mock_header_responses_manager - .send_response(DataOrFin(Some(SignedBlockHeader { - block_header: BlockHeader { - block_hash, - block_header_without_hash: BlockHeaderWithoutHash { - block_number: BlockNumber(u64::try_from(i).unwrap()), - ..Default::default() - }, - state_diff_length: Some(header_state_diff_length), - ..Default::default() - }, - signatures: vec![block_signature], - }))) - .await - .unwrap(); - } + let mut rng = get_rng(); - mock_header_responses_manager.send_response(DataOrFin(None)).await.unwrap(); + let mut actions = vec![ + // We already validate the header query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::Header), + ]; + + // Send headers with corresponding state diff length + for (i, state_diff_length) in header_state_diff_lengths.iter().copied().enumerate() { + actions.push(Action::SendHeader(DataOrFin(Some(random_header( + &mut rng, + BlockNumber(i.try_into().unwrap()), + Some(state_diff_length), + None, + ))))); } + actions.push(Action::SendHeader(DataOrFin(None))); - // TODO(noamsp): remove sleep and wait until header marker writes the new headers. remove the - // comment from the StateDiffQuery about the limit being too low. We wait for the header - // sync to write the new headers. - tokio::time::sleep(SLEEP_DURATION_TO_LET_SYNC_ADVANCE).await; - - // Simulate time has passed so that state diff sync will resend query after it waited for - // new header - tokio::time::pause(); - tokio::time::advance(WAIT_PERIOD_FOR_NEW_DATA).await; - tokio::time::resume(); - - let num_state_diff_headers = u64::try_from(num_headers).unwrap(); - let num_state_diff_queries = num_state_diff_headers.div_ceil(STATE_DIFF_QUERY_LENGTH); - - for i in 0..num_state_diff_queries { - let start_block_number = i * STATE_DIFF_QUERY_LENGTH; - let limit = min(num_state_diff_headers - start_block_number, STATE_DIFF_QUERY_LENGTH); - - // Get a state diff query and validate it - let mut mock_state_diff_responses_manager = - mock_state_diff_response_manager.next().await.unwrap(); - assert_eq!( - *mock_state_diff_responses_manager.query(), - Ok(StateDiffQuery(Query { - start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)), - direction: Direction::Forward, - limit, - step: 1, - })), - "If the limit of the query is too low, try to increase \ - SLEEP_DURATION_TO_LET_SYNC_ADVANCE", - ); - - let mut current_state_diff_length = 0; - let destination_state_diff_length = - header_state_diff_lengths[start_block_number.try_into().unwrap() - ..(start_block_number + limit).try_into().unwrap()] - .iter() - .sum(); - - while current_state_diff_length < destination_state_diff_length { - let state_diff_chunk = state_diff_chunk_receiver.recv().await.unwrap(); - - mock_state_diff_responses_manager - .send_response(DataOrFin(state_diff_chunk.clone())) - .await - .unwrap(); - - if let Some(state_diff_chunk) = state_diff_chunk { - if !state_diff_chunk.is_empty() { - current_state_diff_length += state_diff_chunk.len(); - continue; - } - } - - break; - } - - if should_assert_reported { - mock_state_diff_responses_manager.assert_reported(TIMEOUT_FOR_TEST).await; - continue; - } + actions.push( + // We already validate the state diff query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::StateDiff), + ); - assert_eq!(current_state_diff_length, destination_state_diff_length); - let state_diff_chunk = state_diff_chunk_receiver.recv().await.unwrap(); - mock_state_diff_responses_manager - .send_response(DataOrFin(state_diff_chunk.clone())) - .await - .unwrap(); + // Send state diff chunks. + for state_diff_chunk in state_diff_chunks { + actions.push(Action::SendStateDiff(DataOrFin(state_diff_chunk))); } -} -pub(crate) async fn run_state_diff_sync( - config: P2PSyncClientConfig, - mock_header_response_manager: &mut GenericReceiver, - mock_state_diff_response_manager: &mut GenericReceiver, - header_state_diff_lengths: Vec, - state_diff_chunks: Vec>, -) { - let (state_diff_sender, mut state_diff_receiver) = channel(config.buffer_size); - tokio::join! { - run_state_diff_sync_through_channel( - mock_header_response_manager, - mock_state_diff_response_manager, - header_state_diff_lengths, - &mut state_diff_receiver, - false, - ), - async { - for state_diff in state_diff_chunks.chunks(STATE_DIFF_QUERY_LENGTH.try_into().unwrap()) { - for state_diff_chunk in state_diff { - state_diff_sender.send(state_diff_chunk.clone()).await.unwrap(); - } + actions.push(Action::ValidateReportSent(DataType::StateDiff)); - state_diff_sender.send(None).await.unwrap(); - } - } - }; -} - -fn create_random_state_diff_chunk(rng: &mut ChaCha8Rng) -> StateDiffChunk { - let mut state_diff_chunk = StateDiffChunk::get_test_instance(rng); - let contract_address = ContractAddress::from(rng.next_u64()); - let class_hash = ClassHash(rng.next_u64().into()); - match &mut state_diff_chunk { - StateDiffChunk::ContractDiff(contract_diff) => { - contract_diff.contract_address = contract_address; - contract_diff.class_hash = Some(class_hash); - } - StateDiffChunk::DeclaredClass(declared_class) => { - declared_class.class_hash = class_hash; - declared_class.compiled_class_hash = CompiledClassHash(rng.next_u64().into()); - } - StateDiffChunk::DeprecatedDeclaredClass(deprecated_declared_class) => { - deprecated_declared_class.class_hash = class_hash; - } - } - state_diff_chunk + run_test( + HashMap::from([ + (DataType::Header, header_state_diff_lengths.len().try_into().unwrap()), + (DataType::StateDiff, header_state_diff_lengths.len().try_into().unwrap()), + ]), + actions, + ) + .await; } diff --git a/crates/papyrus_p2p_sync/src/client/test_utils.rs b/crates/papyrus_p2p_sync/src/client/test_utils.rs index 006ce510f9..e6c19cecf8 100644 --- a/crates/papyrus_p2p_sync/src/client/test_utils.rs +++ b/crates/papyrus_p2p_sync/src/client/test_utils.rs @@ -141,7 +141,6 @@ pub enum Action { SendHeader(DataOrFin), /// Send a state diff as a response to a query we got from ReceiveQuery. Will panic if didn't /// call ReceiveQuery with DataType::StateDiff before. - #[allow(dead_code)] SendStateDiff(DataOrFin), /// Send a transaction as a response to a query we got from ReceiveQuery. Will panic if didn't /// call ReceiveQuery with DataType::Transaction before. @@ -149,7 +148,6 @@ pub enum Action { SendTransaction(DataOrFin), /// Send a class as a response to a query we got from ReceiveQuery. Will panic if didn't /// call ReceiveQuery with DataType::Class before. - #[allow(dead_code)] SendClass(DataOrFin<(ApiContractClass, ClassHash)>), /// Perform custom validations on the storage. Returns back the storage reader it received as /// input