From 589f1496cd108c1922a40881863e69e00576d038 Mon Sep 17 00:00:00 2001 From: Shahak Shama Date: Thu, 5 Dec 2024 15:03:18 +0200 Subject: [PATCH] refactor(papyrus_p2p_sync): convert class tests to use run_test --- .../papyrus_p2p_sync/src/client/class_test.rs | 205 +++++++++--------- .../src/client/state_diff_test.rs | 170 +-------------- .../papyrus_p2p_sync/src/client/test_utils.rs | 2 - 3 files changed, 106 insertions(+), 271 deletions(-) diff --git a/crates/papyrus_p2p_sync/src/client/class_test.rs b/crates/papyrus_p2p_sync/src/client/class_test.rs index 6e9e6629bd..85b49d931e 100644 --- a/crates/papyrus_p2p_sync/src/client/class_test.rs +++ b/crates/papyrus_p2p_sync/src/client/class_test.rs @@ -1,10 +1,9 @@ -use std::cmp::min; +use std::collections::HashMap; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use papyrus_common::pending_classes::ApiContractClass; use papyrus_protobuf::sync::{ BlockHashOrNumber, - ClassQuery, DataOrFin, DeclaredClass, DeprecatedDeclaredClass, @@ -22,129 +21,131 @@ use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContract use starknet_api::state::SierraContractClass; use super::test_utils::{ - setup, + random_header, + run_test, wait_for_marker, + Action, DataType, - TestArgs, - CLASS_DIFF_QUERY_LENGTH, - HEADER_QUERY_LENGTH, SLEEP_DURATION_TO_LET_SYNC_ADVANCE, TIMEOUT_FOR_TEST, }; -use crate::client::state_diff_test::run_state_diff_sync; #[tokio::test] async fn class_basic_flow() { - let TestArgs { - p2p_sync, - storage_reader, - mut mock_state_diff_response_manager, - mut mock_header_response_manager, - mut mock_class_response_manager, - // The test will fail if we drop this - mock_transaction_response_manager: _mock_transaction_responses_manager, - .. - } = setup(); - let mut rng = get_rng(); - // TODO(noamsp): Add multiple state diffs per header - let (class_state_diffs, api_contract_classes): (Vec<_>, Vec<_>) = (0..HEADER_QUERY_LENGTH) - .map(|_| create_random_state_diff_chunk_with_class(&mut rng)) - .unzip(); - let header_state_diff_lengths = - class_state_diffs.iter().map(|class_state_diff| class_state_diff.len()).collect::>(); - - // Create a future that will receive queries, send responses and validate the results - let parse_queries_future = async move { - // Check that before we send state diffs there is no class query. - assert!(mock_class_response_manager.next().now_or_never().is_none()); - run_state_diff_sync( - p2p_sync.config, - &mut mock_header_response_manager, - &mut mock_state_diff_response_manager, - header_state_diff_lengths.clone(), - class_state_diffs.clone().into_iter().map(Some).collect(), - ) - .await; - - let num_declare_class_state_diff_headers = - u64::try_from(header_state_diff_lengths.len()).unwrap(); - let num_class_queries = - num_declare_class_state_diff_headers.div_ceil(CLASS_DIFF_QUERY_LENGTH); - for i in 0..num_class_queries { - let start_block_number = i * CLASS_DIFF_QUERY_LENGTH; - let limit = min( - num_declare_class_state_diff_headers - start_block_number, - CLASS_DIFF_QUERY_LENGTH, - ); + let state_diffs_and_classes_of_blocks = [ + vec![ + create_random_state_diff_chunk_with_class(&mut rng), + create_random_state_diff_chunk_with_class(&mut rng), + ], + vec![ + create_random_state_diff_chunk_with_class(&mut rng), + create_random_state_diff_chunk_with_class(&mut rng), + create_random_state_diff_chunk_with_class(&mut rng), + ], + ]; + + let mut actions = vec![ + // We already validate the header query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::Header), + ]; + + // Send headers with corresponding state diff length. + for (i, state_diffs_and_classes) in state_diffs_and_classes_of_blocks.iter().enumerate() { + actions.push(Action::SendHeader(DataOrFin(Some(random_header( + &mut rng, + BlockNumber(i.try_into().unwrap()), + Some(state_diffs_and_classes.len()), + None, + ))))); + } + actions.push(Action::SendHeader(DataOrFin(None))); + + // Send state diffs. + actions.push( + // We already validate the state diff query content in other tests. + Action::ReceiveQuery(Box::new(|_query| ()), DataType::StateDiff), + ); + for state_diffs_and_classes in &state_diffs_and_classes_of_blocks { + for (state_diff, _) in state_diffs_and_classes { + actions.push(Action::SendStateDiff(DataOrFin(Some(state_diff.clone())))); + } + } - // Get a class query and validate it - let mut mock_class_responses_manager = - mock_class_response_manager.next().await.unwrap(); + let len = state_diffs_and_classes_of_blocks.len(); + actions.push(Action::ReceiveQuery( + Box::new(move |query| { assert_eq!( - *mock_class_responses_manager.query(), - Ok(ClassQuery(Query { - start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)), + query, + Query { + start_block: BlockHashOrNumber::Number(BlockNumber(0)), direction: Direction::Forward, - limit, + limit: len.try_into().unwrap(), step: 1, - })), - "If the limit of the query is too low, try to increase \ - SLEEP_DURATION_TO_LET_SYNC_ADVANCE", - ); - - for block_number in start_block_number..(start_block_number + limit) { - let class_hash = - class_state_diffs[usize::try_from(block_number).unwrap()].get_class_hash(); - let expected_class = - api_contract_classes[usize::try_from(block_number).unwrap()].clone(); - - let block_number = BlockNumber(block_number); - - // Check that before we've sent all parts the contract class wasn't written yet - let txn = storage_reader.begin_ro_txn().unwrap(); - assert_eq!(block_number, txn.get_class_marker().unwrap()); - - mock_class_responses_manager - .send_response(DataOrFin(Some((expected_class.clone(), class_hash)))) - .await - .unwrap(); - + } + ) + }), + DataType::Class, + )); + for (i, state_diffs_and_classes) in state_diffs_and_classes_of_blocks.into_iter().enumerate() { + for (state_diff, class) in &state_diffs_and_classes { + let class_hash = state_diff.get_class_hash(); + + // Check that before the last class was sent, the classes aren't written. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + assert_eq!( + u64::try_from(i).unwrap(), + reader.begin_ro_txn().unwrap().get_class_marker().unwrap().0 + ); + } + .boxed() + }))); + actions.push(Action::SendClass(DataOrFin(Some((class.clone(), class_hash))))); + } + // Check that a block's classes are written before the entire query finished. + actions.push(Action::CheckStorage(Box::new(move |reader| { + async move { + let block_number = BlockNumber(i.try_into().unwrap()); wait_for_marker( DataType::Class, - &storage_reader, + &reader, block_number.unchecked_next(), SLEEP_DURATION_TO_LET_SYNC_ADVANCE, TIMEOUT_FOR_TEST, ) .await; - let txn = storage_reader.begin_ro_txn().unwrap(); - let actual_class = match expected_class { - ApiContractClass::ContractClass(_) => ApiContractClass::ContractClass( - txn.get_class(&class_hash).unwrap().unwrap(), - ), - ApiContractClass::DeprecatedContractClass(_) => { - ApiContractClass::DeprecatedContractClass( - txn.get_deprecated_class(&class_hash).unwrap().unwrap(), - ) + let txn = reader.begin_ro_txn().unwrap(); + for (state_diff, expected_class) in state_diffs_and_classes { + let class_hash = state_diff.get_class_hash(); + match expected_class { + ApiContractClass::ContractClass(expected_class) => { + let actual_class = txn.get_class(&class_hash).unwrap().unwrap(); + assert_eq!(actual_class, expected_class.clone()); + } + ApiContractClass::DeprecatedContractClass(expected_class) => { + let actual_class = + txn.get_deprecated_class(&class_hash).unwrap().unwrap(); + assert_eq!(actual_class, expected_class.clone()); + } } - }; - assert_eq!(expected_class, actual_class); + } } - - mock_class_responses_manager.send_response(DataOrFin(None)).await.unwrap(); - } - }; - - tokio::select! { - sync_result = p2p_sync.run() => { - sync_result.unwrap(); - panic!("P2P sync aborted with no failure."); - } - _ = parse_queries_future => {} + .boxed() + }))); } + + run_test( + HashMap::from([ + (DataType::Header, len.try_into().unwrap()), + (DataType::StateDiff, len.try_into().unwrap()), + (DataType::Class, len.try_into().unwrap()), + ]), + actions, + ) + .await; } // We define this new trait here so we can use the get_class_hash function in the test. @@ -176,6 +177,8 @@ fn create_random_state_diff_chunk_with_class( }; ( StateDiffChunk::DeclaredClass(declared_class), + // TODO(noamsp): get_test_instance on these types returns the same value, making this + // test redundant. Fix this. ApiContractClass::ContractClass(SierraContractClass::get_test_instance(rng)), ) } else { diff --git a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs index d4e86f259e..eb8c1528fa 100644 --- a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs +++ b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs @@ -1,9 +1,7 @@ -use std::cmp::min; use std::collections::HashMap; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use indexmap::indexmap; -use papyrus_network::network_manager::GenericReceiver; use papyrus_protobuf::sync::{ BlockHashOrNumber, ContractDiff, @@ -12,33 +10,24 @@ use papyrus_protobuf::sync::{ DeprecatedDeclaredClass, Direction, Query, - SignedBlockHeader, StateDiffChunk, }; use papyrus_storage::state::StateStorageReader; use papyrus_test_utils::get_rng; -use starknet_api::block::{BlockHeader, BlockHeaderWithoutHash, BlockNumber}; +use starknet_api::block::BlockNumber; use starknet_api::core::{ascii_as_felt, ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::{StorageKey, ThinStateDiff}; use starknet_types_core::felt::Felt; -use tokio::sync::mpsc::{channel, Receiver}; use super::test_utils::{ - create_block_hashes_and_signatures, random_header, run_test, wait_for_marker, Action, DataType, - HeaderTestPayload, - StateDiffTestPayload, - HEADER_QUERY_LENGTH, SLEEP_DURATION_TO_LET_SYNC_ADVANCE, - STATE_DIFF_QUERY_LENGTH, TIMEOUT_FOR_TEST, - WAIT_PERIOD_FOR_NEW_DATA, }; -use super::{P2PSyncClientConfig, StateDiffQuery}; #[tokio::test] async fn state_diff_basic_flow() { @@ -353,158 +342,3 @@ async fn validate_state_diff_fails( ) .await; } - -// Advances the header sync with associated header state diffs. -// The receiver waits for external sender to provide the state diff chunks. -async fn run_state_diff_sync_through_channel( - mock_header_response_manager: &mut GenericReceiver, - mock_state_diff_response_manager: &mut GenericReceiver, - header_state_diff_lengths: Vec, - state_diff_chunk_receiver: &mut Receiver>, - should_assert_reported: bool, -) { - // We wait for the state diff sync to see that there are no headers and start sleeping - tokio::time::sleep(SLEEP_DURATION_TO_LET_SYNC_ADVANCE).await; - - // Check that before we send headers there is no state diff query. - assert!(mock_state_diff_response_manager.next().now_or_never().is_none()); - - let num_headers = header_state_diff_lengths.len(); - let block_hashes_and_signatures = - create_block_hashes_and_signatures(num_headers.try_into().unwrap()); - - // split the headers into queries of size HEADER_QUERY_LENGTH and send headers for each query - for headers_for_current_query in block_hashes_and_signatures - .into_iter() - .zip(header_state_diff_lengths.clone().into_iter()) - .enumerate() - .collect::>() - .chunks(HEADER_QUERY_LENGTH.try_into().unwrap()) - .map(Vec::from) - { - // Receive the next query from header sync - let mut mock_header_responses_manager = mock_header_response_manager.next().await.unwrap(); - - for (i, ((block_hash, block_signature), header_state_diff_length)) in - headers_for_current_query - { - // Send header responses - mock_header_responses_manager - .send_response(DataOrFin(Some(SignedBlockHeader { - block_header: BlockHeader { - block_hash, - block_header_without_hash: BlockHeaderWithoutHash { - block_number: BlockNumber(u64::try_from(i).unwrap()), - ..Default::default() - }, - state_diff_length: Some(header_state_diff_length), - ..Default::default() - }, - signatures: vec![block_signature], - }))) - .await - .unwrap(); - } - - mock_header_responses_manager.send_response(DataOrFin(None)).await.unwrap(); - } - - // TODO(noamsp): remove sleep and wait until header marker writes the new headers. remove the - // comment from the StateDiffQuery about the limit being too low. We wait for the header - // sync to write the new headers. - tokio::time::sleep(SLEEP_DURATION_TO_LET_SYNC_ADVANCE).await; - - // Simulate time has passed so that state diff sync will resend query after it waited for - // new header - tokio::time::pause(); - tokio::time::advance(WAIT_PERIOD_FOR_NEW_DATA).await; - tokio::time::resume(); - - let num_state_diff_headers = u64::try_from(num_headers).unwrap(); - let num_state_diff_queries = num_state_diff_headers.div_ceil(STATE_DIFF_QUERY_LENGTH); - - for i in 0..num_state_diff_queries { - let start_block_number = i * STATE_DIFF_QUERY_LENGTH; - let limit = min(num_state_diff_headers - start_block_number, STATE_DIFF_QUERY_LENGTH); - - // Get a state diff query and validate it - let mut mock_state_diff_responses_manager = - mock_state_diff_response_manager.next().await.unwrap(); - assert_eq!( - *mock_state_diff_responses_manager.query(), - Ok(StateDiffQuery(Query { - start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)), - direction: Direction::Forward, - limit, - step: 1, - })), - "If the limit of the query is too low, try to increase \ - SLEEP_DURATION_TO_LET_SYNC_ADVANCE", - ); - - let mut current_state_diff_length = 0; - let destination_state_diff_length = - header_state_diff_lengths[start_block_number.try_into().unwrap() - ..(start_block_number + limit).try_into().unwrap()] - .iter() - .sum(); - - while current_state_diff_length < destination_state_diff_length { - let state_diff_chunk = state_diff_chunk_receiver.recv().await.unwrap(); - - mock_state_diff_responses_manager - .send_response(DataOrFin(state_diff_chunk.clone())) - .await - .unwrap(); - - if let Some(state_diff_chunk) = state_diff_chunk { - if !state_diff_chunk.is_empty() { - current_state_diff_length += state_diff_chunk.len(); - continue; - } - } - - break; - } - - if should_assert_reported { - mock_state_diff_responses_manager.assert_reported(TIMEOUT_FOR_TEST).await; - continue; - } - - assert_eq!(current_state_diff_length, destination_state_diff_length); - let state_diff_chunk = state_diff_chunk_receiver.recv().await.unwrap(); - mock_state_diff_responses_manager - .send_response(DataOrFin(state_diff_chunk.clone())) - .await - .unwrap(); - } -} - -pub(crate) async fn run_state_diff_sync( - config: P2PSyncClientConfig, - mock_header_response_manager: &mut GenericReceiver, - mock_state_diff_response_manager: &mut GenericReceiver, - header_state_diff_lengths: Vec, - state_diff_chunks: Vec>, -) { - let (state_diff_sender, mut state_diff_receiver) = channel(config.buffer_size); - tokio::join! { - run_state_diff_sync_through_channel( - mock_header_response_manager, - mock_state_diff_response_manager, - header_state_diff_lengths, - &mut state_diff_receiver, - false, - ), - async { - for state_diff in state_diff_chunks.chunks(STATE_DIFF_QUERY_LENGTH.try_into().unwrap()) { - for state_diff_chunk in state_diff { - state_diff_sender.send(state_diff_chunk.clone()).await.unwrap(); - } - - state_diff_sender.send(None).await.unwrap(); - } - } - }; -} diff --git a/crates/papyrus_p2p_sync/src/client/test_utils.rs b/crates/papyrus_p2p_sync/src/client/test_utils.rs index 006ce510f9..e6c19cecf8 100644 --- a/crates/papyrus_p2p_sync/src/client/test_utils.rs +++ b/crates/papyrus_p2p_sync/src/client/test_utils.rs @@ -141,7 +141,6 @@ pub enum Action { SendHeader(DataOrFin), /// Send a state diff as a response to a query we got from ReceiveQuery. Will panic if didn't /// call ReceiveQuery with DataType::StateDiff before. - #[allow(dead_code)] SendStateDiff(DataOrFin), /// Send a transaction as a response to a query we got from ReceiveQuery. Will panic if didn't /// call ReceiveQuery with DataType::Transaction before. @@ -149,7 +148,6 @@ pub enum Action { SendTransaction(DataOrFin), /// Send a class as a response to a query we got from ReceiveQuery. Will panic if didn't /// call ReceiveQuery with DataType::Class before. - #[allow(dead_code)] SendClass(DataOrFin<(ApiContractClass, ClassHash)>), /// Perform custom validations on the storage. Returns back the storage reader it received as /// input