From 8ed234ee759650c94d0e95d5a73b7ec7cccf090b Mon Sep 17 00:00:00 2001 From: guy-starkware Date: Mon, 7 Oct 2024 00:40:27 -0700 Subject: [PATCH] feat(consensus): use network channel instead of regular channel as input for StreamHandler (#1084) --- .../papyrus_consensus/src/stream_handler.rs | 26 ++++++-- .../src/stream_handler_test.rs | 65 +++++++++++++------ 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler.rs b/crates/sequencing/papyrus_consensus/src/stream_handler.rs index 2cf763b5c4..50d3ca5f6f 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler.rs @@ -4,6 +4,8 @@ use std::collections::{BTreeMap, HashMap}; use futures::channel::mpsc; use futures::StreamExt; +use papyrus_network::network_manager::BroadcastTopicServer; +use papyrus_network_types::network_types::{BroadcastedMessageManager, OpaquePeerId}; use papyrus_protobuf::consensus::StreamMessage; use papyrus_protobuf::converters::ProtobufConversionError; use tracing::{instrument, warn}; @@ -12,11 +14,16 @@ use tracing::{instrument, warn}; #[path = "stream_handler_test.rs"] mod stream_handler_test; +type PeerId = OpaquePeerId; type StreamId = u64; type MessageId = u64; const CHANNEL_BUFFER_LENGTH: usize = 100; +fn get_metadata_peer_id(metadata: BroadcastedMessageManager) -> PeerId { + metadata.originator_id +} + #[derive(Debug, Clone)] struct StreamData> + TryFrom, Error = ProtobufConversionError>> { // The next message_id that is expected. @@ -53,7 +60,7 @@ pub struct StreamHandler< // An end of a channel used to send out receivers, one for each stream. sender: mpsc::Sender>, // An end of a channel used to receive messages. - receiver: mpsc::Receiver>, + receiver: BroadcastTopicServer>, // A map from stream_id to a struct that contains all the information about the stream. // This includes both the message buffer and some metadata (like the latest message_id). @@ -67,7 +74,7 @@ impl> + TryFrom, Error = ProtobufConversionError /// Create a new StreamHandler. pub fn new( sender: mpsc::Sender>, - receiver: mpsc::Receiver>, + receiver: BroadcastTopicServer>, ) -> Self { StreamHandler { sender, receiver, stream_data: HashMap::new() } } @@ -89,9 +96,20 @@ impl> + TryFrom, Error = ProtobufConversionError data.next_message_id += 1; } - // Handle the message, return true if the channel is still open. #[instrument(skip_all, level = "warn")] - fn handle_message(&mut self, message: StreamMessage) { + fn handle_message( + &mut self, + message: (Result, ProtobufConversionError>, BroadcastedMessageManager), + ) { + let (message, metadata) = message; + let message = match message { + Ok(message) => message, + Err(e) => { + warn!("Error converting message: {:?}", e); + return; + } + }; + let _peer_id = get_metadata_peer_id(metadata); // TODO(guyn): use peer_id let stream_id = message.stream_id; let message_id = message.message_id; diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs index 26101c9647..4e0f64dd47 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs @@ -1,6 +1,14 @@ use futures::channel::mpsc; use futures::stream::StreamExt; +use futures::SinkExt; +use papyrus_network::network_manager::test_utils::{ + mock_register_broadcast_topic, + MockBroadcastedMessagesSender, + TestSubscriberChannels, +}; +use papyrus_network::network_manager::BroadcastTopicChannels; use papyrus_protobuf::consensus::{ConsensusMessage, Proposal, StreamMessage}; +use papyrus_test_utils::{get_rng, GetTestInstance}; use super::StreamHandler; @@ -8,6 +16,8 @@ use super::StreamHandler; mod tests { use std::time::Duration; + use papyrus_network_types::network_types::BroadcastedMessageManager; + use super::*; fn make_test_message( @@ -29,25 +39,43 @@ mod tests { matching == a.len() && matching == b.len() } + // TODO(guyn): should I make this a public function in `manager_test.rs` or just have a copy + // here? + async fn send( + sender: &mut MockBroadcastedMessagesSender>, + msg: StreamMessage, + ) { + let broadcasted_message_manager = + BroadcastedMessageManager::get_test_instance(&mut get_rng()); + sender.send((msg, broadcasted_message_manager)).await.unwrap(); + } + fn setup_test() -> ( StreamHandler, - mpsc::Sender>, + MockBroadcastedMessagesSender>, mpsc::Receiver>, ) { - let (tx_input, rx_input) = mpsc::channel::>(100); + let TestSubscriberChannels { mock_network, subscriber_channels } = + mock_register_broadcast_topic().unwrap(); + let network_sender = mock_network.broadcasted_messages_sender; + let BroadcastTopicChannels { broadcasted_messages_receiver, broadcast_topic_client: _ } = + subscriber_channels; + + // TODO(guyn): We should also give the broadcast_topic_client to the StreamHandler let (tx_output, rx_output) = mpsc::channel::>(100); - let handler = StreamHandler::new(tx_output, rx_input); - (handler, tx_input, rx_output) + let handler = StreamHandler::new(tx_output, broadcasted_messages_receiver); + (handler, network_sender, rx_output) } #[tokio::test] async fn stream_handler_in_order() { - let (mut stream_handler, mut tx_input, mut rx_output) = setup_test(); + let (mut stream_handler, mut network_sender, mut rx_output) = setup_test(); let stream_id = 127; for i in 0..10 { let message = make_test_message(stream_id, i, i == 9); - tx_input.try_send(message).expect("Send should succeed"); + // tx_input.try_send(message).expect("Send should succeed"); + send(&mut network_sender, message).await; } let join_handle = tokio::spawn(async move { @@ -66,12 +94,12 @@ mod tests { #[tokio::test] async fn stream_handler_in_reverse() { - let (mut stream_handler, mut tx_input, mut rx_output) = setup_test(); + let (mut stream_handler, mut network_sender, mut rx_output) = setup_test(); let stream_id = 127; for i in 0..5 { let message = make_test_message(stream_id, 5 - i, i == 0); - tx_input.try_send(message).expect("Send should succeed"); + send(&mut network_sender, message).await; } let join_handle = tokio::spawn(async move { @@ -93,8 +121,7 @@ mod tests { assert!(do_vecs_match(&keys, &range)); // Now send the last message: - tx_input.try_send(make_test_message(stream_id, 0, false)).expect("Send should succeed"); - + send(&mut network_sender, make_test_message(stream_id, 0, false)).await; let join_handle = tokio::spawn(async move { let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await; stream_handler @@ -112,29 +139,29 @@ mod tests { #[tokio::test] async fn stream_handler_multiple_streams() { - let (mut stream_handler, mut tx_input, mut rx_output) = setup_test(); + let (mut stream_handler, mut network_sender, mut rx_output) = setup_test(); let stream_id1 = 127; // Send all messages in order (except the first one). let stream_id2 = 10; // Send in reverse order (except the first one). - let stream_id3 = 1; // Send in two batches of 5 messages, without the first one, don't send fin. + let stream_id3 = 1; // Send in two batches, without the first one, don't send fin. for i in 1..10 { let message = make_test_message(stream_id1, i, i == 9); - tx_input.try_send(message).expect("Send should succeed"); + send(&mut network_sender, message).await; } for i in 0..5 { let message = make_test_message(stream_id2, 5 - i, i == 0); - tx_input.try_send(message).expect("Send should succeed"); + send(&mut network_sender, message).await; } for i in 5..10 { let message = make_test_message(stream_id3, i, false); - tx_input.try_send(message).expect("Send should succeed"); + send(&mut network_sender, message).await; } for i in 1..5 { let message = make_test_message(stream_id3, i, false); - tx_input.try_send(message).expect("Send should succeed"); + send(&mut network_sender, message).await; } let join_handle = tokio::spawn(async move { @@ -195,7 +222,7 @@ mod tests { assert!(receiver3.try_next().is_err()); // Send the last message on stream_id1: - tx_input.try_send(make_test_message(stream_id1, 0, false)).expect("Send should succeed"); + send(&mut network_sender, make_test_message(stream_id1, 0, false)).await; let join_handle = tokio::spawn(async move { let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await; stream_handler @@ -216,7 +243,7 @@ mod tests { assert!(stream_handler.stream_data.clone().into_keys().all(|item| values.contains(&item))); // Send the last message on stream_id2: - tx_input.try_send(make_test_message(stream_id2, 0, false)).expect("Send should succeed"); + send(&mut network_sender, make_test_message(stream_id2, 0, false)).await; let join_handle = tokio::spawn(async move { let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await; stream_handler @@ -237,7 +264,7 @@ mod tests { assert!(stream_handler.stream_data.clone().into_keys().all(|item| values.contains(&item))); // Send the last message on stream_id3: - tx_input.try_send(make_test_message(stream_id3, 0, false)).expect("Send should succeed"); + send(&mut network_sender, make_test_message(stream_id3, 0, false)).await; let join_handle = tokio::spawn(async move { let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await;