Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
0xKitsune committed Feb 28, 2024
2 parents 5499fe5 + 36b3edc commit 24ab7b4
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 216 deletions.
10 changes: 3 additions & 7 deletions bin/e2e/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async fn main() -> eyre::Result<()> {

let masks = db.fetch_masks(0).await?;

(masks.len() as u64).checked_sub(1)
masks.len() as u64
};

let participant_db_sync_queues = vec![
Expand Down Expand Up @@ -169,18 +169,14 @@ async fn main() -> eyre::Result<()> {
);
}
} else {
if let Some(id) = next_serial_id.as_mut() {
*id += 1;
} else {
next_serial_id = Some(0);
}
next_serial_id += 1;

common::seed_db_sync(
&sqs_client,
&config.db_sync.coordinator_db_sync_queue,
&participant_db_sync_queues,
template,
next_serial_id.context("Could not get next serial id")?,
next_serial_id,
)
.await?;

Expand Down
244 changes: 109 additions & 135 deletions src/coordinator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;
use std::time::Duration;

use aws_sdk_sqs::types::Message;
use eyre::{Context, ContextCompat};
use futures::stream::FuturesUnordered;
use futures::{future, StreamExt};
Expand All @@ -11,7 +12,6 @@ use tokio::net::TcpStream;
use tokio::sync::mpsc::Receiver;
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use tracing::instrument;

use crate::bits::Bits;
use crate::config::CoordinatorConfig;
Expand Down Expand Up @@ -90,58 +90,64 @@ impl Coordinator {
self: Arc<Self>,
) -> Result<(), eyre::Error> {
loop {
self.process_uniqueness_check_queue().await?;
}
}

#[tracing::instrument(skip(self))]
pub async fn process_uniqueness_check_queue(&self) -> eyre::Result<()> {
let messages = sqs_dequeue(
&self.sqs_client,
&self.config.queues.queries_queue_url,
)
.await?;
let messages = sqs_dequeue(
&self.sqs_client,
&self.config.queues.queries_queue_url,
)
.await?;

for message in messages {
let receipt_handle = message
.receipt_handle
.context("Missing receipt handle in message")?;

if let Some(message_attributes) = &message.message_attributes {
utils::aws::trace_from_message_attributes(
message_attributes,
&receipt_handle,
)?;
} else {
tracing::warn!(
?receipt_handle,
"SQS message missing message attributes"
);
for message in messages {
self.handle_uniqueness_check(message).await?;
}
}
}

let body = message.body.context("Missing message body")?;
#[tracing::instrument(skip(self, message))]
pub async fn handle_uniqueness_check(
&self,
message: Message,
) -> eyre::Result<()> {
tracing::debug!(?message, "Handling message");

let receipt_handle = message
.receipt_handle
.context("Missing receipt handle in message")?;

if let Some(message_attributes) = &message.message_attributes {
utils::aws::trace_from_message_attributes(
message_attributes,
&receipt_handle,
)?;
} else {
tracing::warn!(
?receipt_handle,
"SQS message missing message attributes"
);
}

if let Ok(UniquenessCheckRequest {
plain_code,
signup_id,
}) = serde_json::from_str::<UniquenessCheckRequest>(&body)
{
self.uniqueness_check(receipt_handle, plain_code, signup_id)
.await?;
} else {
tracing::error!(
?receipt_handle,
"Failed to parse template from message"
);
let body = message.body.context("Missing message body")?;

sqs_delete_message(
&self.sqs_client,
&self.config.queues.queries_queue_url,
receipt_handle,
)
if let Ok(UniquenessCheckRequest {
plain_code,
signup_id,
}) = serde_json::from_str::<UniquenessCheckRequest>(&body)
{
self.uniqueness_check(receipt_handle, plain_code, signup_id)
.await?;
}
} else {
tracing::error!(
?receipt_handle,
"Failed to parse template from message"
);

sqs_delete_message(
&self.sqs_client,
&self.config.queues.queries_queue_url,
receipt_handle,
)
.await?;
}

Ok(())
}

Expand Down Expand Up @@ -430,10 +436,7 @@ impl Coordinator {
tracing::info!(?matches, "Matches found");
}

// Latest serial id is the last id shared across all nodes
// so we need to subtract 1 from the counter
let latest_serial_id: Option<u64> = (i as u64).checked_sub(1);
let distance_results = DistanceResults::new(latest_serial_id, matches);
let distance_results = DistanceResults::new(i as u64, matches);

Ok(distance_results)
}
Expand All @@ -456,58 +459,67 @@ impl Coordinator {

async fn handle_db_sync(self: Arc<Self>) -> eyre::Result<()> {
loop {
self.db_sync().await?;
let messages = sqs_dequeue(
&self.sqs_client,
&self.config.queues.db_sync_queue_url,
)
.await?;

if messages.is_empty() {
tokio::time::sleep(IDLE_SLEEP_TIME).await;
}

for message in messages {
self.db_sync(message).await?;
}
}
}

#[instrument(skip(self))]
async fn db_sync(&self) -> eyre::Result<()> {
let messages = sqs_dequeue(
&self.sqs_client,
&self.config.queues.db_sync_queue_url,
)
.await?;

if messages.is_empty() {
tokio::time::sleep(IDLE_SLEEP_TIME).await;
return Ok(());
#[tracing::instrument(skip(self, message))]
async fn db_sync(&self, message: Message) -> eyre::Result<()> {
let receipt_handle = message
.receipt_handle
.context("Missing receipt handle in message")?;

if let Some(message_attributes) = &message.message_attributes {
utils::aws::trace_from_message_attributes(
message_attributes,
&receipt_handle,
)?;
} else {
tracing::warn!(
?receipt_handle,
"SQS message missing message attributes"
);
}

for message in messages {
let body = message.body.context("Missing message body")?;
let receipt_handle = message
.receipt_handle
.context("Missing receipt handle in message")?;
let body = message.body.context("Missing message body")?;

let items = if let Ok(items) =
serde_json::from_str::<Vec<DbSyncPayload>>(&body)
{
items
} else {
tracing::error!(
?receipt_handle,
"Failed to parse message body"
);
continue;
};
let items = if let Ok(items) =
serde_json::from_str::<Vec<DbSyncPayload>>(&body)
{
items
} else {
tracing::error!(?receipt_handle, "Failed to parse message body");
return Ok(());
};

let masks: Vec<_> =
items.into_iter().map(|item| (item.id, item.mask)).collect();
let masks: Vec<_> =
items.into_iter().map(|item| (item.id, item.mask)).collect();

tracing::info!(
num_new_masks = masks.len(),
"Inserting masks into database"
);
tracing::info!(
num_new_masks = masks.len(),
"Inserting masks into database"
);

self.database.insert_masks(&masks).await?;
self.database.insert_masks(&masks).await?;

sqs_delete_message(
&self.sqs_client,
&self.config.queues.db_sync_queue_url,
receipt_handle,
)
.await?;
}
sqs_delete_message(
&self.sqs_client,
&self.config.queues.db_sync_queue_url,
receipt_handle,
)
.await?;

Ok(())
}
Expand All @@ -527,49 +539,11 @@ pub struct UniquenessCheckRequest {

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct UniquenessCheckResult {
#[serde(with = "some_or_minus_one")]
pub serial_id: Option<u64>,
pub serial_id: u64,
pub matches: Vec<Distance>,
pub signup_id: String,
}

mod some_or_minus_one {
use serde::Deserialize;

pub fn serialize<S>(
value: &Option<u64>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if let Some(value) = value {
serializer.serialize_u64(*value)
} else {
serializer.serialize_i64(-1)
}
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = i64::deserialize(deserializer)?;

if value < -1 {
return Err(serde::de::Error::custom(
"value must be -1 or greater",
));
}

if value == -1 {
Ok(None)
} else {
Ok(Some(value as u64))
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -599,7 +573,7 @@ mod tests {
#[test]
fn result_serialization() {
let output = UniquenessCheckResult {
serial_id: Some(1),
serial_id: 1,
matches: vec![Distance::new(0, 0.5), Distance::new(1, 0.2)],
signup_id: "signup_id".to_string(),
};
Expand Down Expand Up @@ -632,16 +606,16 @@ mod tests {
}

#[test]
fn result_serialization_no_serial_id() {
fn result_serialization_zero_serial_id() {
let output = UniquenessCheckResult {
serial_id: None,
serial_id: 0,
matches: vec![Distance::new(0, 0.5), Distance::new(1, 0.2)],
signup_id: "signup_id".to_string(),
};

const EXPECTED: &str = indoc::indoc! {r#"
{
"serial_id": -1,
"serial_id": 0,
"matches": [
{
"distance": 0.5,
Expand Down
Loading

0 comments on commit 24ab7b4

Please sign in to comment.