Skip to content

Commit

Permalink
ref(rust): Put consumer in an Arc so it can be shared (#4954)
Browse files Browse the repository at this point in the history
Initial refactoring so that the consumer can be shared and referenced
within the assignment callbacks. This is necessary so we can do things
like call commit on the consumer during rebalance callbacks.

This PR is based on #4936
which had to be reverted as the consumer fully stopped committing
when that was merged. Though this is only the initial refactoring
and does not actually commit during join yet.

At the moment my best guess of why the previous attempt did not work
in prod was the the context was being mutated after the consumer
was initially created. This version avoids doing this and keeps the
original approach of deferring consumer creation until
StreamProcessor.subscribe() is called.
  • Loading branch information
lynnagara authored Nov 2, 2023
1 parent 9738bca commit f1743d8
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 57 deletions.
3 changes: 2 additions & 1 deletion rust_snuba/rust_arroyo/examples/base_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use rust_arroyo::processing::strategies::commit_offsets::CommitOffsets;
use rust_arroyo::processing::strategies::{ProcessingStrategy, ProcessingStrategyFactory};
use rust_arroyo::processing::StreamProcessor;
use rust_arroyo::types::Topic;
use std::sync::{Arc, Mutex};
use std::time::Duration;

struct TestFactory {}
Expand All @@ -24,7 +25,7 @@ fn main() {
false,
None,
);
let consumer = Box::new(KafkaConsumer::new(config));
let consumer = Arc::new(Mutex::new(KafkaConsumer::new(config)));
let topic = Topic {
name: "test_static".to_string(),
};
Expand Down
3 changes: 2 additions & 1 deletion rust_snuba/rust_arroyo/examples/transform_and_produce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use rust_arroyo::processing::strategies::{
};
use rust_arroyo::processing::StreamProcessor;
use rust_arroyo::types::{Message, Topic, TopicOrPartition};
use std::sync::{Arc, Mutex};
use std::time::Duration;

fn reverse_string(value: KafkaPayload) -> Result<KafkaPayload, InvalidMessage> {
Expand Down Expand Up @@ -73,7 +74,7 @@ async fn main() {
None,
);

let consumer = Box::new(KafkaConsumer::new(config.clone()));
let consumer = Arc::new(Mutex::new(KafkaConsumer::new(config.clone())));
let mut processor = StreamProcessor::new(
consumer,
Box::new(ReverseStringAndProduceStrategyFactory {
Expand Down
2 changes: 1 addition & 1 deletion rust_snuba/rust_arroyo/src/backends/kafka/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl KafkaConsumer {
}
}

impl<'a> ArroyoConsumer<'a, KafkaPayload> for KafkaConsumer {
impl ArroyoConsumer<KafkaPayload> for KafkaConsumer {
fn subscribe(
&mut self,
topics: &[Topic],
Expand Down
2 changes: 1 addition & 1 deletion rust_snuba/rust_arroyo/src/backends/local/broker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl From<TopicDoesNotExist> for BrokerError {
}
}

impl<TPayload: Clone> LocalBroker<TPayload> {
impl<TPayload: Clone + Send> LocalBroker<TPayload> {
pub fn new(storage: Box<dyn MessageStorage<TPayload>>, clock: Box<dyn Clock>) -> Self {
Self {
storage,
Expand Down
33 changes: 14 additions & 19 deletions rust_snuba/rust_arroyo/src/backends/local/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ struct SubscriptionState {
last_eof_at: HashMap<Partition, u64>,
}

pub struct LocalConsumer<'a, TPayload: Clone> {
pub struct LocalConsumer<TPayload: Clone> {
id: Uuid,
group: String,
broker: &'a mut LocalBroker<TPayload>,
broker: LocalBroker<TPayload>,
pending_callback: VecDeque<Callback>,
paused: HashSet<Partition>,
// The offset that a the last ``EndOfPartition`` exception that was
Expand All @@ -40,10 +40,10 @@ pub struct LocalConsumer<'a, TPayload: Clone> {
closed: bool,
}

impl<'a, TPayload: Clone> LocalConsumer<'a, TPayload> {
impl<TPayload: Clone> LocalConsumer<TPayload> {
pub fn new(
id: Uuid,
broker: &'a mut LocalBroker<TPayload>,
broker: LocalBroker<TPayload>,
group: String,
enable_end_of_partition: bool,
) -> Self {
Expand All @@ -68,7 +68,7 @@ impl<'a, TPayload: Clone> LocalConsumer<'a, TPayload> {
}
}

impl<'a, TPayload: Clone> Consumer<'a, TPayload> for LocalConsumer<'a, TPayload> {
impl<TPayload: Clone + Send> Consumer<TPayload> for LocalConsumer<TPayload> {
fn subscribe(
&mut self,
topics: &[Topic],
Expand Down Expand Up @@ -327,7 +327,7 @@ mod tests {

#[test]
fn test_consumer_subscription() {
let mut broker = build_broker();
let broker = build_broker();

let topic1 = Topic {
name: "test1".to_string(),
Expand All @@ -337,8 +337,7 @@ mod tests {
};

let my_callbacks: Box<dyn AssignmentCallbacks> = Box::new(EmptyCallbacks {});
let mut consumer =
LocalConsumer::new(Uuid::nil(), &mut broker, "test_group".to_string(), true);
let mut consumer = LocalConsumer::new(Uuid::nil(), broker, "test_group".to_string(), true);
assert!(consumer.subscription_state.topics.is_empty());

let res = consumer.subscribe(&[topic1.clone(), topic2.clone()], my_callbacks);
Expand Down Expand Up @@ -381,7 +380,7 @@ mod tests {

#[test]
fn test_subscription_callback() {
let mut broker = build_broker();
let broker = build_broker();

let topic1 = Topic {
name: "test1".to_string(),
Expand Down Expand Up @@ -455,8 +454,7 @@ mod tests {

let my_callbacks: Box<dyn AssignmentCallbacks> = Box::new(TheseCallbacks {});

let mut consumer =
LocalConsumer::new(Uuid::nil(), &mut broker, "test_group".to_string(), true);
let mut consumer = LocalConsumer::new(Uuid::nil(), broker, "test_group".to_string(), true);

let _ = consumer.subscribe(&[topic1, topic2], my_callbacks);
let _ = consumer.poll(Some(Duration::from_millis(100)));
Expand Down Expand Up @@ -500,8 +498,7 @@ mod tests {
}

let my_callbacks: Box<dyn AssignmentCallbacks> = Box::new(TheseCallbacks {});
let mut consumer =
LocalConsumer::new(Uuid::nil(), &mut broker, "test_group".to_string(), true);
let mut consumer = LocalConsumer::new(Uuid::nil(), broker, "test_group".to_string(), true);

let _ = consumer.subscribe(&[topic2], my_callbacks);

Expand All @@ -523,7 +520,7 @@ mod tests {

#[test]
fn test_paused() {
let mut broker = build_broker();
let broker = build_broker();
let topic2 = Topic {
name: "test2".to_string(),
};
Expand All @@ -532,8 +529,7 @@ mod tests {
index: 0,
};
let my_callbacks: Box<dyn AssignmentCallbacks> = Box::new(EmptyCallbacks {});
let mut consumer =
LocalConsumer::new(Uuid::nil(), &mut broker, "test_group".to_string(), false);
let mut consumer = LocalConsumer::new(Uuid::nil(), broker, "test_group".to_string(), false);
let _ = consumer.subscribe(&[topic2], my_callbacks);

assert_eq!(consumer.poll(None).unwrap(), None);
Expand All @@ -549,10 +545,9 @@ mod tests {

#[test]
fn test_commit() {
let mut broker = build_broker();
let broker = build_broker();
let my_callbacks: Box<dyn AssignmentCallbacks> = Box::new(EmptyCallbacks {});
let mut consumer =
LocalConsumer::new(Uuid::nil(), &mut broker, "test_group".to_string(), false);
let mut consumer = LocalConsumer::new(Uuid::nil(), broker, "test_group".to_string(), false);
let topic2 = Topic {
name: "test2".to_string(),
};
Expand Down
4 changes: 2 additions & 2 deletions rust_snuba/rust_arroyo/src/backends/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub enum ProducerError {

/// This is basically an observer pattern to receive the callbacks from
/// the consumer when partitions are assigned/revoked.
pub trait AssignmentCallbacks: Send + Sync {
pub trait AssignmentCallbacks: Send {
fn on_assign(&mut self, partitions: HashMap<Partition, u64>);
fn on_revoke(&mut self, partitions: Vec<Partition>);
}
Expand Down Expand Up @@ -80,7 +80,7 @@ pub trait AssignmentCallbacks: Send + Sync {
/// occurs even if the consumer retains ownership of the partition across
/// assignments.) For this reason, it is generally good practice to ensure
/// offsets are committed as part of the revocation callback.
pub trait Consumer<'a, TPayload: Clone> {
pub trait Consumer<TPayload: Clone>: Send {
fn subscribe(
&mut self,
topic: &[Topic],
Expand Down
2 changes: 1 addition & 1 deletion rust_snuba/rust_arroyo/src/backends/storages/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<TPayload: Clone> Default for MemoryMessageStorage<TPayload> {
}
}

impl<TPayload: Clone> MessageStorage<TPayload> for MemoryMessageStorage<TPayload> {
impl<TPayload: Clone + Send> MessageStorage<TPayload> for MemoryMessageStorage<TPayload> {
fn create_topic(&mut self, topic: Topic, partitions: u16) -> Result<(), TopicExists> {
if self.topics.contains_key(&topic) {
return Err(TopicExists);
Expand Down
2 changes: 1 addition & 1 deletion rust_snuba/rust_arroyo/src/backends/storages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum ConsumeError {
OffsetOutOfRange,
}

pub trait MessageStorage<TPayload: Clone> {
pub trait MessageStorage<TPayload: Clone + Send>: Send {
// Create a topic with the given number of partitions.
//
// If the topic already exists, a ``TopicExists`` exception will be
Expand Down
87 changes: 60 additions & 27 deletions rust_snuba/rust_arroyo/src/processing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ impl<TPayload: Clone> Callbacks<TPayload> {
/// instance and a ``ProcessingStrategy``, ensuring that processing
/// strategies are instantiated on partition assignment and closed on
/// partition revocation.
pub struct StreamProcessor<'a, TPayload: Clone> {
consumer: Box<dyn Consumer<'a, TPayload> + 'a>,
pub struct StreamProcessor<TPayload: Clone> {
consumer: Arc<Mutex<dyn Consumer<TPayload>>>,
strategies: Arc<Mutex<Strategies<TPayload>>>,
message: Option<Message<TPayload>>,
processor_handle: ProcessorHandle,
Expand All @@ -104,9 +104,9 @@ pub struct StreamProcessor<'a, TPayload: Clone> {
metrics_buffer: metrics_buffer::MetricsBuffer,
}

impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {
impl<TPayload: 'static + Clone> StreamProcessor<TPayload> {
pub fn new(
consumer: Box<dyn Consumer<'a, TPayload> + 'a>,
consumer: Arc<Mutex<dyn Consumer<TPayload>>>,
processing_factory: Box<dyn ProcessingStrategyFactory<TPayload>>,
) -> Self {
let strategies = Arc::new(Mutex::new(Strategies {
Expand All @@ -130,14 +130,23 @@ impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {
pub fn subscribe(&mut self, topic: Topic) {
let callbacks: Box<dyn AssignmentCallbacks> =
Box::new(Callbacks::new(self.strategies.clone()));
self.consumer.subscribe(&[topic], callbacks).unwrap();
self.consumer
.lock()
.unwrap()
.subscribe(&[topic], callbacks)
.unwrap();
}

pub fn run_once(&mut self) -> Result<(), RunError> {
if self.is_paused {
// If the consumer waas paused, it should not be returning any messages
// on ``poll``.
let res = self.consumer.poll(Some(Duration::ZERO)).unwrap();
let res = self
.consumer
.lock()
.unwrap()
.poll(Some(Duration::ZERO))
.unwrap();

match res {
None => {}
Expand All @@ -148,7 +157,12 @@ impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {
// even if there is no active assignment and/or processing strategy.
let poll_start = Instant::now();
//TODO: Support errors properly
match self.consumer.poll(Some(Duration::from_secs(1))) {
match self
.consumer
.lock()
.unwrap()
.poll(Some(Duration::from_secs(1)))
{
Ok(msg) => {
self.message = msg.map(|inner| Message {
inner_message: InnerMessage::BrokerMessage(inner),
Expand All @@ -174,8 +188,12 @@ impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {
match commit_request {
Ok(None) => {}
Ok(Some(request)) => {
self.consumer.stage_offsets(request.positions).unwrap();
self.consumer.commit_offsets().unwrap();
self.consumer
.lock()
.unwrap()
.stage_offsets(request.positions)
.unwrap();
self.consumer.lock().unwrap().commit_offsets().unwrap();
}
Err(e) => {
println!("TODOO: Handle invalid message {:?}", e);
Expand All @@ -194,10 +212,17 @@ impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {
Ok(()) => {
// Resume if we are currently in a paused state
if self.is_paused {
let partitions: std::collections::HashSet<Partition> =
self.consumer.tell().unwrap().keys().cloned().collect();

let res = self.consumer.resume(partitions);
let partitions: std::collections::HashSet<Partition> = self
.consumer
.lock()
.unwrap()
.tell()
.unwrap()
.keys()
.cloned()
.collect();

let res = self.consumer.lock().unwrap().resume(partitions);
match res {
Ok(()) => {
self.is_paused = false;
Expand Down Expand Up @@ -236,10 +261,17 @@ impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {

log::warn!("Consumer is in backpressure state for more than 1 second, pausing",);

let partitions =
self.consumer.tell().unwrap().keys().cloned().collect();

let res = self.consumer.pause(partitions);
let partitions = self
.consumer
.lock()
.unwrap()
.tell()
.unwrap()
.keys()
.cloned()
.collect();

let res = self.consumer.lock().unwrap().pause(partitions);
match res {
Ok(()) => {
self.is_paused = true;
Expand Down Expand Up @@ -279,7 +311,7 @@ impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {
}
}
drop(trait_callbacks); // unlock mutex so we can close consumer
self.consumer.close();
self.consumer.lock().unwrap().close();
return Err(e);
}
}
Expand All @@ -293,11 +325,11 @@ impl<'a, TPayload: 'static + Clone> StreamProcessor<'a, TPayload> {
}

pub fn shutdown(&mut self) {
self.consumer.close();
self.consumer.lock().unwrap().close();
}

pub fn tell(self) -> HashMap<Partition, u64> {
self.consumer.tell().unwrap()
self.consumer.lock().unwrap().tell().unwrap()
}
}

Expand All @@ -313,6 +345,7 @@ mod tests {
use crate::types::{Message, Partition, Topic};
use crate::utils::clock::SystemClock;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use uuid::Uuid;

Expand Down Expand Up @@ -366,13 +399,13 @@ mod tests {

#[test]
fn test_processor() {
let mut broker = build_broker();
let consumer = Box::new(LocalConsumer::new(
let broker = build_broker();
let consumer = Arc::new(Mutex::new(LocalConsumer::new(
Uuid::nil(),
&mut broker,
broker,
"test_group".to_string(),
false,
));
)));

let mut processor = StreamProcessor::new(consumer, Box::new(TestFactory {}));
processor.subscribe(Topic {
Expand All @@ -395,12 +428,12 @@ mod tests {
let _ = broker.produce(&partition, "message1".to_string());
let _ = broker.produce(&partition, "message2".to_string());

let consumer = Box::new(LocalConsumer::new(
let consumer = Arc::new(Mutex::new(LocalConsumer::new(
Uuid::nil(),
&mut broker,
broker,
"test_group".to_string(),
false,
));
)));

let mut processor = StreamProcessor::new(consumer, Box::new(TestFactory {}));
processor.subscribe(Topic {
Expand Down
2 changes: 1 addition & 1 deletion rust_snuba/rust_arroyo/src/utils/clock.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::thread::sleep;
use std::time::{Duration, SystemTime};

pub trait Clock {
pub trait Clock: Send {
fn time(&self) -> SystemTime;

fn sleep(self, duration: Duration);
Expand Down
Loading

0 comments on commit f1743d8

Please sign in to comment.