From f5d483ab543fa0905eefa53caf8dcdc369d59a74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Tr=C4=85d?= Date: Thu, 7 Mar 2024 13:09:46 +0100 Subject: [PATCH] Dzejkop/migrate-codes-streaming-impl (#78) * WIP: Streaming * Test compose file * Progress bar + more logging * Clippy + fmt + cleanup * Simplify * Update name * Cleanup * Minor refactor + add pruning functionality * Iris DB methods * Handle final results + checkpoints --- bin/migrate-codes/iris_db.rs | 78 ------------- bin/migrate-codes/migrate.rs | 205 +++++++++++++++++++++++++++-------- bin/migrate-codes/mpc_db.rs | 119 ++++++-------------- bin/utils/common.rs | 5 +- bin/utils/seed_iris_db.rs | 27 ++--- bin/utils/sqs_query.rs | 12 +- bin/utils/utils.rs | 2 - compose_all_sides.yml | 46 ++++++++ src/db.rs | 99 ++++++++++++++++- src/iris_db.rs | 182 +++++++++++++++++++++++++++++++ src/lib.rs | 1 + 11 files changed, 537 insertions(+), 239 deletions(-) delete mode 100644 bin/migrate-codes/iris_db.rs create mode 100644 compose_all_sides.yml create mode 100644 src/iris_db.rs diff --git a/bin/migrate-codes/iris_db.rs b/bin/migrate-codes/iris_db.rs deleted file mode 100644 index cb2b55d..0000000 --- a/bin/migrate-codes/iris_db.rs +++ /dev/null @@ -1,78 +0,0 @@ -use futures::TryStreamExt; -use mongodb::bson::doc; -use mpc::bits::Bits; -use serde::{Deserialize, Serialize}; - -pub const IRIS_CODE_BATCH_SIZE: u32 = 30_000; -pub const DATABASE_NAME: &str = "iris"; -pub const COLLECTION_NAME: &str = "codes.v2"; - -#[derive(Serialize, Deserialize)] -pub struct IrisCodeEntry { - pub signup_id: String, - pub mpc_serial_id: u64, - pub iris_code_left: Bits, - pub mask_code_left: Bits, - pub iris_code_right: Bits, - pub mask_code_right: Bits, -} - -pub struct IrisDb { - pub db: mongodb::Database, -} - -impl IrisDb { - pub async fn new(url: String) -> eyre::Result { - let client_options = - mongodb::options::ClientOptions::parse(url).await?; - - let client: mongodb::Client = - mongodb::Client::with_options(client_options)?; - - let db = client.database(DATABASE_NAME); - - Ok(Self { db }) - } - - #[tracing::instrument(skip(self))] - pub async fn get_iris_code_snapshot( - &self, - ) -> eyre::Result> { - let mut items = vec![]; - - let mut last_serial_id = 0_i64; - - let collection = self.db.collection(COLLECTION_NAME); - - loop { - let find_options = mongodb::options::FindOptions::builder() - .batch_size(IRIS_CODE_BATCH_SIZE) - .sort(doc! {"serial_id": 1}) - .build(); - - let mut cursor = collection - .find(doc! {"serial_id": {"$gt": last_serial_id}}, find_options) - .await?; - - let mut items_added = 0; - while let Some(document) = cursor.try_next().await? { - let iris_code_element = - mongodb::bson::from_document::(document)?; - - last_serial_id += iris_code_element.mpc_serial_id as i64; - - items.push(iris_code_element); - - items_added += 1; - } - - if items_added == 0 { - break; - } - } - - items.sort_by(|a, b| a.mpc_serial_id.cmp(&b.mpc_serial_id)); - - Ok(items) - } -} diff --git a/bin/migrate-codes/migrate.rs b/bin/migrate-codes/migrate.rs index f4bf7dc..5dc9c7f 100644 --- a/bin/migrate-codes/migrate.rs +++ b/bin/migrate-codes/migrate.rs @@ -1,17 +1,22 @@ +#![allow(clippy::type_complexity)] + use clap::Parser; -use iris_db::IrisCodeEntry; +use eyre::ContextCompat; +use futures::{pin_mut, Stream, StreamExt}; +use indicatif::{ProgressBar, ProgressStyle}; use mpc::bits::Bits; use mpc::db::Db; use mpc::distance::EncodedBits; +use mpc::iris_db::{ + FinalResult, IrisCodeEntry, IrisDb, SideResult, FINAL_RESULT_STATUS, +}; use mpc::template::Template; use rand::thread_rng; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use telemetry_batteries::tracing::stdout::StdoutBattery; -use crate::iris_db::IrisDb; use crate::mpc_db::MPCDb; -mod iris_db; mod mpc_db; #[derive(Parser)] @@ -32,12 +37,19 @@ pub struct Args { #[clap(alias = "rp", long, env)] pub right_participant_db: Vec, - #[clap(long, env, default_value = "10000")] - pub batch_size: usize, + /// Batch size for encoding shares + #[clap(short, long, default_value = "100")] + batch_size: usize, + + /// If set to true, no migration or creation of the database will occur on the Postgres side + #[clap(long)] + no_migrate_or_create: bool, } #[tokio::main] async fn main() -> eyre::Result<()> { + dotenv::dotenv().ok(); + let _shutdown_tracing_provider = StdoutBattery::init(); let args = Args::parse(); @@ -56,34 +68,63 @@ async fn main() -> eyre::Result<()> { args.left_participant_db, args.right_coordinator_db, args.right_participant_db, + args.no_migrate_or_create, ) .await?; let iris_db = IrisDb::new(args.iris_code_db).await?; - let iris_code_entries = iris_db.get_iris_code_snapshot().await?; - let (left_templates, right_templates) = - extract_templates(iris_code_entries); + let latest_serial_id = mpc_db.fetch_latest_serial_id().await?; + tracing::info!("Latest serial id {latest_serial_id}"); - let mut next_serial_id = mpc_db.fetch_latest_serial_id().await? + 1; + // Cleanup items with larger ids + // as they might be assigned new values in the future + mpc_db.prune_items(latest_serial_id).await?; + iris_db.prune_final_results(latest_serial_id).await?; - let left_data = encode_shares(left_templates, num_participants as usize)?; - let right_data = encode_shares(right_templates, num_participants as usize)?; + let first_unsynced_iris_serial_id = if let Some(final_result) = iris_db + .get_final_result_by_serial_id(latest_serial_id) + .await? + { + iris_db + .get_entry_by_signup_id(&final_result.signup_id) + .await? + .context("Could not find iris code entry")? + .serial_id + } else { + 0 + }; - insert_masks_and_shares( - &left_data, - &mpc_db.left_coordinator_db, - &mpc_db.left_participant_dbs, - ) - .await?; + let num_iris_codes = iris_db + .count_whitelisted_iris_codes(first_unsynced_iris_serial_id) + .await?; + tracing::info!("Processing {} iris codes", num_iris_codes); - insert_masks_and_shares( - &right_data, - &mpc_db.right_coordinator_db, - &mpc_db.right_participant_dbs, + let pb = + ProgressBar::new(num_iris_codes).with_message("Migrating iris codes"); + let pb_style = ProgressStyle::default_bar() + .template("{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.green}] {pos:>7}/{len:7} ({eta})") + .expect("Could not create progress bar"); + pb.set_style(pb_style); + + let iris_code_entries = iris_db + .stream_whitelisted_iris_codes(first_unsynced_iris_serial_id) + .await? + .chunks(args.batch_size) + .map(|chunk| chunk.into_iter().collect::, _>>()); + + handle_templates_stream( + iris_code_entries, + &iris_db, + &mpc_db, + num_participants as usize, + latest_serial_id, + &pb, ) .await?; + pb.finish(); + Ok(()) } @@ -95,7 +136,7 @@ pub struct MPCIrisData { pub fn encode_shares( template_data: Vec<(usize, Template)>, num_participants: usize, -) -> eyre::Result<(Vec<(usize, Bits, Box<[EncodedBits]>)>)> { +) -> eyre::Result)>> { let iris_data = template_data .into_par_iter() .map(|(serial_id, template)| { @@ -111,32 +152,100 @@ pub fn encode_shares( Ok(iris_data) } -pub fn extract_templates( - iris_code_snapshot: Vec, -) -> (Vec<(usize, Template)>, Vec<(usize, Template)>) { - let (left_templates, right_templates) = iris_code_snapshot - .into_iter() - .map(|entry| { - ( - ( - entry.mpc_serial_id as usize, - Template { - code: entry.iris_code_left, - mask: entry.mask_code_left, - }, - ), - ( - entry.mpc_serial_id as usize, - Template { - code: entry.iris_code_right, - mask: entry.mask_code_right, - }, - ), - ) - }) - .unzip(); +async fn handle_templates_stream( + iris_code_entries: impl Stream< + Item = mongodb::error::Result>, + >, + iris_db: &IrisDb, + mpc_db: &MPCDb, + num_participants: usize, + mut latest_serial_id: u64, + pb: &ProgressBar, +) -> eyre::Result<()> { + pin_mut!(iris_code_entries); + + // Consume the stream + while let Some(entries) = iris_code_entries.next().await { + let entries = entries?; + + let count = entries.len() as u64; + + let left_data: Vec<_> = entries + .iter() + .enumerate() + .map(|(idx, entry)| { + let template = Template { + code: entry.iris_code_left, + mask: entry.mask_code_left, + }; + + (latest_serial_id as usize + 1 + idx, template) + }) + .collect(); + + let right_data: Vec<_> = entries + .iter() + .enumerate() + .map(|(idx, entry)| { + let template = Template { + code: entry.iris_code_right, + mask: entry.mask_code_right, + }; + + (latest_serial_id as usize + 1 + idx, template) + }) + .collect(); + + let left = handle_side_data_chunk( + left_data, + num_participants, + &mpc_db.left_coordinator_db, + &mpc_db.left_participant_dbs, + ); + + let right = handle_side_data_chunk( + right_data, + num_participants, + &mpc_db.right_coordinator_db, + &mpc_db.right_participant_dbs, + ); + + let final_results: Vec<_> = entries + .iter() + .enumerate() + .map(|(idx, entry)| FinalResult { + status: FINAL_RESULT_STATUS.to_string(), + serial_id: latest_serial_id + 1 + idx as u64, + signup_id: entry.signup_id.clone(), + unique: true, + right_result: SideResult {}, + left_result: SideResult {}, + }) + .collect(); + + let results = iris_db.save_final_results(&final_results); + + futures::try_join!(left, right, results)?; + + latest_serial_id += count; + pb.inc(count); + } - (left_templates, right_templates) + Ok(()) +} + +async fn handle_side_data_chunk( + templates: Vec<(usize, Template)>, + num_participants: usize, + coordinator_db: &Db, + participant_dbs: &[Db], +) -> eyre::Result<()> { + let left_data = encode_shares(templates, num_participants)?; + + insert_masks_and_shares(&left_data, coordinator_db, participant_dbs) + .await?; + + Ok(()) } async fn insert_masks_and_shares( @@ -147,7 +256,7 @@ async fn insert_masks_and_shares( // Insert masks let left_masks: Vec<_> = data .iter() - .map(|(serial_id, mask, _)| (*serial_id as u64, mask.clone())) + .map(|(serial_id, mask, _)| (*serial_id as u64, *mask)) .collect(); coordinator_db.insert_masks(&left_masks).await?; diff --git a/bin/migrate-codes/mpc_db.rs b/bin/migrate-codes/mpc_db.rs index 97aa200..dd9b606 100644 --- a/bin/migrate-codes/mpc_db.rs +++ b/bin/migrate-codes/mpc_db.rs @@ -1,10 +1,6 @@ -use std::collections::HashSet; - -use mpc::bits::Bits; +use eyre::ContextCompat; use mpc::config::DbConfig; use mpc::db::Db; -use mpc::distance::EncodedBits; -use mpc::template::Template; pub struct MPCDb { pub left_coordinator_db: Db, @@ -12,18 +8,23 @@ pub struct MPCDb { pub right_coordinator_db: Db, pub right_participant_dbs: Vec, } + impl MPCDb { pub async fn new( left_coordinator_db_url: String, left_participant_db_urls: Vec, right_coordinator_db_url: String, right_participant_db_urls: Vec, + no_migrate_or_create: bool, ) -> eyre::Result { + let migrate = !no_migrate_or_create; + let create = !no_migrate_or_create; + tracing::info!("Connecting to left coordinator db"); let left_coordinator_db = Db::new(&DbConfig { url: left_coordinator_db_url, - migrate: false, - create: false, + migrate, + create, }) .await?; @@ -33,8 +34,8 @@ impl MPCDb { tracing::info!(participant=?i, "Connecting to left participant db"); let db = Db::new(&DbConfig { url, - migrate: false, - create: false, + migrate, + create, }) .await?; left_participant_dbs.push(db); @@ -43,8 +44,8 @@ impl MPCDb { tracing::info!("Connecting to right coordinator db"); let right_coordinator_db = Db::new(&DbConfig { url: right_coordinator_db_url, - migrate: false, - create: false, + migrate, + create, }) .await?; @@ -53,10 +54,11 @@ impl MPCDb { tracing::info!(participant=?i, "Connecting to right participant db"); let db = Db::new(&DbConfig { url, - migrate: false, - create: false, + migrate, + create, }) .await?; + right_participant_dbs.push(db); } @@ -68,92 +70,37 @@ impl MPCDb { }) } - #[tracing::instrument(skip(self))] - pub async fn insert_shares_and_masks( - &self, - left_data: Vec<(u64, Bits, Box<[EncodedBits]>)>, - right_data: Vec<(u64, Bits, Box<[EncodedBits]>)>, - ) -> eyre::Result<()> { - //TODO: logging for progress - - let (left_masks, left_shares): ( - Vec<(u64, Bits)>, - Vec>, - ) = left_data - .into_iter() - .map(|(id, mask, shares)| { - let shares: Vec<(u64, EncodedBits)> = - shares.into_iter().map(|share| (id, *share)).collect(); - - ((id, mask), shares) - }) - .unzip(); - - let (right_masks, right_shares): ( - Vec<(u64, Bits)>, - Vec>, - ) = right_data - .into_iter() - .map(|(id, mask, shares)| { - let shares: Vec<(u64, EncodedBits)> = - shares.into_iter().map(|share| (id, *share)).collect(); - - ((id, mask), shares) - }) - .unzip(); - - let coordinator_tasks = vec![ - self.left_coordinator_db.insert_masks(&left_masks), - self.right_coordinator_db.insert_masks(&right_masks), - ]; - - let participant_tasks = self - .left_participant_dbs - .iter() - .zip(left_shares.iter()) - .chain(self.right_participant_dbs.iter().zip(right_shares.iter())) - .map(|(db, shares)| db.insert_shares(shares)); - - for task in coordinator_tasks { - task.await?; + #[tracing::instrument(skip(self,))] + pub async fn prune_items(&self, serial_id: u64) -> eyre::Result<()> { + self.left_coordinator_db.prune_items(serial_id).await?; + self.right_coordinator_db.prune_items(serial_id).await?; + + for db in self.left_participant_dbs.iter() { + db.prune_items(serial_id).await?; } - for task in participant_tasks { - task.await?; + for db in self.right_participant_dbs.iter() { + db.prune_items(serial_id).await?; } + Ok(()) } #[tracing::instrument(skip(self,))] pub async fn fetch_latest_serial_id(&self) -> eyre::Result { - let mut ids = HashSet::new(); - - let left_coordinator_id = - self.left_coordinator_db.fetch_latest_mask_id().await?; - - tracing::info!(?left_coordinator_id, "Latest left mask Id"); + let mut ids = Vec::new(); - let right_coordinator_id = - self.right_coordinator_db.fetch_latest_share_id().await?; - - tracing::info!(?right_coordinator_id, "Latest right mask Id"); - - for (i, db) in self.left_participant_dbs.iter().enumerate() { - let id = db.fetch_latest_share_id().await?; - tracing::info!(?id, participant=?i, "Latest left share Id"); - ids.insert(id); - } + ids.push(self.left_coordinator_db.fetch_latest_mask_id().await?); + ids.push(self.right_coordinator_db.fetch_latest_mask_id().await?); - for (i, db) in self.right_participant_dbs.iter().enumerate() { - let id = db.fetch_latest_share_id().await?; - tracing::info!(?id, participant=?i, "Latest right share Id"); - ids.insert(id); + for db in self.left_participant_dbs.iter() { + ids.push(db.fetch_latest_share_id().await?); } - if ids.len() != 1 { - return Err(eyre::eyre!("Mismatched serial ids")); + for db in self.right_participant_dbs.iter() { + ids.push(db.fetch_latest_share_id().await?); } - Ok(left_coordinator_id) + ids.into_iter().min().context("No serial ids found") } } diff --git a/bin/utils/common.rs b/bin/utils/common.rs index 8254ec3..74eacfb 100644 --- a/bin/utils/common.rs +++ b/bin/utils/common.rs @@ -30,9 +30,8 @@ pub fn generate_templates(num_templates: usize) -> Vec