Skip to content

Commit

Permalink
Handle final results + checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Dzejkop committed Mar 7, 2024
1 parent b7e04a4 commit 054c916
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 57 deletions.
140 changes: 87 additions & 53 deletions bin/migrate-codes/migrate.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#![allow(clippy::type_complexity)]

use clap::Parser;
use eyre::ContextCompat;
use futures::{pin_mut, Stream, StreamExt};
use indicatif::{ProgressBar, ProgressStyle};
use mpc::bits::Bits;
use mpc::db::Db;
use mpc::distance::EncodedBits;
use mpc::iris_db::{IrisCodeEntry, IrisDb};
use mpc::iris_db::{
FinalResult, IrisCodeEntry, IrisDb, SideResult, FINAL_RESULT_STATUS,
};
use mpc::template::Template;
use rand::thread_rng;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
Expand Down Expand Up @@ -69,14 +72,32 @@ async fn main() -> eyre::Result<()> {
)
.await?;

// TODO: This is wrong, we need to run this value against the final result mapping
// in mongo
let iris_db = IrisDb::new(args.iris_code_db).await?;

let latest_serial_id = mpc_db.fetch_latest_serial_id().await?;
tracing::info!("Latest serial id {latest_serial_id}");

let iris_db = IrisDb::new(args.iris_code_db).await?;

let num_iris_codes = iris_db.count_whitelisted_iris_codes(latest_serial_id).await?;
// Cleanup items with larger ids
// as they might be assigned new values in the future
mpc_db.prune_items(latest_serial_id).await?;
iris_db.prune_final_results(latest_serial_id).await?;

let first_unsynced_iris_serial_id = if let Some(final_result) = iris_db
.get_final_result_by_serial_id(latest_serial_id)
.await?
{
iris_db
.get_entry_by_signup_id(&final_result.signup_id)
.await?
.context("Could not find iris code entry")?
.serial_id
} else {
0
};

let num_iris_codes = iris_db
.count_whitelisted_iris_codes(first_unsynced_iris_serial_id)
.await?;
tracing::info!("Processing {} iris codes", num_iris_codes);

let pb =
Expand All @@ -87,17 +108,17 @@ async fn main() -> eyre::Result<()> {
pb.set_style(pb_style);

let iris_code_entries = iris_db
.stream_whitelisted_iris_codes(latest_serial_id)
.await?;
let iris_code_chunks = iris_code_entries.chunks(args.batch_size);
let iris_code_template_chunks = iris_code_chunks
.map(|chunk| chunk.into_iter().collect::<Result<Vec<_>, _>>())
.map(|chunk| Ok(extract_templates(chunk?)));
.stream_whitelisted_iris_codes(first_unsynced_iris_serial_id)
.await?
.chunks(args.batch_size)
.map(|chunk| chunk.into_iter().collect::<Result<Vec<_>, _>>());

handle_templates_stream(
iris_code_template_chunks,
iris_code_entries,
&iris_db,
&mpc_db,
num_participants as usize,
latest_serial_id,
&pb,
)
.await?;
Expand Down Expand Up @@ -131,69 +152,82 @@ pub fn encode_shares(
Ok(iris_data)
}

pub fn extract_templates(
iris_code_entries: Vec<IrisCodeEntry>,
) -> (Vec<(usize, Template)>, Vec<(usize, Template)>) {
let (left_templates, right_templates) = iris_code_entries
.into_iter()
.map(|entry| {
(
(
entry.serial_id as usize,
Template {
code: entry.iris_code_left,
mask: entry.mask_code_left,
},
),
(
entry.serial_id as usize,
Template {
code: entry.iris_code_right,
mask: entry.mask_code_right,
},
),
)
})
.unzip();

(left_templates, right_templates)
}

async fn handle_templates_stream(
template_chunks: impl Stream<
Item = mongodb::error::Result<(
Vec<(usize, Template)>,
Vec<(usize, Template)>,
)>,
iris_code_entries: impl Stream<
Item = mongodb::error::Result<Vec<IrisCodeEntry>>,
>,
iris_db: &IrisDb,
mpc_db: &MPCDb,
num_participants: usize,
mut latest_serial_id: u64,
pb: &ProgressBar,
) -> eyre::Result<()> {
pin_mut!(template_chunks);
pin_mut!(iris_code_entries);

// Consume the stream
while let Some(template_chunk) = template_chunks.next().await {
let (left_templates, right_templates) = template_chunk?;
while let Some(entries) = iris_code_entries.next().await {
let entries = entries?;

let count = left_templates.len() as u64;
let count = entries.len() as u64;

let left_data: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| {
let template = Template {
code: entry.iris_code_left,
mask: entry.mask_code_left,
};

(latest_serial_id as usize + 1 + idx, template)
})
.collect();

let right_data: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| {
let template = Template {
code: entry.iris_code_right,
mask: entry.mask_code_right,
};

(latest_serial_id as usize + 1 + idx, template)
})
.collect();

let left = handle_side_data_chunk(
left_templates,
left_data,
num_participants,
&mpc_db.left_coordinator_db,
&mpc_db.left_participant_dbs,
);

let right = handle_side_data_chunk(
right_templates,
right_data,
num_participants,
&mpc_db.right_coordinator_db,
&mpc_db.right_participant_dbs,
);

futures::try_join!(left, right)?;
let final_results: Vec<_> = entries
.iter()
.enumerate()
.map(|(idx, entry)| FinalResult {
status: FINAL_RESULT_STATUS.to_string(),
serial_id: latest_serial_id + 1 + idx as u64,
signup_id: entry.signup_id.clone(),
unique: true,
right_result: SideResult {},
left_result: SideResult {},
})
.collect();

let results = iris_db.save_final_results(&final_results);

futures::try_join!(left, right, results)?;

latest_serial_id += count;
pb.inc(count);
}

Expand Down
16 changes: 16 additions & 0 deletions bin/migrate-codes/mpc_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ impl MPCDb {
})
}

#[tracing::instrument(skip(self,))]
pub async fn prune_items(&self, serial_id: u64) -> eyre::Result<()> {
self.left_coordinator_db.prune_items(serial_id).await?;
self.right_coordinator_db.prune_items(serial_id).await?;

for db in self.left_participant_dbs.iter() {
db.prune_items(serial_id).await?;
}

for db in self.right_participant_dbs.iter() {
db.prune_items(serial_id).await?;
}

Ok(())
}

#[tracing::instrument(skip(self,))]
pub async fn fetch_latest_serial_id(&self) -> eyre::Result<u64> {
let mut ids = Vec::new();
Expand Down
7 changes: 3 additions & 4 deletions src/iris_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::bits::Bits;
pub const IRIS_CODE_BATCH_SIZE: u32 = 30_000;
pub const DATABASE_NAME: &str = "iris";
pub const COLLECTION_NAME: &str = "codes.v2";
pub const FINAL_RESULT_COLLECTION_NAME: &str = "iris.mpc.results";
pub const FINAL_RESULT_COLLECTION_NAME: &str = "mpc.results";
pub const FINAL_RESULT_STATUS: &str = "COMPLETED";

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -29,13 +29,12 @@ pub struct FinalResult {
pub status: String,

/// The MPC serial id associated with this signup
pub serial_id: Option<u64>,
pub serial_id: u64,

/// A unique signup id string
pub signup_id: String,

#[serde(skip_serializing_if = "Option::is_none")]
pub unique: Option<bool>,
pub unique: bool,

pub right_result: SideResult,

Expand Down

0 comments on commit 054c916

Please sign in to comment.