diff --git a/src/coordinator.rs b/src/coordinator.rs index a8d5aad..38c602c 100644 --- a/src/coordinator.rs +++ b/src/coordinator.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::time::Duration; +use aws_sdk_sqs::types::Message; use eyre::{Context, ContextCompat}; use futures::stream::FuturesUnordered; use futures::{future, StreamExt}; @@ -11,7 +12,6 @@ use tokio::net::TcpStream; use tokio::sync::mpsc::Receiver; use tokio::sync::{mpsc, Mutex}; use tokio::task::JoinHandle; -use tracing::instrument; use crate::bits::Bits; use crate::config::CoordinatorConfig; @@ -90,47 +90,52 @@ impl Coordinator { self: Arc, ) -> Result<(), eyre::Error> { loop { - self.process_uniqueness_check_queue().await?; + let messages = sqs_dequeue( + &self.sqs_client, + &self.config.queues.queries_queue_url, + ) + .await?; + + for message in messages { + self.handle_uniqueness_check(message).await?; + } } } - #[tracing::instrument(skip(self))] - pub async fn process_uniqueness_check_queue(&self) -> eyre::Result<()> { - let messages = sqs_dequeue( - &self.sqs_client, - &self.config.queues.queries_queue_url, - ) - .await?; + #[tracing::instrument(skip(self, message))] + pub async fn handle_uniqueness_check( + &self, + message: Message, + ) -> eyre::Result<()> { + tracing::debug!(?message, "Handling message"); + + let receipt_handle = message + .receipt_handle + .context("Missing receipt handle in message")?; + + if let Some(message_attributes) = &message.message_attributes { + utils::aws::trace_from_message_attributes( + message_attributes, + &receipt_handle, + )?; + } else { + tracing::warn!( + ?receipt_handle, + "SQS message missing message attributes" + ); + } - for message in messages { - let receipt_handle = message - .receipt_handle - .context("Missing receipt handle in message")?; - - if let Some(message_attributes) = &message.message_attributes { - utils::aws::trace_from_message_attributes( - message_attributes, - &receipt_handle, - )?; - } else { - tracing::warn!( - ?receipt_handle, - "SQS message missing message attributes" - ); - } + let body = message.body.context("Missing message body")?; - let body = message.body.context("Missing message body")?; + let UniquenessCheckRequest { + plain_code: template, + signup_id, + } = serde_json::from_str(&body).context("Failed to parse message")?; - let UniquenessCheckRequest { - plain_code: template, - signup_id, - } = serde_json::from_str(&body) - .context("Failed to parse message")?; + // Process the query + self.uniqueness_check(receipt_handle, template, signup_id) + .await?; - // Process the query - self.uniqueness_check(receipt_handle, template, signup_id) - .await?; - } Ok(()) } @@ -445,58 +450,68 @@ impl Coordinator { async fn handle_db_sync(self: Arc) -> eyre::Result<()> { loop { - self.db_sync().await?; + let messages = sqs_dequeue( + &self.sqs_client, + &self.config.queues.db_sync_queue_url, + ) + .await?; + + if messages.is_empty() { + tokio::time::sleep(IDLE_SLEEP_TIME).await; + return Ok(()); + } + + for message in messages { + self.db_sync(message).await?; + } } } - #[instrument(skip(self))] - async fn db_sync(&self) -> eyre::Result<()> { - let messages = sqs_dequeue( - &self.sqs_client, - &self.config.queues.db_sync_queue_url, - ) - .await?; - - if messages.is_empty() { - tokio::time::sleep(IDLE_SLEEP_TIME).await; - return Ok(()); + #[tracing::instrument(skip(self, message))] + async fn db_sync(&self, message: Message) -> eyre::Result<()> { + let receipt_handle = message + .receipt_handle + .context("Missing receipt handle in message")?; + + if let Some(message_attributes) = &message.message_attributes { + utils::aws::trace_from_message_attributes( + message_attributes, + &receipt_handle, + )?; + } else { + tracing::warn!( + ?receipt_handle, + "SQS message missing message attributes" + ); } - for message in messages { - let body = message.body.context("Missing message body")?; - let receipt_handle = message - .receipt_handle - .context("Missing receipt handle in message")?; + let body = message.body.context("Missing message body")?; - let items = if let Ok(items) = - serde_json::from_str::>(&body) - { - items - } else { - tracing::error!( - ?receipt_handle, - "Failed to parse message body" - ); - continue; - }; + let items = if let Ok(items) = + serde_json::from_str::>(&body) + { + items + } else { + tracing::error!(?receipt_handle, "Failed to parse message body"); + return Ok(()); + }; - let masks: Vec<_> = - items.into_iter().map(|item| (item.id, item.mask)).collect(); + let masks: Vec<_> = + items.into_iter().map(|item| (item.id, item.mask)).collect(); - tracing::info!( - num_new_masks = masks.len(), - "Inserting masks into database" - ); + tracing::info!( + num_new_masks = masks.len(), + "Inserting masks into database" + ); - self.database.insert_masks(&masks).await?; + self.database.insert_masks(&masks).await?; - sqs_delete_message( - &self.sqs_client, - &self.config.queues.db_sync_queue_url, - receipt_handle, - ) - .await?; - } + sqs_delete_message( + &self.sqs_client, + &self.config.queues.db_sync_queue_url, + receipt_handle, + ) + .await?; Ok(()) } diff --git a/src/participant.rs b/src/participant.rs index d6aa776..34bb718 100644 --- a/src/participant.rs +++ b/src/participant.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::time::Duration; +use aws_sdk_sqs::types::Message; use distance::Template; use eyre::ContextCompat; use futures::stream::FuturesUnordered; @@ -17,6 +18,7 @@ use tracing::instrument; use crate::config::ParticipantConfig; use crate::db::Db; use crate::distance::{self, encode, DistanceEngine, EncodedBits}; +use crate::utils; use crate::utils::aws::{ sqs_client_from_config, sqs_delete_message, sqs_dequeue, }; @@ -61,7 +63,7 @@ impl Participant { let mut tasks = FuturesUnordered::new(); - tasks.push(tokio::spawn(self.clone().handle_uniqueness_check())); + tasks.push(tokio::spawn(self.clone().handle_uniqueness_checks())); tasks.push(tokio::spawn(self.clone().handle_db_sync())); while let Some(result) = tasks.next().await { @@ -71,16 +73,19 @@ impl Participant { Ok(()) } - async fn handle_uniqueness_check(self: Arc) -> eyre::Result<()> { + async fn handle_uniqueness_checks(self: Arc) -> eyre::Result<()> { loop { - self.process_uniqueness_check_stream().await?; + let stream = self.listener.accept().await?.0; + self.handle_uniqueness_check(stream).await?; } } #[tracing::instrument(skip(self))] - async fn process_uniqueness_check_stream(&self) -> eyre::Result<()> { - let mut stream = - tokio::io::BufWriter::new(self.listener.accept().await?.0); + async fn handle_uniqueness_check( + &self, + stream: TcpStream, + ) -> eyre::Result<()> { + let mut stream = tokio::io::BufWriter::new(stream); // Process the trace and span ids to correlate traces between services self.handle_traces_payload(&mut stream).await?; @@ -163,59 +168,69 @@ impl Participant { async fn handle_db_sync(self: Arc) -> eyre::Result<()> { loop { - self.db_sync().await?; + let messages = sqs_dequeue( + &self.sqs_client, + &self.config.queues.db_sync_queue_url, + ) + .await?; + + if messages.is_empty() { + tokio::time::sleep(IDLE_SLEEP_TIME).await; + return Ok(()); + } + + for message in messages { + self.db_sync(message).await?; + } } } - #[tracing::instrument(skip(self))] - async fn db_sync(&self) -> eyre::Result<()> { - let messages = sqs_dequeue( - &self.sqs_client, - &self.config.queues.db_sync_queue_url, - ) - .await?; + #[tracing::instrument(skip(self, message))] + async fn db_sync(&self, message: Message) -> eyre::Result<()> { + let receipt_handle = message + .receipt_handle + .context("Missing receipt handle in message")?; + + if let Some(message_attributes) = &message.message_attributes { + utils::aws::trace_from_message_attributes( + message_attributes, + &receipt_handle, + )?; + } else { + tracing::warn!( + ?receipt_handle, + "SQS message missing message attributes" + ); + } - if messages.is_empty() { - tokio::time::sleep(IDLE_SLEEP_TIME).await; + let body = message.body.context("Missing message body")?; + + let items = if let Ok(items) = + serde_json::from_str::>(&body) + { + items + } else { + tracing::error!(?receipt_handle, "Failed to parse message body"); return Ok(()); - } + }; - for message in messages { - let body = message.body.context("Missing message body")?; - let receipt_handle = message - .receipt_handle - .context("Missing receipt handle in message")?; - - let items = if let Ok(items) = - serde_json::from_str::>(&body) - { - items - } else { - tracing::error!( - ?receipt_handle, - "Failed to parse message body" - ); - continue; - }; - - let shares: Vec<_> = items - .into_iter() - .map(|item| (item.id, item.share)) - .collect(); - - tracing::info!( - num_new_shares = shares.len(), - "Inserting shares into database" - ); - self.database.insert_shares(&shares).await?; + let shares: Vec<_> = items + .into_iter() + .map(|item| (item.id, item.share)) + .collect(); - sqs_delete_message( - &self.sqs_client, - &self.config.queues.db_sync_queue_url, - receipt_handle, - ) - .await?; - } + tracing::info!( + num_new_shares = shares.len(), + "Inserting shares into database" + ); + self.database.insert_shares(&shares).await?; + + sqs_delete_message( + &self.sqs_client, + &self.config.queues.db_sync_queue_url, + receipt_handle, + ) + .await?; Ok(()) } diff --git a/src/utils/aws.rs b/src/utils/aws.rs index b388dde..75d9870 100644 --- a/src/utils/aws.rs +++ b/src/utils/aws.rs @@ -63,17 +63,17 @@ pub async fn sqs_dequeue( Ok(messages) } -#[tracing::instrument(skip(client, message))] +#[tracing::instrument(skip(client, payload))] pub async fn sqs_enqueue( client: &aws_sdk_sqs::Client, queue_url: &str, message_group_id: &str, - message: T, + payload: T, ) -> eyre::Result<()> where T: Serialize + Debug, { - let body = serde_json::to_string(&message) + let body = serde_json::to_string(&payload) .wrap_err("Failed to serialize message")?; let message_attributes = construct_message_attributes()?; @@ -87,7 +87,7 @@ where .send() .await?; - tracing::info!(?send_message_output, ?message, "Enqueued message"); + tracing::info!(?send_message_output, ?payload, "Enqueued message"); Ok(()) }