diff --git a/bin/migrate-codes/migrate.rs b/bin/migrate-codes/migrate.rs index 6d0eba3..f4bf7dc 100644 --- a/bin/migrate-codes/migrate.rs +++ b/bin/migrate-codes/migrate.rs @@ -1,9 +1,7 @@ -use std::time::Duration; - use clap::Parser; use iris_db::IrisCodeEntry; -use itertools::Itertools; use mpc::bits::Bits; +use mpc::db::Db; use mpc::distance::EncodedBits; use mpc::template::Template; use rand::thread_rng; @@ -18,16 +16,22 @@ mod mpc_db; #[derive(Parser)] pub struct Args { - #[clap(long, env)] + /// Connection string for the iris MongoDB + #[clap(alias = "ic", long, env)] pub iris_code_db: String, - #[clap(long, env)] + /// Connection string for the left coordinator Postgres DB + #[clap(alias = "lc", long, env)] pub left_coordinator_db: String, - #[clap(long, env)] + /// Connection strings for the left participant Postgres DBs + #[clap(alias = "lp", long, env)] pub left_participant_db: Vec, - #[clap(long, env)] + /// Connection strings for the right coordinator Postgres DBs + #[clap(alias = "rc", long, env)] pub right_coordinator_db: String, - #[clap(long, env)] + /// Connection string for the right participant Postgres DB + #[clap(alias = "rp", long, env)] pub right_participant_db: Vec, + #[clap(long, env, default_value = "10000")] pub batch_size: usize, } @@ -38,7 +42,9 @@ async fn main() -> eyre::Result<()> { let args = Args::parse(); - assert_eq!( + eyre::ensure!( + args.left_participant_db.len() == args.right_participant_db.len(), + "Number of participants on left & right must match (left: {}, right: {})", args.left_participant_db.len(), args.right_participant_db.len() ); @@ -64,7 +70,19 @@ async fn main() -> eyre::Result<()> { let left_data = encode_shares(left_templates, num_participants as usize)?; let right_data = encode_shares(right_templates, num_participants as usize)?; - //TODO: insert in chunks + insert_masks_and_shares( + &left_data, + &mpc_db.left_coordinator_db, + &mpc_db.left_participant_dbs, + ) + .await?; + + insert_masks_and_shares( + &right_data, + &mpc_db.right_coordinator_db, + &mpc_db.right_participant_dbs, + ) + .await?; Ok(()) } @@ -120,3 +138,29 @@ pub fn extract_templates( (left_templates, right_templates) } + +async fn insert_masks_and_shares( + data: &[(usize, Bits, Box<[EncodedBits]>)], + coordinator_db: &Db, + participant_dbs: &[Db], +) -> eyre::Result<()> { + // Insert masks + let left_masks: Vec<_> = data + .iter() + .map(|(serial_id, mask, _)| (*serial_id as u64, mask.clone())) + .collect(); + + coordinator_db.insert_masks(&left_masks).await?; + + // Insert shares to each participant + for (i, participant_db) in participant_dbs.iter().enumerate() { + let shares: Vec<_> = data + .iter() + .map(|(serial_id, _, shares)| (*serial_id as u64, shares[i])) + .collect(); + + participant_db.insert_shares(&shares).await?; + } + + Ok(()) +}