Skip to content

Commit

Permalink
Make tree build system paralellized (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stentonian authored Sep 19, 2023
2 parents 261a0c5 + 2952ed1 commit 7c8dc3d
Show file tree
Hide file tree
Showing 13 changed files with 2,395 additions and 772 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ curve25519-dalek-ng = "4.1.1"

thiserror = "1.0"
displaydoc = "0.2"
rayon = "1.7.0"

[dev-dependencies]
criterion = "0.4.0"
Expand Down
189 changes: 153 additions & 36 deletions src/accumulators/ndm_smt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@
//!
//! TODO more docs

use rand::{rngs::ThreadRng, distributions::Uniform, thread_rng, Rng};
use rand::{distributions::Uniform, rngs::ThreadRng, thread_rng, Rng};
use std::collections::HashMap;
use thiserror::Error;

use crate::binary_tree::{
Coordinate, InputLeafNode, PathError, SparseBinaryTree, SparseBinaryTreeError,
BinaryTree, Coordinate, InputLeafNode, PathError, TreeBuildError, TreeBuilder,
};
use crate::inclusion_proof::{InclusionProof, InclusionProofError, AggregationFactor};
use crate::inclusion_proof::{AggregationFactor, InclusionProof, InclusionProofError};
use crate::kdf::generate_key;
use crate::node_content::FullNodeContent;
use crate::primitives::D256;
use crate::user::{User, UserId};

use std::time::SystemTime;

use rayon::prelude::*;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;

// -------------------------------------------------------------------------------------------------
// NDM-SMT struct and methods

Expand All @@ -28,7 +35,7 @@ pub struct NdmSmt {
master_secret: D256,
salt_b: D256,
salt_s: D256,
tree: SparseBinaryTree<Content>,
tree: BinaryTree<Content>,
user_mapping: HashMap<UserId, u64>,
}

Expand All @@ -47,46 +54,148 @@ impl NdmSmt {
let salt_b_bytes = salt_b.as_bytes();
let salt_s_bytes = salt_s.as_bytes();

// needed for the thread-safe closure
let master_secret_bytes_clone = master_secret_bytes.clone();
let salt_b_bytes_clone = salt_b_bytes.clone();
let salt_s_bytes_clone = salt_s_bytes.clone();

// closure that is used to create new padding nodes
let new_padding_node_content = |coord: &Coordinate| {
let new_padding_node_content = move |coord: &Coordinate| {
// TODO unfortunately we copy data here, maybe there is a way to do without copying
let coord_bytes = coord.as_bytes();
// pad_secret_bytes is given as 'w' in the DAPOL+ paper
let pad_secret = generate_key(master_secret_bytes, &coord_bytes);
// pad_secret is given as 'w' in the DAPOL+ paper
let pad_secret = generate_key(&master_secret_bytes_clone, &coord_bytes);
let pad_secret_bytes: [u8; 32] = pad_secret.into();
let blinding_factor = generate_key(&pad_secret_bytes, salt_b_bytes);
let salt = generate_key(&pad_secret_bytes, salt_s_bytes);
let blinding_factor = generate_key(&pad_secret_bytes, &salt_b_bytes_clone);
let salt = generate_key(&pad_secret_bytes, &salt_s_bytes_clone);
Content::new_pad(blinding_factor.into(), coord, salt.into())
};

let mut x_coord_generator = RandomXCoordGenerator::new(height);
let mut leaves = Vec::with_capacity(users.len());
let mut user_mapping = HashMap::with_capacity(users.len());
let mut i = 0;

for user in users.into_iter() {
let x_coord = x_coord_generator.new_unique_x_coord(i as u64)?;
i = i + 1;

let w = generate_key(master_secret_bytes, &x_coord.to_le_bytes());
let w_bytes: [u8; 32] = w.into();
let blinding_factor = generate_key(&w_bytes, salt_b_bytes);
let user_salt = generate_key(&w_bytes, salt_s_bytes);

leaves.push(InputLeafNode {
content: Content::new_leaf(
user.liability,
blinding_factor.into(),
user.id.clone(),
user_salt.into(),
),
x_coord,
});

user_mapping.insert(user.id, x_coord);
let start = SystemTime::now();
println!(
" ndm start conversion of users to inputleafnode {:?}",
start
);

// [single] first create vec with x_coords
// join with users vec
// [multiple] then generate keys, map to leaf node and
// [single] add into map

let mut x_coords = Vec::<u64>::with_capacity(users.len());
for i in 0..users.len() {
x_coords.push(x_coord_generator.new_unique_x_coord(i as u64)?);
}

let tree = SparseBinaryTree::new(leaves, height, &new_padding_node_content)?;
let tuples = users
.into_iter()
.zip(x_coords.into_iter())
.collect::<Vec<(User, u64)>>();

let leaf_nodes = tuples
.par_iter()
// .into_par_iter()
.map(|(user, x_coord)| {
let w = generate_key(master_secret_bytes, &x_coord.to_le_bytes());
let w_bytes: [u8; 32] = w.into();
let blinding_factor = generate_key(&w_bytes, salt_b_bytes);
let user_salt = generate_key(&w_bytes, salt_s_bytes);

InputLeafNode {
content: Content::new_leaf(
user.liability,
blinding_factor.into(),
user.id.clone(),
user_salt.into(),
),
x_coord: *x_coord,
}
})
.collect::<Vec<InputLeafNode<Content>>>();

let user_mapping = Arc::new(Mutex::new(HashMap::new()));
let user_mapping_ref = Arc::clone(&user_mapping);
let handle = thread::spawn(move || {
let mut my_user_mapping = user_mapping_ref.lock().unwrap();
tuples.into_iter().for_each(|(user, x_coord)| {
my_user_mapping.insert(user.id, x_coord);
});
});
// https://stackoverflow.com/questions/62613488/how-do-i-get-the-runtime-memory-size-of-an-object
use std::mem::size_of_val;
println!(
"The size of `input_leaf_nodes` is {}",
size_of_val(&*leaf_nodes)
);

// let mut leaves = Vec::with_capacity(users.len());
// for user in users.into_iter() {
// let x_coord = x_coord_generator.new_unique_x_coord(i as u64)?;
// i = i + 1;

// let w = generate_key(master_secret_bytes, &x_coord.to_le_bytes());
// let w_bytes: [u8; 32] = w.into();
// let blinding_factor = generate_key(&w_bytes, salt_b_bytes);
// let user_salt = generate_key(&w_bytes, salt_s_bytes);

// leaves.push(InputLeafNode {
// content: Content::new_leaf(
// user.liability,
// blinding_factor.into(),
// user.id.clone(),
// user_salt.into(),
// ),
// x_coord,
// });

// user_mapping.insert(user.id, x_coord);
// }

let end = SystemTime::now();
let dur = end.duration_since(start);
println!(" end {:?}", end);
println!(" duration {:?}", dur);

// println!("leaves len {}", leaves.len());
println!("leaves len {}", leaf_nodes.len());

let start = SystemTime::now();
println!(" ndm start single threaded build {:?}", start);

let tree_2 = TreeBuilder::new()
.with_height(height)
.with_leaf_nodes(leaf_nodes.clone())
.with_single_threaded_build_algorithm()
.with_padding_node_generator(new_padding_node_content)
.build()?;

let end = SystemTime::now();
let dur = end.duration_since(start);
println!(" end {:?}", end);
println!(" duration {:?}", dur);

let start = SystemTime::now();
println!(" ndm start multi threaded build {:?}", start);

let tree = TreeBuilder::new()
.with_height(height)
.with_leaf_nodes(leaf_nodes)
.with_multi_threaded_build_algorithm()
.with_padding_node_generator(new_padding_node_content)
.build()?;

let end = SystemTime::now();
let dur = end.duration_since(start);
println!(" end {:?}", end);
println!(" duration {:?}", dur);

handle.join().unwrap();
let lock = Arc::try_unwrap(user_mapping).expect("Lock still has multiple owners");
let user_mapping = lock.into_inner().expect("Mutex cannot be locked");

assert_eq!(tree.get_root(), tree_2.get_root());

Ok(NdmSmt {
tree,
Expand Down Expand Up @@ -125,7 +234,11 @@ impl NdmSmt {

let path = self.tree.build_path_for(*leaf_x_coord)?;

Ok(InclusionProof::generate(path, aggregation_factor, upper_bound_bit_length)?)
Ok(InclusionProof::generate(
path,
aggregation_factor,
upper_bound_bit_length,
)?)
}

/// Generate an inclusion proof for the given user_id.
Expand All @@ -139,14 +252,18 @@ impl NdmSmt {
) -> Result<InclusionProof<Hash>, NdmSmtError> {
let aggregation_factor = AggregationFactor::Divisor(2u8);
let upper_bound_bit_length = 64u8;
self.generate_inclusion_proof_with_custom_range_proof_params(user_id, aggregation_factor, upper_bound_bit_length)
self.generate_inclusion_proof_with_custom_range_proof_params(
user_id,
aggregation_factor,
upper_bound_bit_length,
)
}
}

#[derive(Error, Debug)]
pub enum NdmSmtError {
#[error("Problem constructing the tree")]
TreeError(#[from] SparseBinaryTreeError),
TreeError(#[from] TreeBuildError),
#[error("Number of users cannot be bigger than 2^height")]
HeightTooSmall(#[from] OutOfBoundsError),
#[error("Inclusion proof generation failed when trying to build the path in the tree")]
Expand Down
Loading

0 comments on commit 7c8dc3d

Please sign in to comment.