diff --git a/bin/e2e/e2e.rs b/bin/e2e/e2e.rs index f177076..7898edb 100644 --- a/bin/e2e/e2e.rs +++ b/bin/e2e/e2e.rs @@ -112,7 +112,7 @@ async fn main() -> eyre::Result<()> { let masks = db.fetch_masks(0).await?; - (masks.len() as u64).checked_sub(1) + masks.len() as u64 }; let participant_db_sync_queues = vec![ @@ -169,18 +169,14 @@ async fn main() -> eyre::Result<()> { ); } } else { - if let Some(id) = next_serial_id.as_mut() { - *id += 1; - } else { - next_serial_id = Some(0); - } + next_serial_id += 1; common::seed_db_sync( &sqs_client, &config.db_sync.coordinator_db_sync_queue, &participant_db_sync_queues, template, - next_serial_id.context("Could not get next serial id")?, + next_serial_id, ) .await?; diff --git a/src/coordinator.rs b/src/coordinator.rs index 38c602c..f2a5cd1 100644 --- a/src/coordinator.rs +++ b/src/coordinator.rs @@ -424,10 +424,7 @@ impl Coordinator { tracing::info!(?matches, "Matches found"); } - // Latest serial id is the last id shared across all nodes - // so we need to subtract 1 from the counter - let latest_serial_id: Option = (i as u64).checked_sub(1); - let distance_results = DistanceResults::new(latest_serial_id, matches); + let distance_results = DistanceResults::new(i as u64, matches); Ok(distance_results) } @@ -458,7 +455,6 @@ impl Coordinator { if messages.is_empty() { tokio::time::sleep(IDLE_SLEEP_TIME).await; - return Ok(()); } for message in messages { @@ -531,49 +527,11 @@ pub struct UniquenessCheckRequest { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct UniquenessCheckResult { - #[serde(with = "some_or_minus_one")] - pub serial_id: Option, + pub serial_id: u64, pub matches: Vec, pub signup_id: String, } -mod some_or_minus_one { - use serde::Deserialize; - - pub fn serialize( - value: &Option, - serializer: S, - ) -> Result - where - S: serde::Serializer, - { - if let Some(value) = value { - serializer.serialize_u64(*value) - } else { - serializer.serialize_i64(-1) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - let value = i64::deserialize(deserializer)?; - - if value < -1 { - return Err(serde::de::Error::custom( - "value must be -1 or greater", - )); - } - - if value == -1 { - Ok(None) - } else { - Ok(Some(value as u64)) - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -603,7 +561,7 @@ mod tests { #[test] fn result_serialization() { let output = UniquenessCheckResult { - serial_id: Some(1), + serial_id: 1, matches: vec![Distance::new(0, 0.5), Distance::new(1, 0.2)], signup_id: "signup_id".to_string(), }; @@ -636,16 +594,16 @@ mod tests { } #[test] - fn result_serialization_no_serial_id() { + fn result_serialization_zero_serial_id() { let output = UniquenessCheckResult { - serial_id: None, + serial_id: 0, matches: vec![Distance::new(0, 0.5), Distance::new(1, 0.2)], signup_id: "signup_id".to_string(), }; const EXPECTED: &str = indoc::indoc! {r#" { - "serial_id": -1, + "serial_id": 0, "matches": [ { "distance": 0.5, diff --git a/src/db.rs b/src/db.rs index 1901df6..5c262c9 100644 --- a/src/db.rs +++ b/src/db.rs @@ -42,7 +42,7 @@ impl Db { r#" SELECT id, mask FROM masks - WHERE id >= $1 + WHERE id > $1 ORDER BY id ASC "#, ) @@ -50,7 +50,7 @@ impl Db { .fetch_all(&self.pool) .await?; - Ok(filter_sequential_items(masks, id as i64)) + Ok(filter_sequential_items(masks, 1 + id as i64)) } #[tracing::instrument(skip(self))] @@ -90,7 +90,7 @@ impl Db { r#" SELECT id, share FROM shares - WHERE id >= $1 + WHERE id > $1 ORDER BY id ASC "#, ) @@ -98,7 +98,7 @@ impl Db { .fetch_all(&self.pool) .await?; - Ok(filter_sequential_items(shares, id as i64)) + Ok(filter_sequential_items(shares, 1 + id as i64)) } #[tracing::instrument(skip(self))] @@ -195,7 +195,7 @@ mod tests { let mut rng = thread_rng(); - let masks = vec![(0, rng.gen::()), (1, rng.gen::())]; + let masks = vec![(1, rng.gen::()), (2, rng.gen::())]; db.insert_masks(&masks).await?; @@ -215,7 +215,7 @@ mod tests { let mut rng = thread_rng(); - let masks = vec![(0, rng.gen::()), (1, rng.gen::())]; + let masks = vec![(1, rng.gen::()), (2, rng.gen::())]; db.insert_masks(&masks).await?; @@ -244,7 +244,7 @@ mod tests { let mut rng = thread_rng(); let shares = - vec![(0, rng.gen::()), (1, rng.gen::())]; + vec![(1, rng.gen::()), (2, rng.gen::())]; db.insert_shares(&shares).await?; @@ -265,7 +265,7 @@ mod tests { let mut rng = thread_rng(); let shares = - vec![(0, rng.gen::()), (1, rng.gen::())]; + vec![(1, rng.gen::()), (2, rng.gen::())]; db.insert_shares(&shares).await?; @@ -283,11 +283,11 @@ mod tests { let mut rng = thread_rng(); let shares = vec![ - (0, rng.gen::()), (1, rng.gen::()), - (4, rng.gen::()), + (2, rng.gen::()), (5, rng.gen::()), - (7, rng.gen::()), + (6, rng.gen::()), + (8, rng.gen::()), ]; db.insert_shares(&shares).await?; @@ -308,11 +308,11 @@ mod tests { let mut rng = thread_rng(); let masks = vec![ - (0, rng.gen::()), (1, rng.gen::()), (2, rng.gen::()), (3, rng.gen::()), - (5, rng.gen::()), + (4, rng.gen::()), + (6, rng.gen::()), ]; db.insert_masks(&masks).await?; @@ -335,11 +335,11 @@ mod tests { let mut rng = thread_rng(); let masks = vec![ - (0, rng.gen::()), (1, rng.gen::()), (2, rng.gen::()), (3, rng.gen::()), - (5, rng.gen::()), + (4, rng.gen::()), + (6, rng.gen::()), ]; db.insert_masks(&masks).await?; @@ -359,11 +359,11 @@ mod tests { let mut rng = thread_rng(); let masks = vec![ - (1, rng.gen::()), (2, rng.gen::()), (3, rng.gen::()), (4, rng.gen::()), (5, rng.gen::()), + (6, rng.gen::()), ]; db.insert_masks(&masks).await?; diff --git a/src/distance.rs b/src/distance.rs index b70cbb2..17c85ef 100644 --- a/src/distance.rs +++ b/src/distance.rs @@ -82,13 +82,13 @@ impl PartialEq for Distance { #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct DistanceResults { /// The lowest serial id known across all nodes - pub serial_id: Option, + pub serial_id: u64, /// The distances to the query pub matches: Vec, } impl DistanceResults { - pub fn new(serial_id: Option, matches: Vec) -> Self { + pub fn new(serial_id: u64, matches: Vec) -> Self { Self { serial_id, matches } } } diff --git a/src/participant.rs b/src/participant.rs index 34bb718..9f7af8c 100644 --- a/src/participant.rs +++ b/src/participant.rs @@ -176,7 +176,6 @@ impl Participant { if messages.is_empty() { tokio::time::sleep(IDLE_SLEEP_TIME).await; - return Ok(()); } for message in messages {