Skip to content

Commit

Permalink
Dzejkop/migrate-codes-streaming-impl (#78)
Browse files Browse the repository at this point in the history
* WIP: Streaming

* Test compose file

* Progress bar + more logging

* Clippy + fmt + cleanup

* Simplify

* Update name

* Cleanup

* Minor refactor + add pruning functionality

* Iris DB methods

* Handle final results + checkpoints
  • Loading branch information
Dzejkop authored Mar 7, 2024
1 parent 833f25a commit f5d483a
Show file tree
Hide file tree
Showing 11 changed files with 537 additions and 239 deletions.
78 changes: 0 additions & 78 deletions bin/migrate-codes/iris_db.rs

This file was deleted.

205 changes: 157 additions & 48 deletions bin/migrate-codes/migrate.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
#![allow(clippy::type_complexity)]

use clap::Parser;
use iris_db::IrisCodeEntry;
use eyre::ContextCompat;
use futures::{pin_mut, Stream, StreamExt};
use indicatif::{ProgressBar, ProgressStyle};
use mpc::bits::Bits;
use mpc::db::Db;
use mpc::distance::EncodedBits;
use mpc::iris_db::{
FinalResult, IrisCodeEntry, IrisDb, SideResult, FINAL_RESULT_STATUS,
};
use mpc::template::Template;
use rand::thread_rng;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use telemetry_batteries::tracing::stdout::StdoutBattery;

use crate::iris_db::IrisDb;
use crate::mpc_db::MPCDb;

mod iris_db;
mod mpc_db;

#[derive(Parser)]
Expand All @@ -32,12 +37,19 @@ pub struct Args {
#[clap(alias = "rp", long, env)]
pub right_participant_db: Vec<String>,

#[clap(long, env, default_value = "10000")]
pub batch_size: usize,
/// Batch size for encoding shares
#[clap(short, long, default_value = "100")]
batch_size: usize,

/// If set to true, no migration or creation of the database will occur on the Postgres side
#[clap(long)]
no_migrate_or_create: bool,
}

#[tokio::main]
async fn main() -> eyre::Result<()> {
dotenv::dotenv().ok();

let _shutdown_tracing_provider = StdoutBattery::init();

let args = Args::parse();
Expand All @@ -56,34 +68,63 @@ async fn main() -> eyre::Result<()> {
args.left_participant_db,
args.right_coordinator_db,
args.right_participant_db,
args.no_migrate_or_create,
)
.await?;

let iris_db = IrisDb::new(args.iris_code_db).await?;

let iris_code_entries = iris_db.get_iris_code_snapshot().await?;
let (left_templates, right_templates) =
extract_templates(iris_code_entries);
let latest_serial_id = mpc_db.fetch_latest_serial_id().await?;
tracing::info!("Latest serial id {latest_serial_id}");

let mut next_serial_id = mpc_db.fetch_latest_serial_id().await? + 1;
// Cleanup items with larger ids
// as they might be assigned new values in the future
mpc_db.prune_items(latest_serial_id).await?;
iris_db.prune_final_results(latest_serial_id).await?;

let left_data = encode_shares(left_templates, num_participants as usize)?;
let right_data = encode_shares(right_templates, num_participants as usize)?;
let first_unsynced_iris_serial_id = if let Some(final_result) = iris_db
.get_final_result_by_serial_id(latest_serial_id)
.await?
{
iris_db
.get_entry_by_signup_id(&final_result.signup_id)
.await?
.context("Could not find iris code entry")?
.serial_id
} else {
0
};

insert_masks_and_shares(
&left_data,
&mpc_db.left_coordinator_db,
&mpc_db.left_participant_dbs,
)
.await?;
let num_iris_codes = iris_db
.count_whitelisted_iris_codes(first_unsynced_iris_serial_id)
.await?;
tracing::info!("Processing {} iris codes", num_iris_codes);

insert_masks_and_shares(
&right_data,
&mpc_db.right_coordinator_db,
&mpc_db.right_participant_dbs,
let pb =
ProgressBar::new(num_iris_codes).with_message("Migrating iris codes");
let pb_style = ProgressStyle::default_bar()
.template("{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.green}] {pos:>7}/{len:7} ({eta})")
.expect("Could not create progress bar");
pb.set_style(pb_style);

let iris_code_entries = iris_db
.stream_whitelisted_iris_codes(first_unsynced_iris_serial_id)
.await?
.chunks(args.batch_size)
.map(|chunk| chunk.into_iter().collect::<Result<Vec<_>, _>>());

handle_templates_stream(
iris_code_entries,
&iris_db,
&mpc_db,
num_participants as usize,
latest_serial_id,
&pb,
)
.await?;

pb.finish();

Ok(())
}

Expand All @@ -95,7 +136,7 @@ pub struct MPCIrisData {
pub fn encode_shares(
template_data: Vec<(usize, Template)>,
num_participants: usize,
) -> eyre::Result<(Vec<(usize, Bits, Box<[EncodedBits]>)>)> {
) -> eyre::Result<Vec<(usize, Bits, Box<[EncodedBits]>)>> {
let iris_data = template_data
.into_par_iter()
.map(|(serial_id, template)| {
Expand All @@ -111,32 +152,100 @@ pub fn encode_shares(
Ok(iris_data)
}

pub fn extract_templates(
iris_code_snapshot: Vec<IrisCodeEntry>,
) -> (Vec<(usize, Template)>, Vec<(usize, Template)>) {
let (left_templates, right_templates) = iris_code_snapshot
.into_iter()
.map(|entry| {
(
(
entry.mpc_serial_id as usize,
Template {
code: entry.iris_code_left,
mask: entry.mask_code_left,
},
),
(
entry.mpc_serial_id as usize,
Template {
code: entry.iris_code_right,
mask: entry.mask_code_right,
},
),
)
})
.unzip();
async fn handle_templates_stream(
iris_code_entries: impl Stream<
Item = mongodb::error::Result<Vec<IrisCodeEntry>>,
>,
iris_db: &IrisDb,
mpc_db: &MPCDb,
num_participants: usize,
mut latest_serial_id: u64,
pb: &ProgressBar,
) -> eyre::Result<()> {
pin_mut!(iris_code_entries);

// Consume the stream
while let Some(entries) = iris_code_entries.next().await {
let entries = entries?;

let count = entries.len() as u64;

let left_data: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| {
let template = Template {
code: entry.iris_code_left,
mask: entry.mask_code_left,
};

(latest_serial_id as usize + 1 + idx, template)
})
.collect();

let right_data: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| {
let template = Template {
code: entry.iris_code_right,
mask: entry.mask_code_right,
};

(latest_serial_id as usize + 1 + idx, template)
})
.collect();

let left = handle_side_data_chunk(
left_data,
num_participants,
&mpc_db.left_coordinator_db,
&mpc_db.left_participant_dbs,
);

let right = handle_side_data_chunk(
right_data,
num_participants,
&mpc_db.right_coordinator_db,
&mpc_db.right_participant_dbs,
);

let final_results: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| FinalResult {
status: FINAL_RESULT_STATUS.to_string(),
serial_id: latest_serial_id + 1 + idx as u64,
signup_id: entry.signup_id.clone(),
unique: true,
right_result: SideResult {},
left_result: SideResult {},
})
.collect();

let results = iris_db.save_final_results(&final_results);

futures::try_join!(left, right, results)?;

latest_serial_id += count;
pb.inc(count);
}

(left_templates, right_templates)
Ok(())
}

async fn handle_side_data_chunk(
templates: Vec<(usize, Template)>,
num_participants: usize,
coordinator_db: &Db,
participant_dbs: &[Db],
) -> eyre::Result<()> {
let left_data = encode_shares(templates, num_participants)?;

insert_masks_and_shares(&left_data, coordinator_db, participant_dbs)
.await?;

Ok(())
}

async fn insert_masks_and_shares(
Expand All @@ -147,7 +256,7 @@ async fn insert_masks_and_shares(
// Insert masks
let left_masks: Vec<_> = data
.iter()
.map(|(serial_id, mask, _)| (*serial_id as u64, mask.clone()))
.map(|(serial_id, mask, _)| (*serial_id as u64, *mask))
.collect();

coordinator_db.insert_masks(&left_masks).await?;
Expand Down
Loading

0 comments on commit f5d483a

Please sign in to comment.