diff --git a/crates/sequencing/papyrus_consensus/Cargo.toml b/crates/sequencing/papyrus_consensus/Cargo.toml index 5694a205fe..95104b1963 100644 --- a/crates/sequencing/papyrus_consensus/Cargo.toml +++ b/crates/sequencing/papyrus_consensus/Cargo.toml @@ -24,7 +24,7 @@ serde = { workspace = true, features = ["derive"] } starknet-types-core.workspace = true starknet_api.workspace = true thiserror.workspace = true -tokio.workspace = true +tokio = { workspace = true, features = ["sync"] } tracing.workspace = true [dev-dependencies] diff --git a/crates/sequencing/papyrus_consensus/src/manager_test.rs b/crates/sequencing/papyrus_consensus/src/manager_test.rs index e186c23f0e..c7f2ddb9cd 100644 --- a/crates/sequencing/papyrus_consensus/src/manager_test.rs +++ b/crates/sequencing/papyrus_consensus/src/manager_test.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Duration; use std::vec; @@ -24,6 +25,7 @@ use papyrus_test_utils::{get_rng, GetTestInstance}; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::transaction::Transaction; use starknet_types_core::felt::Felt; +use tokio::sync::Notify; use super::{run_consensus, MultiHeightManager}; use crate::config::TimeoutsConfig; @@ -42,7 +44,11 @@ lazy_static! { static ref VALIDATOR_ID: ValidatorId = (DEFAULT_VALIDATOR_ID + 1).into(); static ref VALIDATOR_ID_2: ValidatorId = (DEFAULT_VALIDATOR_ID + 2).into(); static ref VALIDATOR_ID_3: ValidatorId = (DEFAULT_VALIDATOR_ID + 3).into(); - static ref TIMEOUTS: TimeoutsConfig = TimeoutsConfig::default(); + static ref TIMEOUTS: TimeoutsConfig = TimeoutsConfig { + prevote_timeout: Duration::from_millis(100), + precommit_timeout: Duration::from_millis(100), + proposal_timeout: Duration::from_millis(100), + }; } const CHANNEL_SIZE: usize = 10; @@ -263,7 +269,7 @@ async fn run_consensus_sync() { #[tokio::test] async fn run_consensus_sync_cancellation_safety() { let mut context = MockTestContext::new(); - let (proposal_handled_tx, proposal_handled_rx) = oneshot::channel(); + let proposal_handled = Arc::new(Notify::new()); let (decision_tx, decision_rx) = oneshot::channel(); // TODO(guyn): refactor this test to pass proposals through the correct channels. @@ -273,12 +279,15 @@ async fn run_consensus_sync_cancellation_safety() { context.expect_validators().returning(move |_| vec![*PROPOSER_ID, *VALIDATOR_ID]); context.expect_proposer().returning(move |_, _| *PROPOSER_ID); context.expect_set_height_and_round().returning(move |_, _| ()); - context.expect_broadcast().with(eq(prevote(Some(Felt::ONE), 1, 0, *VALIDATOR_ID))).return_once( - move |_| { - proposal_handled_tx.send(()).unwrap(); + let proposal_handled_clone = Arc::clone(&proposal_handled); + context + .expect_broadcast() + .with(eq(prevote(Some(Felt::ONE), 1, 0, *VALIDATOR_ID))) + // May occur repeatedly due to re-broadcasting. + .returning(move |_| { + proposal_handled_clone.notify_one(); Ok(()) - }, - ); + }); context.expect_broadcast().returning(move |_| Ok(())); context.expect_decision_reached().return_once(|block, votes| { assert_eq!(block, BlockHash(Felt::ONE)); @@ -313,7 +322,7 @@ async fn run_consensus_sync_cancellation_safety() { vec![ProposalPart::Init(proposal_init(1, 0, *PROPOSER_ID))], ) .await; - proposal_handled_rx.await.unwrap(); + proposal_handled.notified().await; // Send an old sync. This should not cancel the current height. sync_sender.send(BlockNumber(0)).await.unwrap();