Skip to content

Commit

Permalink
Offset sequence by 1 (#66)
Browse files Browse the repository at this point in the history
* Offset sequence by 1

* Fix db sync
  • Loading branch information
Dzejkop authored Feb 28, 2024
1 parent be36a42 commit 36b3edc
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 74 deletions.
10 changes: 3 additions & 7 deletions bin/e2e/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async fn main() -> eyre::Result<()> {

let masks = db.fetch_masks(0).await?;

(masks.len() as u64).checked_sub(1)
masks.len() as u64
};

let participant_db_sync_queues = vec![
Expand Down Expand Up @@ -169,18 +169,14 @@ async fn main() -> eyre::Result<()> {
);
}
} else {
if let Some(id) = next_serial_id.as_mut() {
*id += 1;
} else {
next_serial_id = Some(0);
}
next_serial_id += 1;

common::seed_db_sync(
&sqs_client,
&config.db_sync.coordinator_db_sync_queue,
&participant_db_sync_queues,
template,
next_serial_id.context("Could not get next serial id")?,
next_serial_id,
)
.await?;

Expand Down
54 changes: 6 additions & 48 deletions src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,7 @@ impl Coordinator {
tracing::info!(?matches, "Matches found");
}

// Latest serial id is the last id shared across all nodes
// so we need to subtract 1 from the counter
let latest_serial_id: Option<u64> = (i as u64).checked_sub(1);
let distance_results = DistanceResults::new(latest_serial_id, matches);
let distance_results = DistanceResults::new(i as u64, matches);

Ok(distance_results)
}
Expand Down Expand Up @@ -458,7 +455,6 @@ impl Coordinator {

if messages.is_empty() {
tokio::time::sleep(IDLE_SLEEP_TIME).await;
return Ok(());
}

for message in messages {
Expand Down Expand Up @@ -531,49 +527,11 @@ pub struct UniquenessCheckRequest {

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct UniquenessCheckResult {
#[serde(with = "some_or_minus_one")]
pub serial_id: Option<u64>,
pub serial_id: u64,
pub matches: Vec<Distance>,
pub signup_id: String,
}

mod some_or_minus_one {
use serde::Deserialize;

pub fn serialize<S>(
value: &Option<u64>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if let Some(value) = value {
serializer.serialize_u64(*value)
} else {
serializer.serialize_i64(-1)
}
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = i64::deserialize(deserializer)?;

if value < -1 {
return Err(serde::de::Error::custom(
"value must be -1 or greater",
));
}

if value == -1 {
Ok(None)
} else {
Ok(Some(value as u64))
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -603,7 +561,7 @@ mod tests {
#[test]
fn result_serialization() {
let output = UniquenessCheckResult {
serial_id: Some(1),
serial_id: 1,
matches: vec![Distance::new(0, 0.5), Distance::new(1, 0.2)],
signup_id: "signup_id".to_string(),
};
Expand Down Expand Up @@ -636,16 +594,16 @@ mod tests {
}

#[test]
fn result_serialization_no_serial_id() {
fn result_serialization_zero_serial_id() {
let output = UniquenessCheckResult {
serial_id: None,
serial_id: 0,
matches: vec![Distance::new(0, 0.5), Distance::new(1, 0.2)],
signup_id: "signup_id".to_string(),
};

const EXPECTED: &str = indoc::indoc! {r#"
{
"serial_id": -1,
"serial_id": 0,
"matches": [
{
"distance": 0.5,
Expand Down
32 changes: 16 additions & 16 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ impl Db {
r#"
SELECT id, mask
FROM masks
WHERE id >= $1
WHERE id > $1
ORDER BY id ASC
"#,
)
.bind(id as i64)
.fetch_all(&self.pool)
.await?;

Ok(filter_sequential_items(masks, id as i64))
Ok(filter_sequential_items(masks, 1 + id as i64))
}

#[tracing::instrument(skip(self))]
Expand Down Expand Up @@ -90,15 +90,15 @@ impl Db {
r#"
SELECT id, share
FROM shares
WHERE id >= $1
WHERE id > $1
ORDER BY id ASC
"#,
)
.bind(id as i64)
.fetch_all(&self.pool)
.await?;

Ok(filter_sequential_items(shares, id as i64))
Ok(filter_sequential_items(shares, 1 + id as i64))
}

#[tracing::instrument(skip(self))]
Expand Down Expand Up @@ -195,7 +195,7 @@ mod tests {

let mut rng = thread_rng();

let masks = vec![(0, rng.gen::<Bits>()), (1, rng.gen::<Bits>())];
let masks = vec![(1, rng.gen::<Bits>()), (2, rng.gen::<Bits>())];

db.insert_masks(&masks).await?;

Expand All @@ -215,7 +215,7 @@ mod tests {

let mut rng = thread_rng();

let masks = vec![(0, rng.gen::<Bits>()), (1, rng.gen::<Bits>())];
let masks = vec![(1, rng.gen::<Bits>()), (2, rng.gen::<Bits>())];

db.insert_masks(&masks).await?;

Expand Down Expand Up @@ -244,7 +244,7 @@ mod tests {
let mut rng = thread_rng();

let shares =
vec![(0, rng.gen::<EncodedBits>()), (1, rng.gen::<EncodedBits>())];
vec![(1, rng.gen::<EncodedBits>()), (2, rng.gen::<EncodedBits>())];

db.insert_shares(&shares).await?;

Expand All @@ -265,7 +265,7 @@ mod tests {
let mut rng = thread_rng();

let shares =
vec![(0, rng.gen::<EncodedBits>()), (1, rng.gen::<EncodedBits>())];
vec![(1, rng.gen::<EncodedBits>()), (2, rng.gen::<EncodedBits>())];

db.insert_shares(&shares).await?;

Expand All @@ -283,11 +283,11 @@ mod tests {
let mut rng = thread_rng();

let shares = vec![
(0, rng.gen::<EncodedBits>()),
(1, rng.gen::<EncodedBits>()),
(4, rng.gen::<EncodedBits>()),
(2, rng.gen::<EncodedBits>()),
(5, rng.gen::<EncodedBits>()),
(7, rng.gen::<EncodedBits>()),
(6, rng.gen::<EncodedBits>()),
(8, rng.gen::<EncodedBits>()),
];

db.insert_shares(&shares).await?;
Expand All @@ -308,11 +308,11 @@ mod tests {
let mut rng = thread_rng();

let masks = vec![
(0, rng.gen::<Bits>()),
(1, rng.gen::<Bits>()),
(2, rng.gen::<Bits>()),
(3, rng.gen::<Bits>()),
(5, rng.gen::<Bits>()),
(4, rng.gen::<Bits>()),
(6, rng.gen::<Bits>()),
];

db.insert_masks(&masks).await?;
Expand All @@ -335,11 +335,11 @@ mod tests {
let mut rng = thread_rng();

let masks = vec![
(0, rng.gen::<Bits>()),
(1, rng.gen::<Bits>()),
(2, rng.gen::<Bits>()),
(3, rng.gen::<Bits>()),
(5, rng.gen::<Bits>()),
(4, rng.gen::<Bits>()),
(6, rng.gen::<Bits>()),
];

db.insert_masks(&masks).await?;
Expand All @@ -359,11 +359,11 @@ mod tests {
let mut rng = thread_rng();

let masks = vec![
(1, rng.gen::<Bits>()),
(2, rng.gen::<Bits>()),
(3, rng.gen::<Bits>()),
(4, rng.gen::<Bits>()),
(5, rng.gen::<Bits>()),
(6, rng.gen::<Bits>()),
];

db.insert_masks(&masks).await?;
Expand Down
4 changes: 2 additions & 2 deletions src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ impl PartialEq for Distance {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct DistanceResults {
/// The lowest serial id known across all nodes
pub serial_id: Option<u64>,
pub serial_id: u64,
/// The distances to the query
pub matches: Vec<Distance>,
}

impl DistanceResults {
pub fn new(serial_id: Option<u64>, matches: Vec<Distance>) -> Self {
pub fn new(serial_id: u64, matches: Vec<Distance>) -> Self {
Self { serial_id, matches }
}
}
Expand Down
1 change: 0 additions & 1 deletion src/participant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ impl Participant {

if messages.is_empty() {
tokio::time::sleep(IDLE_SLEEP_TIME).await;
return Ok(());
}

for message in messages {
Expand Down

0 comments on commit 36b3edc

Please sign in to comment.