Skip to content

Commit

Permalink
Merge pull request #109 from worldcoin/0xkitsune/deletion
Browse files Browse the repository at this point in the history
feat(deletion): mask/share deletion logic
  • Loading branch information
0xKitsune authored May 6, 2024
2 parents 1f47d4a + b40dee4 commit 43443c4
Show file tree
Hide file tree
Showing 9 changed files with 472 additions and 217 deletions.
281 changes: 144 additions & 137 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-04-01"
channel = "nightly-2024-04-27"
components = ["rustc-dev", "rustc", "cargo", "rustfmt", "clippy"]
1 change: 1 addition & 0 deletions src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct Bits(pub [u64; LIMBS]);

impl Bits {
pub const ZERO: Self = Self([0; LIMBS]);
pub const MAX: Self = Self([u64::MAX; LIMBS]);

/// Returns an unordered iterator over the 31 possible rotations
pub fn rotations(&self) -> impl Iterator<Item = Self> + '_ {
Expand Down
45 changes: 35 additions & 10 deletions src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ impl Coordinator {
{
self.enqueue_latest_serial_id(participant_streams).await?;
last_serial_id_check = Instant::now();
} else {
// Send ack and wait for response from all participants
self.send_ack(participant_streams).await?;
}

// Send ack and wait for response from all participants
self.send_ack(participant_streams).await?;

// Dequeue messages, limiting the max number of messages to 1
let messages = match sqs_dequeue(
&self.sqs_client,
Expand Down Expand Up @@ -715,15 +715,21 @@ impl Coordinator {
return Ok(());
};

let masks: Vec<_> =
items.into_iter().map(|item| (item.id, item.mask)).collect();
// Partition deletions and overwrite masks in memory
let deletions = items
.iter()
.filter(|item| item.mask == Bits::MAX)
.collect::<Vec<_>>();

tracing::info!(
num_new_masks = masks.len(),
"Inserting masks into database"
);
// Remove the mask from masks
let mut masks = self.masks.lock().await;
for DbSyncPayload { id, .. } in deletions {
masks[(id - 1) as usize] = Bits::MAX;
}
drop(masks);

self.database.insert_masks(&masks).await?;
// Insert masks into the db
self.insert_masks(items).await?;

sqs_delete_message(
&self.sqs_client,
Expand All @@ -734,6 +740,25 @@ impl Coordinator {

Ok(())
}

async fn insert_masks(
&self,
insertions: Vec<DbSyncPayload>,
) -> eyre::Result<()> {
tracing::info!(
num_masks = insertions.len(),
"Inserting masks into database"
);

let insertions = insertions
.into_iter()
.map(|item| (item.id, item.mask))
.collect::<Vec<(u64, Bits)>>();

self.database.insert_masks(&insertions).await?;

Ok(())
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
82 changes: 79 additions & 3 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use sqlx::{Postgres, QueryBuilder};
use crate::bits::Bits;
use crate::config::DbConfig;
use crate::distance::EncodedBits;

static MIGRATOR: Migrator = sqlx::migrate!("./migrations/");

pub struct Db {
Expand Down Expand Up @@ -92,7 +91,7 @@ impl Db {
builder.push(")");
}

builder.push(" ON CONFLICT (id) DO NOTHING");
builder.push(" ON CONFLICT (id) DO UPDATE SET mask = EXCLUDED.mask");

let query = builder.build();

Expand Down Expand Up @@ -159,7 +158,7 @@ impl Db {
builder.push(")");
}

builder.push(" ON CONFLICT (id) DO NOTHING");
builder.push(" ON CONFLICT (id) DO UPDATE SET share = EXCLUDED.share");

let query = builder.build();

Expand Down Expand Up @@ -544,4 +543,81 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_delete_masks() -> eyre::Result<()> {
let docker = clients::Cli::default();
let (db, _pg) = setup(&docker).await?;

let mut rng = thread_rng();

let masks = vec![
(1, rng.gen::<Bits>()),
(2, rng.gen::<Bits>()),
(3, rng.gen::<Bits>()),
(4, rng.gen::<Bits>()),
(5, rng.gen::<Bits>()),
];

db.insert_masks(&masks).await?;

let fetched_masks = db.fetch_masks(0).await?;
let expected_masks =
masks.iter().map(|(_, mask)| *mask).collect::<Vec<_>>();
assert_eq!(fetched_masks, expected_masks);

// Overwrite masks, simulating deletion
db.insert_masks(&[(1, Bits::MAX), (3, Bits::MAX), (5, Bits::MAX)])
.await?;

let fetched_masks = db.fetch_masks(0).await?;
let expected_masks =
vec![Bits::MAX, masks[1].1, Bits::MAX, masks[3].1, Bits::MAX];
assert_eq!(fetched_masks, expected_masks);

Ok(())
}

#[tokio::test]
async fn test_delete_shares() -> eyre::Result<()> {
let docker = clients::Cli::default();
let (db, _pg) = setup(&docker).await?;

let mut rng = thread_rng();

let shares = vec![
(1, rng.gen::<EncodedBits>()),
(2, rng.gen::<EncodedBits>()),
(3, rng.gen::<EncodedBits>()),
(4, rng.gen::<EncodedBits>()),
(5, rng.gen::<EncodedBits>()),
];

db.insert_shares(&shares).await?;

let fetched_shares = db.fetch_shares(0).await?;
let expected_shares =
shares.iter().map(|(_, share)| *share).collect::<Vec<_>>();
assert_eq!(fetched_shares, expected_shares);

// Overwrite shares, simulating deletion
db.insert_shares(&[
(1, EncodedBits::ZERO),
(3, EncodedBits::MAX),
(5, EncodedBits::ZERO),
])
.await?;

let fetched_shares = db.fetch_shares(0).await?;
let expected_shares = vec![
EncodedBits::ZERO,
shares[1].1,
EncodedBits::MAX,
shares[3].1,
EncodedBits::ZERO,
];
assert_eq!(fetched_shares, expected_shares);

Ok(())
}
}
3 changes: 3 additions & 0 deletions src/encoded_bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ unsafe impl Zeroable for EncodedBits {}
unsafe impl Pod for EncodedBits {}

impl EncodedBits {
pub const ZERO: Self = Self([0; BITS]);
pub const MAX: Self = Self([u16::MAX; BITS]);

/// Generate secret shares from this bitvector.
pub fn share(&self, n: usize, rng: &mut impl Rng) -> Box<[EncodedBits]> {
assert!(n > 0);
Expand Down
45 changes: 35 additions & 10 deletions src/participant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ impl Participant {
}

let body = message.body.context("Missing message body")?;

let items = if let Ok(items) =
serde_json::from_str::<Vec<DbSyncPayload>>(&body)
{
Expand All @@ -279,16 +278,23 @@ impl Participant {
return Ok(());
};

let shares: Vec<_> = items
.into_iter()
.map(|item| (item.id, item.share))
.collect();
// Partition deletions and overwrite shares in memory
let deletions = items
.iter()
.filter(|item| {
matches!(item.share, EncodedBits::ZERO)
|| matches!(item.share, EncodedBits::MAX)
})
.collect::<Vec<_>>();

tracing::info!(
num_new_shares = shares.len(),
"Inserting shares into database"
);
self.database.insert_shares(&shares).await?;
let mut shares = self.shares.lock().await;
for DbSyncPayload { id, share } in deletions {
shares[(id - 1) as usize] = *share;
}
drop(shares);

// Insert the shares into the db
self.insert_shares(items).await?;

sqs_delete_message(
&self.sqs_client,
Expand All @@ -300,6 +306,25 @@ impl Participant {
Ok(())
}

async fn insert_shares(
&self,
shares: Vec<DbSyncPayload>,
) -> eyre::Result<()> {
tracing::info!(
num_shares = shares.len(),
"Inserting shares into database"
);

let shares = shares
.into_iter()
.map(|item| (item.id, item.share))
.collect::<Vec<(u64, EncodedBits)>>();

self.database.insert_shares(&shares).await?;

Ok(())
}

#[tracing::instrument(skip(self))]
pub async fn sync_shares(&self) -> eyre::Result<usize> {
let mut shares = self.shares.lock().await;
Expand Down
Loading

0 comments on commit 43443c4

Please sign in to comment.