diff --git a/bin/migrate-codes/iris_db.rs b/bin/migrate-codes/iris_db.rs deleted file mode 100644 index cb2b55d..0000000 --- a/bin/migrate-codes/iris_db.rs +++ /dev/null @@ -1,78 +0,0 @@ -use futures::TryStreamExt; -use mongodb::bson::doc; -use mpc::bits::Bits; -use serde::{Deserialize, Serialize}; - -pub const IRIS_CODE_BATCH_SIZE: u32 = 30_000; -pub const DATABASE_NAME: &str = "iris"; -pub const COLLECTION_NAME: &str = "codes.v2"; - -#[derive(Serialize, Deserialize)] -pub struct IrisCodeEntry { - pub signup_id: String, - pub mpc_serial_id: u64, - pub iris_code_left: Bits, - pub mask_code_left: Bits, - pub iris_code_right: Bits, - pub mask_code_right: Bits, -} - -pub struct IrisDb { - pub db: mongodb::Database, -} - -impl IrisDb { - pub async fn new(url: String) -> eyre::Result<Self> { - let client_options = - mongodb::options::ClientOptions::parse(url).await?; - - let client: mongodb::Client = - mongodb::Client::with_options(client_options)?; - - let db = client.database(DATABASE_NAME); - - Ok(Self { db }) - } - - #[tracing::instrument(skip(self))] - pub async fn get_iris_code_snapshot( - &self, - ) -> eyre::Result<Vec<IrisCodeEntry>> { - let mut items = vec![]; - - let mut last_serial_id = 0_i64; - - let collection = self.db.collection(COLLECTION_NAME); - - loop { - let find_options = mongodb::options::FindOptions::builder() - .batch_size(IRIS_CODE_BATCH_SIZE) - .sort(doc! {"serial_id": 1}) - .build(); - - let mut cursor = collection - .find(doc! {"serial_id": {"$gt": last_serial_id}}, find_options) - .await?; - - let mut items_added = 0; - while let Some(document) = cursor.try_next().await? { - let iris_code_element = - mongodb::bson::from_document::<IrisCodeEntry>(document)?; - - last_serial_id += iris_code_element.mpc_serial_id as i64; - - items.push(iris_code_element); - - items_added += 1; - } - - if items_added == 0 { - break; - } - } - - items.sort_by(|a, b| a.mpc_serial_id.cmp(&b.mpc_serial_id)); - - Ok(items) - } -} diff --git a/bin/migrate-codes/migrate.rs b/bin/migrate-codes/migrate.rs index f4bf7dc..5dc9c7f 100644 --- a/bin/migrate-codes/migrate.rs +++ b/bin/migrate-codes/migrate.rs @@ -1,17 +1,22 @@ +#![allow(clippy::type_complexity)] + use clap::Parser; -use iris_db::IrisCodeEntry; +use eyre::ContextCompat; +use futures::{pin_mut, Stream, StreamExt}; +use indicatif::{ProgressBar, ProgressStyle}; use mpc::bits::Bits; use mpc::db::Db; use mpc::distance::EncodedBits; +use mpc::iris_db::{ + FinalResult, IrisCodeEntry, IrisDb, SideResult, FINAL_RESULT_STATUS, +}; use mpc::template::Template; use rand::thread_rng; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use telemetry_batteries::tracing::stdout::StdoutBattery; -use crate::iris_db::IrisDb; use crate::mpc_db::MPCDb; -mod iris_db; mod mpc_db; #[derive(Parser)] @@ -32,12 +37,19 @@ pub struct Args { #[clap(alias = "rp", long, env)] pub right_participant_db: Vec<String>, - #[clap(long, env, default_value = "10000")] - pub batch_size: usize, + /// Batch size for encoding shares + #[clap(short, long, default_value = "100")] + batch_size: usize, + + /// If set to true, no migration or creation of the database will occur on the Postgres side + #[clap(long)] + no_migrate_or_create: bool, } #[tokio::main] async fn main() -> eyre::Result<()> { + dotenv::dotenv().ok(); + let _shutdown_tracing_provider = StdoutBattery::init(); let args = Args::parse(); @@ -56,34 +68,63 @@ async fn main() -> eyre::Result<()> { args.left_participant_db, args.right_coordinator_db, args.right_participant_db, + args.no_migrate_or_create, ) .await?; let iris_db = IrisDb::new(args.iris_code_db).await?; - let iris_code_entries = iris_db.get_iris_code_snapshot().await?; - let (left_templates, right_templates) = - extract_templates(iris_code_entries); + let latest_serial_id = mpc_db.fetch_latest_serial_id().await?; + tracing::info!("Latest serial id {latest_serial_id}"); - let mut next_serial_id = mpc_db.fetch_latest_serial_id().await? + 1; + // Cleanup items with larger ids + // as they might be assigned new values in the future + mpc_db.prune_items(latest_serial_id).await?; + iris_db.prune_final_results(latest_serial_id).await?; - let left_data = encode_shares(left_templates, num_participants as usize)?; - let right_data = encode_shares(right_templates, num_participants as usize)?; + let first_unsynced_iris_serial_id = if let Some(final_result) = iris_db + .get_final_result_by_serial_id(latest_serial_id) + .await? + { + iris_db + .get_entry_by_signup_id(&final_result.signup_id) + .await? + .context("Could not find iris code entry")? + .serial_id + } else { + 0 + }; - insert_masks_and_shares( - &left_data, - &mpc_db.left_coordinator_db, - &mpc_db.left_participant_dbs, - ) - .await?; + let num_iris_codes = iris_db + .count_whitelisted_iris_codes(first_unsynced_iris_serial_id) + .await?; + tracing::info!("Processing {} iris codes", num_iris_codes); - insert_masks_and_shares( - &right_data, - &mpc_db.right_coordinator_db, - &mpc_db.right_participant_dbs, + let pb = + ProgressBar::new(num_iris_codes).with_message("Migrating iris codes"); + let pb_style = ProgressStyle::default_bar() + .template("{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.green}] {pos:>7}/{len:7} ({eta})") + .expect("Could not create progress bar"); + pb.set_style(pb_style); + + let iris_code_entries = iris_db + .stream_whitelisted_iris_codes(first_unsynced_iris_serial_id) + .await? + .chunks(args.batch_size) + .map(|chunk| chunk.into_iter().collect::<Result<Vec<_>, _>>()); + + handle_templates_stream( + iris_code_entries, + &iris_db, + &mpc_db, + num_participants as usize, + latest_serial_id, + &pb, ) .await?; + pb.finish(); + Ok(()) } @@ -95,7 +136,7 @@ pub struct MPCIrisData { pub fn encode_shares( template_data: Vec<(usize, Template)>, num_participants: usize, -) -> eyre::Result<(Vec<(usize, Bits, Box<[EncodedBits]>)>)> { +) -> eyre::Result<Vec<(usize, Bits, Box<[EncodedBits]>)>> { let iris_data = template_data .into_par_iter() .map(|(serial_id, template)| { @@ -111,32 +152,100 @@ pub fn encode_shares( Ok(iris_data) } -pub fn extract_templates( - iris_code_snapshot: Vec<IrisCodeEntry>, -) -> (Vec<(usize, Template)>, Vec<(usize, Template)>) { - let (left_templates, right_templates) = iris_code_snapshot - .into_iter() - .map(|entry| { - ( - ( - entry.mpc_serial_id as usize, - Template { - code: entry.iris_code_left, - mask: entry.mask_code_left, - }, - ), - ( - entry.mpc_serial_id as usize, - Template { - code: entry.iris_code_right, - mask: entry.mask_code_right, - }, - ), - ) - }) - .unzip(); +async fn handle_templates_stream( + iris_code_entries: impl Stream< + Item = mongodb::error::Result<Vec<IrisCodeEntry>>, + >, + iris_db: &IrisDb, + mpc_db: &MPCDb, + num_participants: usize, + mut latest_serial_id: u64, + pb: &ProgressBar, +) -> eyre::Result<()> { + pin_mut!(iris_code_entries); + + // Consume the stream + while let Some(entries) = iris_code_entries.next().await { + let entries = entries?; + + let count = entries.len() as u64; + + let left_data: Vec<_> = entries + .iter() + .enumerate() + .map(|(idx, entry)| { + let template = Template { + code: entry.iris_code_left, + mask: entry.mask_code_left, + }; + + (latest_serial_id as usize + 1 + idx, template) + }) + .collect(); + + let right_data: Vec<_> = entries + .iter() + .enumerate() + .map(|(idx, entry)| { + let template = Template { + code: entry.iris_code_right, + mask: entry.mask_code_right, + }; + + (latest_serial_id as usize + 1 + idx, template) + }) + .collect(); + + let left = handle_side_data_chunk( + left_data, + num_participants, + &mpc_db.left_coordinator_db, + &mpc_db.left_participant_dbs, + ); + + let right = handle_side_data_chunk( + right_data, + num_participants, + &mpc_db.right_coordinator_db, + &mpc_db.right_participant_dbs, + ); + + let final_results: Vec<_> = entries + .iter() + .enumerate() + .map(|(idx, entry)| FinalResult { + status: FINAL_RESULT_STATUS.to_string(), + serial_id: latest_serial_id + 1 + idx as u64, + signup_id: entry.signup_id.clone(), + unique: true, + right_result: SideResult {}, + left_result: SideResult {}, + }) + .collect(); + + let results = iris_db.save_final_results(&final_results); + + futures::try_join!(left, right, results)?; + + latest_serial_id += count; + pb.inc(count); + } - (left_templates, right_templates) + Ok(()) +} + +async fn handle_side_data_chunk( + templates: Vec<(usize, Template)>, + num_participants: usize, + coordinator_db: &Db, + participant_dbs: &[Db], +) -> eyre::Result<()> { + let left_data = encode_shares(templates, num_participants)?; + + insert_masks_and_shares(&left_data, coordinator_db, participant_dbs) + .await?; + + Ok(()) } async fn insert_masks_and_shares( @@ -147,7 +256,7 @@ async fn insert_masks_and_shares( // Insert masks let left_masks: Vec<_> = data .iter() - .map(|(serial_id, mask, _)| (*serial_id as u64, mask.clone())) + .map(|(serial_id, mask, _)| (*serial_id as u64, *mask)) .collect(); coordinator_db.insert_masks(&left_masks).await?; diff --git a/bin/migrate-codes/mpc_db.rs b/bin/migrate-codes/mpc_db.rs index 97aa200..dd9b606 100644 --- a/bin/migrate-codes/mpc_db.rs +++ b/bin/migrate-codes/mpc_db.rs @@ -1,10 +1,6 @@ -use std::collections::HashSet; - -use mpc::bits::Bits; +use eyre::ContextCompat; use mpc::config::DbConfig; use mpc::db::Db; -use mpc::distance::EncodedBits; -use mpc::template::Template; pub struct MPCDb { pub left_coordinator_db: Db, @@ -12,18 +8,23 @@ pub struct MPCDb { pub right_coordinator_db: Db, pub right_participant_dbs: Vec<Db>, } + impl MPCDb { pub async fn new( left_coordinator_db_url: String, left_participant_db_urls: Vec<String>, right_coordinator_db_url: String, right_participant_db_urls: Vec<String>, + no_migrate_or_create: bool, ) -> eyre::Result<Self> { + let migrate = !no_migrate_or_create; + let create = !no_migrate_or_create; + tracing::info!("Connecting to left coordinator db"); let left_coordinator_db = Db::new(&DbConfig { url: left_coordinator_db_url, - migrate: false, - create: false, + migrate, + create, }) .await?; @@ -33,8 +34,8 @@ impl MPCDb { tracing::info!(participant=?i, "Connecting to left participant db"); let db = Db::new(&DbConfig { url, - migrate: false, - create: false, + migrate, + create, }) .await?; left_participant_dbs.push(db); @@ -43,8 +44,8 @@ impl MPCDb { tracing::info!("Connecting to right coordinator db"); let right_coordinator_db = Db::new(&DbConfig { url: right_coordinator_db_url, - migrate: false, - create: false, + migrate, + create, }) .await?; @@ -53,10 +54,11 @@ impl MPCDb { tracing::info!(participant=?i, "Connecting to right participant db"); let db = Db::new(&DbConfig { url, - migrate: false, - create: false, + migrate, + create, }) .await?; + right_participant_dbs.push(db); } @@ -68,92 +70,37 @@ impl MPCDb { }) } - #[tracing::instrument(skip(self))] - pub async fn insert_shares_and_masks( - &self, - left_data: Vec<(u64, Bits, Box<[EncodedBits]>)>, - right_data: Vec<(u64, Bits, Box<[EncodedBits]>)>, - ) -> eyre::Result<()> { - //TODO: logging for progress - - let (left_masks, left_shares): ( - Vec<(u64, Bits)>, - Vec<Vec<(u64, EncodedBits)>>, - ) = left_data - .into_iter() - .map(|(id, mask, shares)| { - let shares: Vec<(u64, EncodedBits)> = - shares.into_iter().map(|share| (id, *share)).collect(); - - ((id, mask), shares) - }) - .unzip(); - - let (right_masks, right_shares): ( - Vec<(u64, Bits)>, - Vec<Vec<(u64, EncodedBits)>>, - ) = right_data - .into_iter() - .map(|(id, mask, shares)| { - let shares: Vec<(u64, EncodedBits)> = - shares.into_iter().map(|share| (id, *share)).collect(); - - ((id, mask), shares) - }) - .unzip(); - - let coordinator_tasks = vec![ - self.left_coordinator_db.insert_masks(&left_masks), - self.right_coordinator_db.insert_masks(&right_masks), - ]; - - let participant_tasks = self - .left_participant_dbs - .iter() - .zip(left_shares.iter()) - .chain(self.right_participant_dbs.iter().zip(right_shares.iter())) - .map(|(db, shares)| db.insert_shares(shares)); - - for task in coordinator_tasks { - task.await?; + #[tracing::instrument(skip(self,))] + pub async fn prune_items(&self, serial_id: u64) -> eyre::Result<()> { + self.left_coordinator_db.prune_items(serial_id).await?; + self.right_coordinator_db.prune_items(serial_id).await?; + + for db in self.left_participant_dbs.iter() { + db.prune_items(serial_id).await?; } - for task in participant_tasks { - task.await?; + for db in self.right_participant_dbs.iter() { + db.prune_items(serial_id).await?; } + Ok(()) } #[tracing::instrument(skip(self,))] pub async fn fetch_latest_serial_id(&self) -> eyre::Result<u64> { - let mut ids = HashSet::new(); - - let left_coordinator_id = - self.left_coordinator_db.fetch_latest_mask_id().await?; - - tracing::info!(?left_coordinator_id, "Latest left mask Id"); + let mut ids = Vec::new(); - let right_coordinator_id = - self.right_coordinator_db.fetch_latest_share_id().await?; - - tracing::info!(?right_coordinator_id, "Latest right mask Id"); - - for (i, db) in self.left_participant_dbs.iter().enumerate() { - let id = db.fetch_latest_share_id().await?; - tracing::info!(?id, participant=?i, "Latest left share Id"); - ids.insert(id); - } + ids.push(self.left_coordinator_db.fetch_latest_mask_id().await?); + ids.push(self.right_coordinator_db.fetch_latest_mask_id().await?); - for (i, db) in self.right_participant_dbs.iter().enumerate() { - let id = db.fetch_latest_share_id().await?; - tracing::info!(?id, participant=?i, "Latest right share Id"); - ids.insert(id); + for db in self.left_participant_dbs.iter() { + ids.push(db.fetch_latest_share_id().await?); } - if ids.len() != 1 { - return Err(eyre::eyre!("Mismatched serial ids")); + for db in self.right_participant_dbs.iter() { + ids.push(db.fetch_latest_share_id().await?); } - Ok(left_coordinator_id) + ids.into_iter().min().context("No serial ids found") } } diff --git a/bin/utils/common.rs b/bin/utils/common.rs index 8254ec3..74eacfb 100644 --- a/bin/utils/common.rs +++ b/bin/utils/common.rs @@ -30,9 +30,8 @@ pub fn generate_templates(num_templates: usize) -> Vec<Template> { templates } -pub fn generate_random_string(len: usize) -> String { - rand::thread_rng() - .sample_iter(&Alphanumeric) +pub fn generate_random_string(rng: &mut impl Rng, len: usize) -> String { + rng.sample_iter(&Alphanumeric) .take(len) .map(char::from) .collect() diff --git a/bin/utils/seed_iris_db.rs b/bin/utils/seed_iris_db.rs index 941492a..dd1dcca 100644 --- a/bin/utils/seed_iris_db.rs +++ b/bin/utils/seed_iris_db.rs @@ -1,7 +1,8 @@ use clap::Args; use indicatif::{ProgressBar, ProgressStyle}; -use mpc::bits::Bits; -use serde::{Deserialize, Serialize}; +use mpc::iris_db::IrisCodeEntry; +use mpc::rng_source::RngSource; +use rand::Rng; use crate::common::{generate_random_string, generate_templates}; @@ -18,6 +19,9 @@ pub struct SeedIrisDb { #[clap(short, long, default_value = "100")] pub batch_size: usize, + + #[clap(long, default_value = "thread")] + pub rng: RngSource, } pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> { @@ -41,18 +45,20 @@ pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> { tracing::info!(?next_serial_id); + let mut rng = args.rng.to_rng(); + let documents = left_templates .iter() .zip(right_templates.iter()) .enumerate() .map(|(serial_id, (left, right))| IrisCodeEntry { - signup_id: generate_random_string(10), - mpc_serial_id: next_serial_id + serial_id as u64, + signup_id: generate_random_string(&mut rng, 10), + serial_id: next_serial_id + serial_id as u64, iris_code_left: left.code, mask_code_left: left.mask, iris_code_right: right.code, mask_code_right: right.mask, - whitelisted: true, + whitelisted: rng.gen_bool(0.8), }) .collect::<Vec<IrisCodeEntry>>(); @@ -73,14 +79,3 @@ pub async fn seed_iris_db(args: &SeedIrisDb) -> eyre::Result<()> { Ok(()) } - -#[derive(Serialize, Deserialize)] -pub struct IrisCodeEntry { - pub signup_id: String, - pub mpc_serial_id: u64, - pub iris_code_left: Bits, - pub mask_code_left: Bits, - pub iris_code_right: Bits, - pub mask_code_right: Bits, - pub whitelisted: bool, -} diff --git a/bin/utils/sqs_query.rs b/bin/utils/sqs_query.rs index 9519553..e6b7ce3 100644 --- a/bin/utils/sqs_query.rs +++ b/bin/utils/sqs_query.rs @@ -1,9 +1,10 @@ use clap::Args; use mpc::config::AwsConfig; use mpc::coordinator::UniquenessCheckRequest; +use mpc::rng_source::RngSource; use mpc::template::Template; use mpc::utils::aws::sqs_client_from_config; -use rand::{thread_rng, Rng}; +use rand::Rng; use crate::common::generate_random_string; @@ -22,6 +23,9 @@ pub struct SQSQuery { /// The URL of the SQS queue #[clap(short, long)] pub queue_url: String, + + #[clap(long, default_value = "thread")] + pub rng: RngSource, } pub async fn sqs_query(args: &SQSQuery) -> eyre::Result<()> { @@ -31,11 +35,11 @@ pub async fn sqs_query(args: &SQSQuery) -> eyre::Result<()> { }) .await?; - let mut rng = thread_rng(); + let mut rng = args.rng.to_rng(); let plain_code: Template = rng.gen(); - let signup_id = generate_random_string(10); - let group_id = generate_random_string(10); + let signup_id = generate_random_string(&mut rng, 10); + let group_id = generate_random_string(&mut rng, 10); tracing::info!(?signup_id, ?group_id, "Sending message"); diff --git a/bin/utils/utils.rs b/bin/utils/utils.rs index d06fb73..17c6629 100644 --- a/bin/utils/utils.rs +++ b/bin/utils/utils.rs @@ -1,7 +1,5 @@ use clap::Parser; use generate_mock_templates::{generate_mock_templates, GenerateMockTemplates}; -use rand::distributions::Alphanumeric; -use rand::Rng; use seed_iris_db::{seed_iris_db, SeedIrisDb}; use seed_mpc_db::{seed_mpc_db, SeedMPCDb}; use sqs_query::{sqs_query, SQSQuery}; diff --git a/compose_all_sides.yml b/compose_all_sides.yml new file mode 100644 index 0000000..67d9c85 --- /dev/null +++ b/compose_all_sides.yml @@ -0,0 +1,46 @@ +version: "3" +services: + coordinator_left_db: + image: postgres + ports: + - 8432:5432 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + participant_left_0_db: + image: postgres + ports: + - 8433:5432 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + participant_left_1_db: + image: postgres + ports: + - 8534:5432 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + coordinator_right_db: + image: postgres + ports: + - 8532:5432 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + participant_right_0_db: + image: postgres + ports: + - 8533:5432 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + participant_right_1_db: + image: postgres + ports: + - 8434:5432 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + iris_db: + hostname: iris_db + image: mongo + ports: + - 27017:27017 + environment: + - MONGO_INITDB_ROOT_USERNAME=admin + - MONGO_INITDB_ROOT_PASSWORD=password diff --git a/src/db.rs b/src/db.rs index d7a98f5..df5dfbf 100644 --- a/src/db.rs +++ b/src/db.rs @@ -167,6 +167,32 @@ impl Db { Ok(()) } + + /// Removes masks and shares with serial IDs larger than `serial_id`. + #[tracing::instrument(skip(self))] + pub async fn prune_items(&self, serial_id: u64) -> eyre::Result<()> { + sqlx::query( + r#" + DELETE FROM masks + WHERE id > $1 + "#, + ) + .bind(serial_id as i64) + .execute(&self.pool) + .await?; + + sqlx::query( + r#" + DELETE FROM shares + WHERE id > $1 + "#, + ) + .bind(serial_id as i64) + .execute(&self.pool) + .await?; + + Ok(()) + } } fn filter_sequential_items<T>( @@ -317,14 +343,46 @@ mod tests { #[tokio::test] async fn fetch_latest_mask_id() -> eyre::Result<()> { - todo!(); + let (db, _pg) = setup().await?; + + let mut rng = thread_rng(); + + let masks = vec![ + (1, rng.gen::<Bits>()), + (2, rng.gen::<Bits>()), + (5, rng.gen::<Bits>()), + (6, rng.gen::<Bits>()), + (8, rng.gen::<Bits>()), + ]; + + db.insert_masks(&masks).await?; + + let latest_mask_id = db.fetch_latest_mask_id().await?; + + assert_eq!(latest_mask_id, 8); Ok(()) } #[tokio::test] async fn fetch_latest_share_id() -> eyre::Result<()> { - todo!(); + let (db, _pg) = setup().await?; + + let mut rng = thread_rng(); + + let shares = vec![ + (1, rng.gen::<EncodedBits>()), + (2, rng.gen::<EncodedBits>()), + (5, rng.gen::<EncodedBits>()), + (6, rng.gen::<EncodedBits>()), + (8, rng.gen::<EncodedBits>()), + ]; + + db.insert_shares(&shares).await?; + + let latest_share_id = db.fetch_latest_share_id().await?; + + assert_eq!(latest_share_id, 8); Ok(()) } @@ -427,4 +485,41 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn prune_items() -> eyre::Result<()> { + let (db, _pg) = setup().await?; + + let mut rng = thread_rng(); + + let masks = vec![ + (1, rng.gen::<Bits>()), + (2, rng.gen::<Bits>()), + (3, rng.gen::<Bits>()), + (4, rng.gen::<Bits>()), + (6, rng.gen::<Bits>()), + ]; + + db.insert_masks(&masks).await?; + + let shares = vec![ + (1, rng.gen::<EncodedBits>()), + (2, rng.gen::<EncodedBits>()), + (3, rng.gen::<EncodedBits>()), + (4, rng.gen::<EncodedBits>()), + (6, rng.gen::<EncodedBits>()), + ]; + + db.insert_shares(&shares).await?; + + db.prune_items(2).await?; + + let fetched_masks = db.fetch_masks(0).await?; + let fetched_shares = db.fetch_shares(0).await?; + + assert_eq!(fetched_masks.len(), 2); + assert_eq!(fetched_shares.len(), 2); + + Ok(()) + } } diff --git a/src/iris_db.rs b/src/iris_db.rs new file mode 100644 index 0000000..0e0c177 --- /dev/null +++ b/src/iris_db.rs @@ -0,0 +1,182 @@ +use futures::{Stream, TryStreamExt}; +use mongodb::bson::doc; +use mongodb::Collection; +use serde::{Deserialize, Serialize}; + +use crate::bits::Bits; + +pub const IRIS_CODE_BATCH_SIZE: u32 = 30_000; +pub const DATABASE_NAME: &str = "iris"; +pub const COLLECTION_NAME: &str = "codes.v2"; +pub const FINAL_RESULT_COLLECTION_NAME: &str = "mpc.results"; +pub const FINAL_RESULT_STATUS: &str = "COMPLETED"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IrisCodeEntry { + pub signup_id: String, + /// Internal serial id of the iris code db + pub serial_id: u64, + pub iris_code_left: Bits, + pub mask_code_left: Bits, + pub iris_code_right: Bits, + pub mask_code_right: Bits, + pub whitelisted: bool, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct FinalResult { + /// Should always by "COMPLETED" + pub status: String, + + /// The MPC serial id associated with this signup + pub serial_id: u64, + + /// A unique signup id string + pub signup_id: String, + + pub unique: bool, + + pub right_result: SideResult, + + pub left_result: SideResult, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SideResult { + // This struct is intentionally left empty to represent an empty document. +} + +pub struct IrisDb { + pub db: mongodb::Database, +} + +impl IrisDb { + pub async fn new(url: String) -> eyre::Result<Self> { + let client_options = + mongodb::options::ClientOptions::parse(url).await?; + + let client: mongodb::Client = + mongodb::Client::with_options(client_options)?; + + let db = client.database(DATABASE_NAME); + + Ok(Self { db }) + } + + #[tracing::instrument(skip(self))] + pub async fn save_final_results( + &self, + final_results: &[FinalResult], + ) -> eyre::Result<()> { + let collection: Collection<FinalResult> = + self.db.collection(FINAL_RESULT_COLLECTION_NAME); + + collection.insert_many(final_results, None).await?; + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub async fn get_final_result_by_serial_id( + &self, + serial_id: u64, + ) -> eyre::Result<Option<FinalResult>> { + let collection: Collection<FinalResult> = + self.db.collection(FINAL_RESULT_COLLECTION_NAME); + + let final_result = collection + .find_one(doc! { "serial_id": serial_id as i64 }, None) + .await?; + + Ok(final_result) + } + + /// Removes all final result entries with serial id larger than the given one + #[tracing::instrument(skip(self))] + pub async fn prune_final_results( + &self, + serial_id: u64, + ) -> eyre::Result<()> { + let collection: Collection<FinalResult> = + self.db.collection(FINAL_RESULT_COLLECTION_NAME); + + collection + .delete_many( + doc! { "serial_id": { "$gt": serial_id as i64 } }, + None, + ) + .await?; + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub async fn count_whitelisted_iris_codes( + &self, + last_serial_id: u64, + ) -> eyre::Result<u64> { + let collection: Collection<IrisCodeEntry> = + self.db.collection(COLLECTION_NAME); + + let count = collection + .count_documents( + doc! { + "serial_id": {"$gt": last_serial_id as i64}, + "whitelisted": true, + }, + None, + ) + .await?; + + Ok(count) + } + + #[tracing::instrument(skip(self))] + pub async fn get_entry_by_signup_id( + &self, + signup_id: &str, + ) -> eyre::Result<Option<IrisCodeEntry>> { + let collection: Collection<IrisCodeEntry> = + self.db.collection(COLLECTION_NAME); + + let iris_code_entry = collection + .find_one(doc! {"signup_id": signup_id}, None) + .await?; + + Ok(iris_code_entry) + } + + #[tracing::instrument(skip(self))] + pub async fn stream_whitelisted_iris_codes( + &self, + last_serial_id: u64, + ) -> eyre::Result< + impl Stream<Item = Result<IrisCodeEntry, mongodb::error::Error>>, + > { + let find_options = mongodb::options::FindOptions::builder() + .batch_size(IRIS_CODE_BATCH_SIZE) + .sort(doc! { "serial_id": 1 }) + .build(); + + let collection = self.db.collection(COLLECTION_NAME); + + let cursor = collection + .find( + doc! { + "serial_id": {"$gt": last_serial_id as i64}, + "whitelisted": true + }, + find_options, + ) + .await?; + + let codes_stream = cursor.and_then(|document| async move { + let iris_code_element = + mongodb::bson::from_document::<IrisCodeEntry>(document)?; + + eyre::Result::Ok(iris_code_element) + }); + + Ok(codes_stream) + } +} diff --git a/src/lib.rs b/src/lib.rs index 9e8ad3a..a9351eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod db; pub mod distance; pub mod encoded_bits; pub mod health_check; +pub mod iris_db; pub mod participant; pub mod rng_source; pub mod template;