diff --git a/bin/e2e/common.rs b/bin/e2e/common.rs index 7c4e9bc..79b2d62 100644 --- a/bin/e2e/common.rs +++ b/bin/e2e/common.rs @@ -124,10 +124,11 @@ pub async fn seed_db_sync( participant_db_sync_queues: &[&str], template: Template, serial_id: u64, + rng: &mut impl Rng, ) -> eyre::Result<()> { tracing::info!("Encoding shares"); let shares: Box<[EncodedBits]> = mpc::distance::encode(&template) - .share(participant_db_sync_queues.len()); + .share(participant_db_sync_queues.len(), rng); let coordinator_payload = serde_json::to_string(&vec![coordinator::DbSyncPayload { @@ -209,9 +210,8 @@ async fn wait_for_empty_queue( } } -pub fn generate_random_string(len: usize) -> String { - rand::thread_rng() - .sample_iter(&Alphanumeric) +pub fn generate_random_string(len: usize, rng: &mut impl Rng) -> String { + rng.sample_iter(&Alphanumeric) .take(len) .map(char::from) .collect() diff --git a/bin/e2e/e2e.rs b/bin/e2e/e2e.rs index 7898edb..3452a5b 100644 --- a/bin/e2e/e2e.rs +++ b/bin/e2e/e2e.rs @@ -7,6 +7,7 @@ use eyre::ContextCompat; use mpc::config::{load_config, AwsConfig, DbConfig}; use mpc::coordinator::UniquenessCheckResult; use mpc::db::Db; +use mpc::rng_source::RngSource; use mpc::template::{Bits, Template}; use mpc::utils::aws::{self, sqs_client_from_config}; use serde::Deserialize; @@ -44,6 +45,9 @@ struct Args { /// The path to the signup sequence file to use #[clap(short, long, default_value = "bin/e2e/signup_sequence.json")] signup_sequence: String, + + #[clap(short, long, env, default_value = "thread")] + rng: RngSource, } #[derive(Debug, Deserialize)] @@ -74,6 +78,8 @@ async fn main() -> eyre::Result<()> { tracing::warn!("AWS_DEFAULT_REGION not set"); } + let mut rng = args.rng.to_rng(); + let _shutdown_tracing_provider = StdoutBattery::init(); tracing::info!("Loading config"); @@ -132,7 +138,7 @@ async fn main() -> eyre::Result<()> { &sqs_client, &config.coordinator_queue.query_queue, &element.signup_id, - &common::generate_random_string(4), + &common::generate_random_string(4, &mut rng), ) .await?; @@ -177,6 +183,7 @@ async fn main() -> eyre::Result<()> { &participant_db_sync_queues, template, next_serial_id, + &mut rng, ) .await?; diff --git a/bin/utils/seed_iris_db.rs b/bin/utils/seed_iris_db.rs index 0a51201..b4e1599 100644 --- a/bin/utils/seed_iris_db.rs +++ b/bin/utils/seed_iris_db.rs @@ -1,7 +1,8 @@ use clap::Args; use mpc::bits::Bits; +use mpc::rng_source::RngSource; use mpc::template::Template; -use rand::{thread_rng, Rng}; +use rand::Rng; use serde::{Deserialize, Serialize}; use crate::generate_random_string; @@ -19,6 +20,9 @@ pub struct SeedIrisDb { #[clap(short, long, default_value = "10000")] pub batch_size: usize, + + #[clap(short, long, env, default_value = "thread")] + pub rng: RngSource, } pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> { @@ -30,7 +34,7 @@ pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> { let iris_db = client.database(DATABASE_NAME); - let mut rng = thread_rng(); + let mut rng = args.rng.to_rng(); tracing::info!("Generating codes"); let left_templates = (0..args.num_templates) diff --git a/bin/utils/seed_mpc_db.rs b/bin/utils/seed_mpc_db.rs index 6562159..e49538b 100644 --- a/bin/utils/seed_mpc_db.rs +++ b/bin/utils/seed_mpc_db.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use clap::Args; use futures::stream::FuturesUnordered; @@ -8,6 +8,7 @@ use mpc::bits::Bits; use mpc::config::DbConfig; use mpc::db::Db; use mpc::distance::EncodedBits; +use mpc::rng_source::RngSource; use mpc::template::Template; use rand::{thread_rng, Rng}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; @@ -25,6 +26,9 @@ pub struct SeedMPCDb { #[clap(short, long, default_value = "100")] pub batch_size: usize, + + #[clap(short, long, env, default_value = "thread")] + pub rng: RngSource, } pub async fn seed_mpc_db(args: &SeedMPCDb) -> eyre::Result<()> { @@ -149,8 +153,10 @@ fn generate_shares_and_masks( let shares_chunk = chunk .into_par_iter() .map(|template| { - let shares = - mpc::distance::encode(template).share(num_participants); + let mut rng = thread_rng(); + + let shares = mpc::distance::encode(template) + .share(num_participants, &mut rng); pb.inc(1); shares diff --git a/compose.yml b/compose.yml index 6971c7c..6a12346 100644 --- a/compose.yml +++ b/compose.yml @@ -107,37 +107,37 @@ services: - 8434:5432 environment: - POSTGRES_HOST_AUTH_METHOD=trust - e2e_test: - depends_on: - - localstack - - coordinator - - participant_0 - - participant_1 - - coordinator_db - - participant_0_db - - participant_1_db - build: - context: . - dockerfile: Dockerfile - args: - - BIN=e2e - command: [ "--signup-sequence", "/signup_sequence.json" ] - volumes: - - ./bin/e2e/signup_sequence.json:/signup_sequence.json - environment: - - 'RUST_LOG=info' - - 'E2E__AWS__ENDPOINT=http://localstack:4566' - - 'E2E__AWS__REGION=us-east-1' - - 'E2E__DB_SYNC__COORDINATOR_DB_URL=postgres://postgres:postgres@coordinator_db:5432/db' - - 'E2E__DB_SYNC__COORDINATOR_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-db-sync-queue' - - 'E2E__DB_SYNC__PARTICIPANT_0_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-0-db-sync-queue' - - 'E2E__DB_SYNC__PARTICIPANT_1_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-1-db-sync-queue' - - 'E2E__COORDINATOR_QUEUE__QUERY_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-uniqueness-check.fifo' - - 'E2E__COORDINATOR_QUEUE__RESULTS_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-results-queue.fifo' - # AWS env vars - they don't matter but are required - - 'AWS_ACCESS_KEY_ID=test' - - 'AWS_SECRET_ACCESS_KEY=test' - - 'AWS_DEFAULT_REGION=us-east-1' + # e2e_test: + # depends_on: + # - localstack + # - coordinator + # - participant_0 + # - participant_1 + # - coordinator_db + # - participant_0_db + # - participant_1_db + # build: + # context: . + # dockerfile: Dockerfile + # args: + # - BIN=e2e + # command: [ "--signup-sequence", "/signup_sequence.json" ] + # volumes: + # - ./bin/e2e/signup_sequence.json:/signup_sequence.json + # environment: + # - 'RUST_LOG=info' + # - 'E2E__AWS__ENDPOINT=http://localstack:4566' + # - 'E2E__AWS__REGION=us-east-1' + # - 'E2E__DB_SYNC__COORDINATOR_DB_URL=postgres://postgres:postgres@coordinator_db:5432/db' + # - 'E2E__DB_SYNC__COORDINATOR_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-db-sync-queue' + # - 'E2E__DB_SYNC__PARTICIPANT_0_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-0-db-sync-queue' + # - 'E2E__DB_SYNC__PARTICIPANT_1_DB_SYNC_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/participant-1-db-sync-queue' + # - 'E2E__COORDINATOR_QUEUE__QUERY_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-uniqueness-check.fifo' + # - 'E2E__COORDINATOR_QUEUE__RESULTS_QUEUE=http://sqs.us-east-1.localhost.localstack.cloud:4566/000000000000/coordinator-results-queue.fifo' + # # AWS env vars - they don't matter but are required + # - 'AWS_ACCESS_KEY_ID=test' + # - 'AWS_SECRET_ACCESS_KEY=test' + # - 'AWS_DEFAULT_REGION=us-east-1' networks: default: diff --git a/src/coordinator.rs b/src/coordinator.rs index f2a5cd1..36c93be 100644 --- a/src/coordinator.rs +++ b/src/coordinator.rs @@ -127,14 +127,26 @@ impl Coordinator { let body = message.body.context("Missing message body")?; - let UniquenessCheckRequest { - plain_code: template, + if let Ok(UniquenessCheckRequest { + plain_code, signup_id, - } = serde_json::from_str(&body).context("Failed to parse message")?; + }) = serde_json::from_str::(&body) + { + self.uniqueness_check(receipt_handle, plain_code, signup_id) + .await?; + } else { + tracing::error!( + ?receipt_handle, + "Failed to parse template from message" + ); - // Process the query - self.uniqueness_check(receipt_handle, template, signup_id) + sqs_delete_message( + &self.sqs_client, + &self.config.queues.queries_queue_url, + receipt_handle, + ) .await?; + } Ok(()) } @@ -489,6 +501,14 @@ impl Coordinator { items } else { tracing::error!(?receipt_handle, "Failed to parse message body"); + + sqs_delete_message( + &self.sqs_client, + &self.config.queues.db_sync_queue_url, + receipt_handle, + ) + .await?; + return Ok(()); }; diff --git a/src/encoded_bits.rs b/src/encoded_bits.rs index 051222c..5d462ce 100644 --- a/src/encoded_bits.rs +++ b/src/encoded_bits.rs @@ -7,7 +7,7 @@ use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytemuck::{cast_slice_mut, Pod, Zeroable}; use rand::distributions::{Distribution, Standard}; -use rand::{thread_rng, Rng}; +use rand::Rng; use serde::de::Error as _; use serde::{Deserialize, Deserializer, Serialize}; @@ -23,11 +23,10 @@ unsafe impl Pod for EncodedBits {} impl EncodedBits { /// Generate secret shares from this bitvector. - pub fn share(&self, n: usize) -> Box<[EncodedBits]> { + pub fn share(&self, n: usize, rng: &mut impl Rng) -> Box<[EncodedBits]> { assert!(n > 0); // Create `n - 1` random shares. - let mut rng = thread_rng(); let mut result: Box<[EncodedBits]> = iter::repeat_with(|| rng.gen::()) .take(n - 1) @@ -210,6 +209,8 @@ impl Serialize for EncodedBits { #[cfg(test)] mod tests { + use rand::thread_rng; + use super::*; #[test] diff --git a/src/lib.rs b/src/lib.rs index e1dcdab..9e8ad3a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,5 +10,6 @@ pub mod distance; pub mod encoded_bits; pub mod health_check; pub mod participant; +pub mod rng_source; pub mod template; pub mod utils; diff --git a/src/participant.rs b/src/participant.rs index 9f7af8c..ad36fa6 100644 --- a/src/participant.rs +++ b/src/participant.rs @@ -210,6 +210,14 @@ impl Participant { items } else { tracing::error!(?receipt_handle, "Failed to parse message body"); + + sqs_delete_message( + &self.sqs_client, + &self.config.queues.db_sync_queue_url, + receipt_handle, + ) + .await?; + return Ok(()); }; diff --git a/src/rng_source.rs b/src/rng_source.rs new file mode 100644 index 0000000..03c27d5 --- /dev/null +++ b/src/rng_source.rs @@ -0,0 +1,73 @@ +use std::fmt; +use std::str::FromStr; + +use rand::{thread_rng, RngCore, SeedableRng}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "kind")] +pub enum RngSource { + Thread, + Small(u64), + Std(u64), +} + +impl RngSource { + pub fn to_rng(&self) -> Box { + match self { + RngSource::Thread => Box::new(thread_rng()), + RngSource::Small(seed) => { + let rng: rand::rngs::SmallRng = + SeedableRng::seed_from_u64(*seed); + Box::new(rng) + } + RngSource::Std(seed) => { + let rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(*seed); + Box::new(rng) + } + } + } +} + +impl fmt::Display for RngSource { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + RngSource::Thread => write!(f, "thread"), + RngSource::Small(seed) => write!(f, "small:{}", seed), + RngSource::Std(seed) => write!(f, "std:{}", seed), + } + } +} + +impl FromStr for RngSource { + type Err = eyre::Error; + + fn from_str(s: &str) -> Result { + if s == "thread" { + Ok(RngSource::Thread) + } else if s.starts_with("small:") { + let seed = s.trim_start_matches("small:").parse()?; + Ok(RngSource::Small(seed)) + } else if s.starts_with("std:") { + let seed = s.trim_start_matches("std:").parse()?; + Ok(RngSource::Std(seed)) + } else { + Err(eyre::eyre!("Invalid RngSource: {}", s)) + } + } +} + +#[cfg(test)] +mod tests { + use test_case::test_case; + + use super::*; + + #[test_case("thread" => RngSource::Thread)] + #[test_case("std:42" => RngSource::Std(42))] + #[test_case("small:42" => RngSource::Small(42))] + fn serialization_round_trip(s: &str) -> RngSource { + s.parse().unwrap() + } +}