Skip to content

Commit

Permalink
refactor(papyrus_p2p_sync): convert class tests to use run_test
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahakShama committed Dec 17, 2024
1 parent 9287a94 commit 589f149
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 271 deletions.
205 changes: 104 additions & 101 deletions crates/papyrus_p2p_sync/src/client/class_test.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::cmp::min;
use std::collections::HashMap;

use futures::{FutureExt, StreamExt};
use futures::FutureExt;
use papyrus_common::pending_classes::ApiContractClass;
use papyrus_protobuf::sync::{
BlockHashOrNumber,
ClassQuery,
DataOrFin,
DeclaredClass,
DeprecatedDeclaredClass,
Expand All @@ -22,129 +21,131 @@ use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContract
use starknet_api::state::SierraContractClass;

use super::test_utils::{
setup,
random_header,
run_test,
wait_for_marker,
Action,
DataType,
TestArgs,
CLASS_DIFF_QUERY_LENGTH,
HEADER_QUERY_LENGTH,
SLEEP_DURATION_TO_LET_SYNC_ADVANCE,
TIMEOUT_FOR_TEST,
};
use crate::client::state_diff_test::run_state_diff_sync;

#[tokio::test]
async fn class_basic_flow() {
let TestArgs {
p2p_sync,
storage_reader,
mut mock_state_diff_response_manager,
mut mock_header_response_manager,
mut mock_class_response_manager,
// The test will fail if we drop this
mock_transaction_response_manager: _mock_transaction_responses_manager,
..
} = setup();

let mut rng = get_rng();
// TODO(noamsp): Add multiple state diffs per header
let (class_state_diffs, api_contract_classes): (Vec<_>, Vec<_>) = (0..HEADER_QUERY_LENGTH)
.map(|_| create_random_state_diff_chunk_with_class(&mut rng))
.unzip();
let header_state_diff_lengths =
class_state_diffs.iter().map(|class_state_diff| class_state_diff.len()).collect::<Vec<_>>();

// Create a future that will receive queries, send responses and validate the results
let parse_queries_future = async move {
// Check that before we send state diffs there is no class query.
assert!(mock_class_response_manager.next().now_or_never().is_none());

run_state_diff_sync(
p2p_sync.config,
&mut mock_header_response_manager,
&mut mock_state_diff_response_manager,
header_state_diff_lengths.clone(),
class_state_diffs.clone().into_iter().map(Some).collect(),
)
.await;

let num_declare_class_state_diff_headers =
u64::try_from(header_state_diff_lengths.len()).unwrap();
let num_class_queries =
num_declare_class_state_diff_headers.div_ceil(CLASS_DIFF_QUERY_LENGTH);
for i in 0..num_class_queries {
let start_block_number = i * CLASS_DIFF_QUERY_LENGTH;
let limit = min(
num_declare_class_state_diff_headers - start_block_number,
CLASS_DIFF_QUERY_LENGTH,
);
let state_diffs_and_classes_of_blocks = [
vec![
create_random_state_diff_chunk_with_class(&mut rng),
create_random_state_diff_chunk_with_class(&mut rng),
],
vec![
create_random_state_diff_chunk_with_class(&mut rng),
create_random_state_diff_chunk_with_class(&mut rng),
create_random_state_diff_chunk_with_class(&mut rng),
],
];

let mut actions = vec![
// We already validate the header query content in other tests.
Action::ReceiveQuery(Box::new(|_query| ()), DataType::Header),
];

// Send headers with corresponding state diff length.
for (i, state_diffs_and_classes) in state_diffs_and_classes_of_blocks.iter().enumerate() {
actions.push(Action::SendHeader(DataOrFin(Some(random_header(
&mut rng,
BlockNumber(i.try_into().unwrap()),
Some(state_diffs_and_classes.len()),
None,
)))));
}
actions.push(Action::SendHeader(DataOrFin(None)));

// Send state diffs.
actions.push(
// We already validate the state diff query content in other tests.
Action::ReceiveQuery(Box::new(|_query| ()), DataType::StateDiff),
);
for state_diffs_and_classes in &state_diffs_and_classes_of_blocks {
for (state_diff, _) in state_diffs_and_classes {
actions.push(Action::SendStateDiff(DataOrFin(Some(state_diff.clone()))));
}
}

// Get a class query and validate it
let mut mock_class_responses_manager =
mock_class_response_manager.next().await.unwrap();
let len = state_diffs_and_classes_of_blocks.len();
actions.push(Action::ReceiveQuery(
Box::new(move |query| {
assert_eq!(
*mock_class_responses_manager.query(),
Ok(ClassQuery(Query {
start_block: BlockHashOrNumber::Number(BlockNumber(start_block_number)),
query,
Query {
start_block: BlockHashOrNumber::Number(BlockNumber(0)),
direction: Direction::Forward,
limit,
limit: len.try_into().unwrap(),
step: 1,
})),
"If the limit of the query is too low, try to increase \
SLEEP_DURATION_TO_LET_SYNC_ADVANCE",
);

for block_number in start_block_number..(start_block_number + limit) {
let class_hash =
class_state_diffs[usize::try_from(block_number).unwrap()].get_class_hash();
let expected_class =
api_contract_classes[usize::try_from(block_number).unwrap()].clone();

let block_number = BlockNumber(block_number);

// Check that before we've sent all parts the contract class wasn't written yet
let txn = storage_reader.begin_ro_txn().unwrap();
assert_eq!(block_number, txn.get_class_marker().unwrap());

mock_class_responses_manager
.send_response(DataOrFin(Some((expected_class.clone(), class_hash))))
.await
.unwrap();

}
)
}),
DataType::Class,
));
for (i, state_diffs_and_classes) in state_diffs_and_classes_of_blocks.into_iter().enumerate() {
for (state_diff, class) in &state_diffs_and_classes {
let class_hash = state_diff.get_class_hash();

// Check that before the last class was sent, the classes aren't written.
actions.push(Action::CheckStorage(Box::new(move |reader| {
async move {
assert_eq!(
u64::try_from(i).unwrap(),
reader.begin_ro_txn().unwrap().get_class_marker().unwrap().0
);
}
.boxed()
})));
actions.push(Action::SendClass(DataOrFin(Some((class.clone(), class_hash)))));
}
// Check that a block's classes are written before the entire query finished.
actions.push(Action::CheckStorage(Box::new(move |reader| {
async move {
let block_number = BlockNumber(i.try_into().unwrap());
wait_for_marker(
DataType::Class,
&storage_reader,
&reader,
block_number.unchecked_next(),
SLEEP_DURATION_TO_LET_SYNC_ADVANCE,
TIMEOUT_FOR_TEST,
)
.await;

let txn = storage_reader.begin_ro_txn().unwrap();
let actual_class = match expected_class {
ApiContractClass::ContractClass(_) => ApiContractClass::ContractClass(
txn.get_class(&class_hash).unwrap().unwrap(),
),
ApiContractClass::DeprecatedContractClass(_) => {
ApiContractClass::DeprecatedContractClass(
txn.get_deprecated_class(&class_hash).unwrap().unwrap(),
)
let txn = reader.begin_ro_txn().unwrap();
for (state_diff, expected_class) in state_diffs_and_classes {
let class_hash = state_diff.get_class_hash();
match expected_class {
ApiContractClass::ContractClass(expected_class) => {
let actual_class = txn.get_class(&class_hash).unwrap().unwrap();
assert_eq!(actual_class, expected_class.clone());
}
ApiContractClass::DeprecatedContractClass(expected_class) => {
let actual_class =
txn.get_deprecated_class(&class_hash).unwrap().unwrap();
assert_eq!(actual_class, expected_class.clone());
}
}
};
assert_eq!(expected_class, actual_class);
}
}

mock_class_responses_manager.send_response(DataOrFin(None)).await.unwrap();
}
};

tokio::select! {
sync_result = p2p_sync.run() => {
sync_result.unwrap();
panic!("P2P sync aborted with no failure.");
}
_ = parse_queries_future => {}
.boxed()
})));
}

run_test(
HashMap::from([
(DataType::Header, len.try_into().unwrap()),
(DataType::StateDiff, len.try_into().unwrap()),
(DataType::Class, len.try_into().unwrap()),
]),
actions,
)
.await;
}

// We define this new trait here so we can use the get_class_hash function in the test.
Expand Down Expand Up @@ -176,6 +177,8 @@ fn create_random_state_diff_chunk_with_class(
};
(
StateDiffChunk::DeclaredClass(declared_class),
// TODO(noamsp): get_test_instance on these types returns the same value, making this
// test redundant. Fix this.
ApiContractClass::ContractClass(SierraContractClass::get_test_instance(rng)),
)
} else {
Expand Down
Loading

0 comments on commit 589f149

Please sign in to comment.