Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
0xKitsune committed Mar 1, 2024
2 parents f790819 + 70fc04b commit 4f23f20
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 49 deletions.
8 changes: 4 additions & 4 deletions bin/e2e/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,11 @@ pub async fn seed_db_sync(
participant_db_sync_queues: &[&str],
template: Template,
serial_id: u64,
rng: &mut impl Rng,
) -> eyre::Result<()> {
tracing::info!("Encoding shares");
let shares: Box<[EncodedBits]> = mpc::distance::encode(&template)
.share(participant_db_sync_queues.len());
.share(participant_db_sync_queues.len(), rng);

let coordinator_payload =
serde_json::to_string(&vec![coordinator::DbSyncPayload {
Expand Down Expand Up @@ -209,9 +210,8 @@ async fn wait_for_empty_queue(
}
}

pub fn generate_random_string(len: usize) -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
pub fn generate_random_string(len: usize, rng: &mut impl Rng) -> String {
rng.sample_iter(&Alphanumeric)
.take(len)
.map(char::from)
.collect()
Expand Down
9 changes: 8 additions & 1 deletion bin/e2e/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use eyre::ContextCompat;
use mpc::config::{load_config, AwsConfig, DbConfig};
use mpc::coordinator::UniquenessCheckResult;
use mpc::db::Db;
use mpc::rng_source::RngSource;
use mpc::template::{Bits, Template};
use mpc::utils::aws::{self, sqs_client_from_config};
use serde::Deserialize;
Expand Down Expand Up @@ -44,6 +45,9 @@ struct Args {
/// The path to the signup sequence file to use
#[clap(short, long, default_value = "bin/e2e/signup_sequence.json")]
signup_sequence: String,

#[clap(short, long, env, default_value = "thread")]
rng: RngSource,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -74,6 +78,8 @@ async fn main() -> eyre::Result<()> {
tracing::warn!("AWS_DEFAULT_REGION not set");
}

let mut rng = args.rng.to_rng();

let _shutdown_tracing_provider = StdoutBattery::init();

tracing::info!("Loading config");
Expand Down Expand Up @@ -132,7 +138,7 @@ async fn main() -> eyre::Result<()> {
&sqs_client,
&config.coordinator_queue.query_queue,
&element.signup_id,
&common::generate_random_string(4),
&common::generate_random_string(4, &mut rng),
)
.await?;

Expand Down Expand Up @@ -177,6 +183,7 @@ async fn main() -> eyre::Result<()> {
&participant_db_sync_queues,
template,
next_serial_id,
&mut rng,
)
.await?;

Expand Down
8 changes: 6 additions & 2 deletions bin/utils/seed_iris_db.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use clap::Args;
use mpc::bits::Bits;
use mpc::rng_source::RngSource;
use mpc::template::Template;
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::{Deserialize, Serialize};

use crate::generate_random_string;
Expand All @@ -19,6 +20,9 @@ pub struct SeedIrisDb {

#[clap(short, long, default_value = "10000")]
pub batch_size: usize,

#[clap(short, long, env, default_value = "thread")]
pub rng: RngSource,
}

pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> {
Expand All @@ -30,7 +34,7 @@ pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> {

let iris_db = client.database(DATABASE_NAME);

let mut rng = thread_rng();
let mut rng = args.rng.to_rng();

tracing::info!("Generating codes");
let left_templates = (0..args.num_templates)
Expand Down
12 changes: 9 additions & 3 deletions bin/utils/seed_mpc_db.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use clap::Args;
use futures::stream::FuturesUnordered;
Expand All @@ -8,6 +8,7 @@ use mpc::bits::Bits;
use mpc::config::DbConfig;
use mpc::db::Db;
use mpc::distance::EncodedBits;
use mpc::rng_source::RngSource;
use mpc::template::Template;
use rand::{thread_rng, Rng};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
Expand All @@ -25,6 +26,9 @@ pub struct SeedMPCDb {

#[clap(short, long, default_value = "100")]
pub batch_size: usize,

#[clap(short, long, env, default_value = "thread")]
pub rng: RngSource,
}

pub async fn seed_mpc_db(args: &SeedMPCDb) -> eyre::Result<()> {
Expand Down Expand Up @@ -149,8 +153,10 @@ fn generate_shares_and_masks(
let shares_chunk = chunk
.into_par_iter()
.map(|template| {
let shares =
mpc::distance::encode(template).share(num_participants);
let mut rng = thread_rng();

let shares = mpc::distance::encode(template)
.share(num_participants, &mut rng);

pb.inc(1);
shares
Expand Down
62 changes: 31 additions & 31 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,37 +107,37 @@ services:
- 8434:5432
environment:
- POSTGRES_HOST_AUTH_METHOD=trust
e2e_test:
depends_on:
- localstack
- coordinator
- participant_0
- participant_1
- coordinator_db
- participant_0_db
- participant_1_db
build:
context: .
dockerfile: Dockerfile
args:
- BIN=e2e
command: [ "--signup-sequence", "/signup_sequence.json" ]
volumes:
- ./bin/e2e/signup_sequence.json:/signup_sequence.json
environment:
- 'RUST_LOG=info'
- 'E2E__AWS__ENDPOINT=http://localstack:4566'
- 'E2E__AWS__REGION=us-east-1'
- 'E2E__DB_SYNC__COORDINATOR_DB_URL=postgres://postgres:postgres@coordinator_db:5432/db'
- 'E2E__DB_SYNC__COORDINATOR_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-db-sync-queue'
- 'E2E__DB_SYNC__PARTICIPANT_0_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-0-db-sync-queue'
- 'E2E__DB_SYNC__PARTICIPANT_1_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-1-db-sync-queue'
- 'E2E__COORDINATOR_QUEUE__QUERY_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-uniqueness-check.fifo'
- 'E2E__COORDINATOR_QUEUE__RESULTS_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-results-queue.fifo'
# AWS env vars - they don't matter but are required
- 'AWS_ACCESS_KEY_ID=test'
- 'AWS_SECRET_ACCESS_KEY=test'
- 'AWS_DEFAULT_REGION=us-east-1'
# e2e_test:
# depends_on:
# - localstack
# - coordinator
# - participant_0
# - participant_1
# - coordinator_db
# - participant_0_db
# - participant_1_db
# build:
# context: .
# dockerfile: Dockerfile
# args:
# - BIN=e2e
# command: [ "--signup-sequence", "/signup_sequence.json" ]
# volumes:
# - ./bin/e2e/signup_sequence.json:/signup_sequence.json
# environment:
# - 'RUST_LOG=info'
# - 'E2E__AWS__ENDPOINT=http://localstack:4566'
# - 'E2E__AWS__REGION=us-east-1'
# - 'E2E__DB_SYNC__COORDINATOR_DB_URL=postgres://postgres:postgres@coordinator_db:5432/db'
# - 'E2E__DB_SYNC__COORDINATOR_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-db-sync-queue'
# - 'E2E__DB_SYNC__PARTICIPANT_0_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-0-db-sync-queue'
# - 'E2E__DB_SYNC__PARTICIPANT_1_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-1-db-sync-queue'
# - 'E2E__COORDINATOR_QUEUE__QUERY_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-uniqueness-check.fifo'
# - 'E2E__COORDINATOR_QUEUE__RESULTS_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-results-queue.fifo'
# # AWS env vars - they don't matter but are required
# - 'AWS_ACCESS_KEY_ID=test'
# - 'AWS_SECRET_ACCESS_KEY=test'
# - 'AWS_DEFAULT_REGION=us-east-1'

networks:
default:
Expand Down
30 changes: 25 additions & 5 deletions src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,26 @@ impl Coordinator {

let body = message.body.context("Missing message body")?;

let UniquenessCheckRequest {
plain_code: template,
if let Ok(UniquenessCheckRequest {
plain_code,
signup_id,
} = serde_json::from_str(&body).context("Failed to parse message")?;
}) = serde_json::from_str::<UniquenessCheckRequest>(&body)
{
self.uniqueness_check(receipt_handle, plain_code, signup_id)
.await?;
} else {
tracing::error!(
?receipt_handle,
"Failed to parse template from message"
);

// Process the query
self.uniqueness_check(receipt_handle, template, signup_id)
sqs_delete_message(
&self.sqs_client,
&self.config.queues.queries_queue_url,
receipt_handle,
)
.await?;
}

Ok(())
}
Expand Down Expand Up @@ -489,6 +501,14 @@ impl Coordinator {
items
} else {
tracing::error!(?receipt_handle, "Failed to parse message body");

sqs_delete_message(
&self.sqs_client,
&self.config.queues.db_sync_queue_url,
receipt_handle,
)
.await?;

return Ok(());
};

Expand Down
7 changes: 4 additions & 3 deletions src/encoded_bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytemuck::{cast_slice_mut, Pod, Zeroable};
use rand::distributions::{Distribution, Standard};
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::de::Error as _;
use serde::{Deserialize, Deserializer, Serialize};

Expand All @@ -23,11 +23,10 @@ unsafe impl Pod for EncodedBits {}

impl EncodedBits {
/// Generate secret shares from this bitvector.
pub fn share(&self, n: usize) -> Box<[EncodedBits]> {
pub fn share(&self, n: usize, rng: &mut impl Rng) -> Box<[EncodedBits]> {
assert!(n > 0);

// Create `n - 1` random shares.
let mut rng = thread_rng();
let mut result: Box<[EncodedBits]> =
iter::repeat_with(|| rng.gen::<EncodedBits>())
.take(n - 1)
Expand Down Expand Up @@ -210,6 +209,8 @@ impl Serialize for EncodedBits {

#[cfg(test)]
mod tests {
use rand::thread_rng;

use super::*;

#[test]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ pub mod distance;
pub mod encoded_bits;
pub mod health_check;
pub mod participant;
pub mod rng_source;
pub mod template;
pub mod utils;
8 changes: 8 additions & 0 deletions src/participant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ impl Participant {
items
} else {
tracing::error!(?receipt_handle, "Failed to parse message body");

sqs_delete_message(
&self.sqs_client,
&self.config.queues.db_sync_queue_url,
receipt_handle,
)
.await?;

return Ok(());
};

Expand Down
73 changes: 73 additions & 0 deletions src/rng_source.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::fmt;
use std::str::FromStr;

use rand::{thread_rng, RngCore, SeedableRng};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "kind")]
pub enum RngSource {
Thread,
Small(u64),
Std(u64),
}

impl RngSource {
pub fn to_rng(&self) -> Box<dyn RngCore> {
match self {
RngSource::Thread => Box::new(thread_rng()),
RngSource::Small(seed) => {
let rng: rand::rngs::SmallRng =
SeedableRng::seed_from_u64(*seed);
Box::new(rng)
}
RngSource::Std(seed) => {
let rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(*seed);
Box::new(rng)
}
}
}
}

impl fmt::Display for RngSource {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
RngSource::Thread => write!(f, "thread"),
RngSource::Small(seed) => write!(f, "small:{}", seed),
RngSource::Std(seed) => write!(f, "std:{}", seed),
}
}
}

impl FromStr for RngSource {
type Err = eyre::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "thread" {
Ok(RngSource::Thread)
} else if s.starts_with("small:") {
let seed = s.trim_start_matches("small:").parse()?;
Ok(RngSource::Small(seed))
} else if s.starts_with("std:") {
let seed = s.trim_start_matches("std:").parse()?;
Ok(RngSource::Std(seed))
} else {
Err(eyre::eyre!("Invalid RngSource: {}", s))
}
}
}

#[cfg(test)]
mod tests {
use test_case::test_case;

use super::*;

#[test_case("thread" => RngSource::Thread)]
#[test_case("std:42" => RngSource::Std(42))]
#[test_case("small:42" => RngSource::Small(42))]
fn serialization_round_trip(s: &str) -> RngSource {
s.parse().unwrap()
}
}

0 comments on commit 4f23f20

Please sign in to comment.