Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(seed-mpc-db): Upgrade seed db performance, improve progress visibility #69

Merged
merged 27 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
21fc1f4
updated logging
0xKitsune Feb 28, 2024
32f4031
insert all shares/masks concurrently, update logs
0xKitsune Feb 29, 2024
deffeb1
encode shares in parallel
0xKitsune Feb 29, 2024
ebc265d
template creation parallel
0xKitsune Feb 29, 2024
2e8a6b6
refactored logic
0xKitsune Feb 29, 2024
7555424
separated coordinator and participant insertion
0xKitsune Feb 29, 2024
9ead0dc
added time elapsed
0xKitsune Feb 29, 2024
8cb87ff
added time elapsed
0xKitsune Feb 29, 2024
34f42c2
updated insertion logic
0xKitsune Feb 29, 2024
ea42212
added print log while waiting for sync
0xKitsune Feb 29, 2024
477fd44
adjusted logs
0xKitsune Feb 29, 2024
f14ddb6
updated progress bars
0xKitsune Feb 29, 2024
0b6f3ea
clippy
0xKitsune Feb 29, 2024
daf9212
removed unused imports
0xKitsune Feb 29, 2024
c727e93
Merge pull request #68 from worldcoin/0xkitsune/seed-mpc-db-parallel
0xKitsune Feb 29, 2024
d5137e1
updated progress bars
0xKitsune Feb 29, 2024
e1bb31b
removed unused var
0xKitsune Feb 29, 2024
6c014d7
fmt
0xKitsune Feb 29, 2024
23a1ffb
adjust for 1 based indexing
0xKitsune Feb 29, 2024
b439809
sequentially insert masks and shares, update progress bar
0xKitsune Feb 29, 2024
f62a8db
updated insert starting from latest serial id
0xKitsune Mar 1, 2024
f790819
updated default vals
0xKitsune Mar 1, 2024
4f23f20
merge main
0xKitsune Mar 1, 2024
bd9feab
Merge branch 'main' into 0xkitsune/seed-mpc-db
0xKitsune Mar 1, 2024
308b981
clippy
0xKitsune Mar 1, 2024
7ac4887
cargo fmt
0xKitsune Mar 1, 2024
8c8739c
uncomment e2e
0xKitsune Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 206 additions & 58 deletions bin/utils/seed_mpc_db.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use std::sync::Arc;

use clap::Args;
use indicatif::ProgressBar;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use mpc::bits::Bits;
use mpc::config::DbConfig;
use mpc::db::Db;
use mpc::distance::EncodedBits;
use mpc::rng_source::RngSource;
use mpc::template::Template;
use rand::Rng;
use rand::{thread_rng, Rng};
use rayon::iter::{IntoParallelIterator, ParallelIterator};

#[derive(Debug, Clone, Args)]
pub struct SeedMPCDb {
Expand All @@ -14,10 +21,10 @@ pub struct SeedMPCDb {
#[clap(short, long)]
pub participant_db_url: Vec<String>,

#[clap(short, long, default_value = "3000000")]
#[clap(short, long, default_value = "100000")]
pub num_templates: usize,

#[clap(short, long, default_value = "10000")]
#[clap(short, long, default_value = "100")]
pub batch_size: usize,

#[clap(short, long, env, default_value = "thread")]
Expand All @@ -29,96 +36,237 @@ pub async fn seed_mpc_db(args: &SeedMPCDb) -> eyre::Result<()> {
return Err(eyre::eyre!("No participant DBs provided"));
}

let mut templates: Vec<Template> = Vec::with_capacity(args.num_templates);

tracing::info!("Generating templates");
let pb = ProgressBar::new(args.num_templates as u64)
.with_message("Generating templates");
let (coordinator_db, participant_dbs) = initialize_dbs(args).await?;

let mut rng = args.rng.to_rng();
let now = std::time::Instant::now();

for _ in 0..args.num_templates {
templates.push(rng.gen());
let latest_serial_id =
get_latest_serial_id(coordinator_db.clone(), participant_dbs.clone())
.await?;

pb.inc(1);
}
let templates = generate_templates(args);
println!("Templates generated in {:?}", now.elapsed());

pb.finish_with_message("done");
let (batched_masks, batched_shares) = generate_shares_and_masks(
args,
templates,
(latest_serial_id + 1) as usize,
);
println!("Shares and masks generated in {:?}", now.elapsed());

let coordinator_db = Db::new(&DbConfig {
url: args.coordinator_db_url.clone(),
migrate: true,
create: true,
})
insert_masks_and_shares(
batched_masks,
batched_shares,
coordinator_db,
participant_dbs,
args.num_templates,
args.batch_size,
)
.await?;

println!("Time elapsed: {:?}", now.elapsed());

Ok(())
}

async fn initialize_dbs(
args: &SeedMPCDb,
) -> eyre::Result<(Arc<Db>, Vec<Arc<Db>>)> {
let coordinator_db = Arc::new(
Db::new(&DbConfig {
url: args.coordinator_db_url.clone(),
migrate: true,
create: true,
})
.await?,
);

let mut participant_dbs = vec![];

for db_config in args.participant_db_url.iter() {
participant_dbs.push(
participant_dbs.push(Arc::new(
Db::new(&DbConfig {
url: db_config.clone(),
migrate: true,
create: true,
})
.await?,
);
));
}

tracing::info!("Seeding databases");
let pb =
ProgressBar::new(args.num_templates as u64).with_message("Seeding DBs");
Ok((coordinator_db, participant_dbs))
}

for (idx, chunk) in templates.chunks(args.batch_size).enumerate() {
tracing::info!(
"Seeding chunk {}/{}",
idx + 1,
(templates.len() / args.batch_size) + 1
);
fn generate_templates(args: &SeedMPCDb) -> Vec<Template> {
let pb = ProgressBar::new(args.num_templates as u64)
.with_message("Generating templates...");

pb.set_style(ProgressStyle::default_bar()
.template("{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.green}] {pos:>7}/{len:7} ({eta})")
.expect("Could not create progress bar"));

let mut chunk_masks = Vec::with_capacity(chunk.len());
let mut chunk_shares: Vec<_> = (0..participant_dbs.len())
// Generate templates
let templates = (0..args.num_templates)
.into_par_iter()
.map(|_| {
let mut rng = thread_rng();

let template = rng.gen();

pb.inc(1);
template
})
.collect::<Vec<Template>>();

pb.finish_with_message("Created templates");

templates
}

pub type BatchedShares = Vec<Vec<Vec<(u64, EncodedBits)>>>;
pub type BatchedMasks = Vec<Vec<(u64, Bits)>>;

fn generate_shares_and_masks(
args: &SeedMPCDb,
templates: Vec<Template>,
next_serial_id: usize,
) -> (BatchedMasks, BatchedShares) {
// Generate shares and masks
let mut batched_shares = vec![];
let mut batched_masks = vec![];

let num_participants = args.participant_db_url.len();

let pb = ProgressBar::new(args.num_templates as u64)
.with_message("Generating shares and masks...");

pb.set_style(ProgressStyle::default_bar()
.template("{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.green}] {pos:>7}/{len:7} ({eta})")
.expect("Could not create progress bar"));

for (idx, chunk) in templates.chunks(args.batch_size).enumerate() {
let mut batch_masks = Vec::with_capacity(chunk.len());
let mut batch_shares: Vec<_> = (0..num_participants)
.map(|_| Vec::with_capacity(chunk.len()))
.collect();

tracing::info!("Encoding shares");
let pb = ProgressBar::new(chunk.len() as u64)
.with_message("Encoding shares");
for (offset, template) in chunk.iter().enumerate() {
let shares = mpc::distance::encode(template)
.share(participant_dbs.len(), &mut rng);
let shares_chunk = chunk
.into_par_iter()
.map(|template| {
let mut rng = thread_rng();

let id = offset + (idx * args.batch_size);
let shares = mpc::distance::encode(template)
.share(num_participants, &mut rng);

pb.inc(1);
shares
})
.collect::<Vec<Box<[EncodedBits]>>>();

for (offset, (shares, template)) in
shares_chunk.iter().zip(chunk).enumerate()
{
let id = offset + (idx * args.batch_size) + next_serial_id;

batch_masks.push((id as u64, template.mask));

chunk_masks.push((id as u64, template.mask));
for (idx, share) in shares.iter().enumerate() {
chunk_shares[idx].push((id as u64, *share));
batch_shares[idx].push((id as u64, *share));
}

pb.inc(1);
}

let mut tasks = vec![];
batched_shares.push(batch_shares);
batched_masks.push(batch_masks);
}

pb.finish_with_message("Created shares and masks");

for (idx, db) in participant_dbs.iter().enumerate() {
tracing::info!("Inserting shares into participant DB {idx}");
(batched_masks, batched_shares)
}

tasks.push(db.insert_shares(&chunk_shares[idx]));
}
async fn insert_masks_and_shares(
batched_masks: BatchedMasks,
batched_shares: BatchedShares,
coordinator_db: Arc<Db>,
participant_dbs: Vec<Arc<Db>>,
num_templates: usize,
batch_size: usize,
) -> eyre::Result<()> {
println!("Inserting masks and shares into db...");

tracing::info!("Inserting masks into coordinator DB");
let (coordinator, participants) = tokio::join!(
coordinator_db.insert_masks(&chunk_masks),
futures::future::join_all(tasks),
);
// Commit shares and masks to db
let mpb = MultiProgress::new();
let style = ProgressStyle::default_bar().template("{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.green}] {pos:>7}/{len:7} ({eta})")?;

coordinator?;
participants.into_iter().collect::<Result<_, _>>()?;
let participant_progress_bars = participant_dbs
.iter()
.enumerate()
.map(|(i, _)| {
mpb.add(ProgressBar::new(num_templates as u64).with_message(
format!("Inserting shares for participant {}", i),
))
})
.collect::<Vec<_>>();

pb.inc(args.batch_size as u64);
for pb in participant_progress_bars.iter() {
pb.set_style(style.clone());
}

pb.finish_with_message("done");
let coordinator_progress_bar = mpb.add(
ProgressBar::new(num_templates as u64)
.with_message("Inserting masks for coordinator"),
);
coordinator_progress_bar.set_style(style.clone());

for (shares, masks) in
batched_shares.into_iter().zip(batched_masks.into_iter())
{
let mut tasks = FuturesUnordered::new();

for (idx, db) in participant_dbs.iter().enumerate() {
let shares = shares[idx].clone();
let db = db.clone();
let pb = participant_progress_bars[idx].clone();

tasks.push(tokio::spawn(async move {
db.insert_shares(&shares).await?;
pb.inc(batch_size as u64);
eyre::Result::<()>::Ok(())
}));
}

let coordinator_db = coordinator_db.clone();
let coordinator_progress_bar = coordinator_progress_bar.clone();
tasks.push(tokio::spawn(async move {
coordinator_db.insert_masks(&masks).await?;
coordinator_progress_bar.inc(batch_size as u64);
eyre::Result::<()>::Ok(())
}));

while let Some(result) = tasks.next().await {
result??;
}
}

Ok(())
}

async fn get_latest_serial_id(
coordinator_db: Arc<Db>,
participant_dbs: Vec<Arc<Db>>,
) -> eyre::Result<u64> {
let coordinator_serial_id = coordinator_db.fetch_latest_mask_id().await?;
println!("Coordinator serial id: {}", coordinator_serial_id);

for (i, db) in participant_dbs.iter().enumerate() {
let participant_id = db.fetch_latest_share_id().await?;

println!("Participant {} ids: {}", i, participant_id);
assert_eq!(
coordinator_serial_id, participant_id,
"Databases are not in sync"
);
}

Ok(coordinator_serial_id)
}
3 changes: 0 additions & 3 deletions bin/utils/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use seed_iris_db::{seed_iris_db, SeedIrisDb};
use seed_mpc_db::{seed_mpc_db, SeedMPCDb};
use sqs_query::{sqs_query, SQSQuery};
use sqs_receive::{sqs_receive, SQSReceive};
use telemetry_batteries::tracing::stdout::StdoutBattery;

mod generate_mock_templates;
mod seed_iris_db;
Expand All @@ -27,8 +26,6 @@ enum Opt {
async fn main() -> eyre::Result<()> {
dotenv::dotenv().ok();

let _shutdown_tracing_provider = StdoutBattery::init();

let args = Opt::parse();

match args {
Expand Down
40 changes: 40 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,46 @@ impl Db {
Ok(filter_sequential_items(masks, 1 + id as i64))
}

#[tracing::instrument(skip(self))]
pub async fn fetch_latest_mask_id(&self) -> eyre::Result<u64> {
let mask_id = sqlx::query_as::<_, (i64,)>(
r#"
SELECT id
FROM masks
ORDER BY id DESC
LIMIT 1
"#,
)
.fetch_one(&self.pool)
.await;

match mask_id {
Ok(mask_id) => Ok(mask_id.0 as u64),
Err(sqlx::Error::RowNotFound) => Ok(0),
Err(err) => Err(err.into()),
}
}

#[tracing::instrument(skip(self))]
pub async fn fetch_latest_share_id(&self) -> eyre::Result<u64> {
let share_id = sqlx::query_as::<_, (i64,)>(
r#"
SELECT id
FROM shares
ORDER BY id DESC
LIMIT 1
"#,
)
.fetch_one(&self.pool)
.await;

match share_id {
Ok(share_id) => Ok(share_id.0 as u64),
Err(sqlx::Error::RowNotFound) => Ok(0),
Err(err) => Err(err.into()),
}
}

#[tracing::instrument(skip(self))]
pub async fn insert_masks(
&self,
Expand Down
Loading