Skip to content

Commit

Permalink
feat(consensus): use network channel instead of regular channel as in…
Browse files Browse the repository at this point in the history
…put for StreamHandler (#1084)
  • Loading branch information
guy-starkware authored Oct 7, 2024
1 parent 89243a1 commit 8ed234e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 23 deletions.
26 changes: 22 additions & 4 deletions crates/sequencing/papyrus_consensus/src/stream_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::collections::{BTreeMap, HashMap};

use futures::channel::mpsc;
use futures::StreamExt;
use papyrus_network::network_manager::BroadcastTopicServer;
use papyrus_network_types::network_types::{BroadcastedMessageManager, OpaquePeerId};
use papyrus_protobuf::consensus::StreamMessage;
use papyrus_protobuf::converters::ProtobufConversionError;
use tracing::{instrument, warn};
Expand All @@ -12,11 +14,16 @@ use tracing::{instrument, warn};
#[path = "stream_handler_test.rs"]
mod stream_handler_test;

type PeerId = OpaquePeerId;
type StreamId = u64;
type MessageId = u64;

const CHANNEL_BUFFER_LENGTH: usize = 100;

fn get_metadata_peer_id(metadata: BroadcastedMessageManager) -> PeerId {
metadata.originator_id
}

#[derive(Debug, Clone)]
struct StreamData<T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError>> {
// The next message_id that is expected.
Expand Down Expand Up @@ -53,7 +60,7 @@ pub struct StreamHandler<
// An end of a channel used to send out receivers, one for each stream.
sender: mpsc::Sender<mpsc::Receiver<T>>,
// An end of a channel used to receive messages.
receiver: mpsc::Receiver<StreamMessage<T>>,
receiver: BroadcastTopicServer<StreamMessage<T>>,

// A map from stream_id to a struct that contains all the information about the stream.
// This includes both the message buffer and some metadata (like the latest message_id).
Expand All @@ -67,7 +74,7 @@ impl<T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError
/// Create a new StreamHandler.
pub fn new(
sender: mpsc::Sender<mpsc::Receiver<T>>,
receiver: mpsc::Receiver<StreamMessage<T>>,
receiver: BroadcastTopicServer<StreamMessage<T>>,
) -> Self {
StreamHandler { sender, receiver, stream_data: HashMap::new() }
}
Expand All @@ -89,9 +96,20 @@ impl<T: Clone + Into<Vec<u8>> + TryFrom<Vec<u8>, Error = ProtobufConversionError
data.next_message_id += 1;
}

// Handle the message, return true if the channel is still open.
#[instrument(skip_all, level = "warn")]
fn handle_message(&mut self, message: StreamMessage<T>) {
fn handle_message(
&mut self,
message: (Result<StreamMessage<T>, ProtobufConversionError>, BroadcastedMessageManager),
) {
let (message, metadata) = message;
let message = match message {
Ok(message) => message,
Err(e) => {
warn!("Error converting message: {:?}", e);
return;
}
};
let _peer_id = get_metadata_peer_id(metadata); // TODO(guyn): use peer_id
let stream_id = message.stream_id;
let message_id = message.message_id;

Expand Down
65 changes: 46 additions & 19 deletions crates/sequencing/papyrus_consensus/src/stream_handler_test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
use futures::channel::mpsc;
use futures::stream::StreamExt;
use futures::SinkExt;
use papyrus_network::network_manager::test_utils::{
mock_register_broadcast_topic,
MockBroadcastedMessagesSender,
TestSubscriberChannels,
};
use papyrus_network::network_manager::BroadcastTopicChannels;
use papyrus_protobuf::consensus::{ConsensusMessage, Proposal, StreamMessage};
use papyrus_test_utils::{get_rng, GetTestInstance};

use super::StreamHandler;

#[cfg(test)]
mod tests {
use std::time::Duration;

use papyrus_network_types::network_types::BroadcastedMessageManager;

use super::*;

fn make_test_message(
Expand All @@ -29,25 +39,43 @@ mod tests {
matching == a.len() && matching == b.len()
}

// TODO(guyn): should I make this a public function in `manager_test.rs` or just have a copy
// here?
async fn send(
sender: &mut MockBroadcastedMessagesSender<StreamMessage<ConsensusMessage>>,
msg: StreamMessage<ConsensusMessage>,
) {
let broadcasted_message_manager =
BroadcastedMessageManager::get_test_instance(&mut get_rng());
sender.send((msg, broadcasted_message_manager)).await.unwrap();
}

fn setup_test() -> (
StreamHandler<ConsensusMessage>,
mpsc::Sender<StreamMessage<ConsensusMessage>>,
MockBroadcastedMessagesSender<StreamMessage<ConsensusMessage>>,
mpsc::Receiver<mpsc::Receiver<ConsensusMessage>>,
) {
let (tx_input, rx_input) = mpsc::channel::<StreamMessage<ConsensusMessage>>(100);
let TestSubscriberChannels { mock_network, subscriber_channels } =
mock_register_broadcast_topic().unwrap();
let network_sender = mock_network.broadcasted_messages_sender;
let BroadcastTopicChannels { broadcasted_messages_receiver, broadcast_topic_client: _ } =
subscriber_channels;

// TODO(guyn): We should also give the broadcast_topic_client to the StreamHandler
let (tx_output, rx_output) = mpsc::channel::<mpsc::Receiver<ConsensusMessage>>(100);
let handler = StreamHandler::new(tx_output, rx_input);
(handler, tx_input, rx_output)
let handler = StreamHandler::new(tx_output, broadcasted_messages_receiver);
(handler, network_sender, rx_output)
}

#[tokio::test]
async fn stream_handler_in_order() {
let (mut stream_handler, mut tx_input, mut rx_output) = setup_test();
let (mut stream_handler, mut network_sender, mut rx_output) = setup_test();

let stream_id = 127;
for i in 0..10 {
let message = make_test_message(stream_id, i, i == 9);
tx_input.try_send(message).expect("Send should succeed");
// tx_input.try_send(message).expect("Send should succeed");
send(&mut network_sender, message).await;
}

let join_handle = tokio::spawn(async move {
Expand All @@ -66,12 +94,12 @@ mod tests {

#[tokio::test]
async fn stream_handler_in_reverse() {
let (mut stream_handler, mut tx_input, mut rx_output) = setup_test();
let (mut stream_handler, mut network_sender, mut rx_output) = setup_test();

let stream_id = 127;
for i in 0..5 {
let message = make_test_message(stream_id, 5 - i, i == 0);
tx_input.try_send(message).expect("Send should succeed");
send(&mut network_sender, message).await;
}

let join_handle = tokio::spawn(async move {
Expand All @@ -93,8 +121,7 @@ mod tests {
assert!(do_vecs_match(&keys, &range));

// Now send the last message:
tx_input.try_send(make_test_message(stream_id, 0, false)).expect("Send should succeed");

send(&mut network_sender, make_test_message(stream_id, 0, false)).await;
let join_handle = tokio::spawn(async move {
let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await;
stream_handler
Expand All @@ -112,29 +139,29 @@ mod tests {

#[tokio::test]
async fn stream_handler_multiple_streams() {
let (mut stream_handler, mut tx_input, mut rx_output) = setup_test();
let (mut stream_handler, mut network_sender, mut rx_output) = setup_test();

let stream_id1 = 127; // Send all messages in order (except the first one).
let stream_id2 = 10; // Send in reverse order (except the first one).
let stream_id3 = 1; // Send in two batches of 5 messages, without the first one, don't send fin.
let stream_id3 = 1; // Send in two batches, without the first one, don't send fin.

for i in 1..10 {
let message = make_test_message(stream_id1, i, i == 9);
tx_input.try_send(message).expect("Send should succeed");
send(&mut network_sender, message).await;
}

for i in 0..5 {
let message = make_test_message(stream_id2, 5 - i, i == 0);
tx_input.try_send(message).expect("Send should succeed");
send(&mut network_sender, message).await;
}

for i in 5..10 {
let message = make_test_message(stream_id3, i, false);
tx_input.try_send(message).expect("Send should succeed");
send(&mut network_sender, message).await;
}
for i in 1..5 {
let message = make_test_message(stream_id3, i, false);
tx_input.try_send(message).expect("Send should succeed");
send(&mut network_sender, message).await;
}

let join_handle = tokio::spawn(async move {
Expand Down Expand Up @@ -195,7 +222,7 @@ mod tests {
assert!(receiver3.try_next().is_err());

// Send the last message on stream_id1:
tx_input.try_send(make_test_message(stream_id1, 0, false)).expect("Send should succeed");
send(&mut network_sender, make_test_message(stream_id1, 0, false)).await;
let join_handle = tokio::spawn(async move {
let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await;
stream_handler
Expand All @@ -216,7 +243,7 @@ mod tests {
assert!(stream_handler.stream_data.clone().into_keys().all(|item| values.contains(&item)));

// Send the last message on stream_id2:
tx_input.try_send(make_test_message(stream_id2, 0, false)).expect("Send should succeed");
send(&mut network_sender, make_test_message(stream_id2, 0, false)).await;
let join_handle = tokio::spawn(async move {
let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await;
stream_handler
Expand All @@ -237,7 +264,7 @@ mod tests {
assert!(stream_handler.stream_data.clone().into_keys().all(|item| values.contains(&item)));

// Send the last message on stream_id3:
tx_input.try_send(make_test_message(stream_id3, 0, false)).expect("Send should succeed");
send(&mut network_sender, make_test_message(stream_id3, 0, false)).await;

let join_handle = tokio::spawn(async move {
let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.listen()).await;
Expand Down

0 comments on commit 8ed234e

Please sign in to comment.