From 8bbb4824ede5f2cf115f1a37e08410b57944a4f6 Mon Sep 17 00:00:00 2001 From: intx4 <65897068+intx4@users.noreply.github.com> Date: Fri, 29 Dec 2023 19:56:38 +0100 Subject: [PATCH] Implementation of Parallel Merkle Tree (#125) Co-authored-by: f50033134 Co-authored-by: Giacomo Fenzi Co-authored-by: Pratyush Mishra --- .gitignore | 2 +- Cargo.toml | 6 ++ benches/merkle_tree.rs | 66 +++++++++++++++ src/crh/bowe_hopwood/mod.rs | 4 +- src/crh/injective_map/mod.rs | 6 +- src/crh/mod.rs | 2 +- src/lib.rs | 27 ++++--- src/merkle_tree/mod.rs | 115 +++++++++++++++++---------- src/merkle_tree/tests/constraints.rs | 15 +--- src/merkle_tree/tests/mod.rs | 23 ++---- src/prf/blake2s/mod.rs | 4 +- src/prf/mod.rs | 4 +- 12 files changed, 184 insertions(+), 90 deletions(-) create mode 100644 benches/merkle_tree.rs diff --git a/.gitignore b/.gitignore index 448a8bb6..9d1f1106 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ Cargo.lock params *.swp *.swo - +.vscode \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 909b59a8..aa001515 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,6 +108,12 @@ path = "benches/signature.rs" harness = false required-features = [ "signature" ] +[[bench]] +name = "merkle_tree" +path = "benches/merkle_tree.rs" +harness = false +required-features = [ "merkle_tree" ] + [patch.crates-io] ark-r1cs-std = { git = "https://github.com/arkworks-rs/r1cs-std/" } ark-ff = { git = "https://github.com/arkworks-rs/algebra/" } diff --git a/benches/merkle_tree.rs b/benches/merkle_tree.rs new file mode 100644 index 00000000..b32168e8 --- /dev/null +++ b/benches/merkle_tree.rs @@ -0,0 +1,66 @@ +#[macro_use] +extern crate criterion; + +static NUM_LEAVES: i32 = 1 << 20; + +mod bytes_mt_benches { + use ark_crypto_primitives::crh::*; + use ark_crypto_primitives::merkle_tree::*; + use ark_crypto_primitives::to_uncompressed_bytes; + use ark_ff::BigInteger256; + use ark_serialize::CanonicalSerialize; + use ark_std::{test_rng, UniformRand}; + use criterion::Criterion; + use std::borrow::Borrow; + + use crate::NUM_LEAVES; + + type LeafH = sha2::Sha256; + type CompressH = sha2::Sha256; + + struct Sha256MerkleTreeParams; + + impl Config for Sha256MerkleTreeParams { + type Leaf = [u8]; + + type LeafDigest = ::Output; + type LeafInnerDigestConverter = ByteDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; + } + type Sha256MerkleTree = MerkleTree; + + pub fn merkle_tree_create(c: &mut Criterion) { + let mut rng = test_rng(); + let leaves: Vec<_> = (0..NUM_LEAVES) + .map(|_| { + let rnd = BigInteger256::rand(&mut rng); + to_uncompressed_bytes!(rnd).unwrap() + }) + .collect(); + let leaf_crh_params = ::setup(&mut rng).unwrap(); + let two_to_one_params = ::setup(&mut rng) + .unwrap() + .clone(); + c.bench_function("Merkle Tree Create (Leaves as [u8])", move |b| { + b.iter(|| { + Sha256MerkleTree::new( + &leaf_crh_params.clone(), + &two_to_one_params.clone(), + &leaves, + ) + .unwrap(); + }) + }); + } + + criterion_group! { + name = mt_create; + config = Criterion::default().sample_size(10); + targets = merkle_tree_create + } +} + +criterion_main!(crate::bytes_mt_benches::mt_create,); diff --git a/src/crh/bowe_hopwood/mod.rs b/src/crh/bowe_hopwood/mod.rs index e74386b0..6dd758e5 100644 --- a/src/crh/bowe_hopwood/mod.rs +++ b/src/crh/bowe_hopwood/mod.rs @@ -17,7 +17,7 @@ use ark_ec::{ twisted_edwards::Projective as TEProjective, twisted_edwards::TECurveConfig, AdditiveGroup, CurveGroup, }; -use ark_ff::{biginteger::BigInteger, fields::PrimeField}; +use ark_ff::fields::PrimeField; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::borrow::Borrow; use ark_std::cfg_chunks; @@ -82,7 +82,7 @@ impl CRHScheme for CRH { let mut c = 0; let mut range = F::BigInt::from(2_u64); while range < upper_limit { - range.muln(4); + range <<= 4; c += 1; } diff --git a/src/crh/injective_map/mod.rs b/src/crh/injective_map/mod.rs index e138149d..fbd99fd1 100644 --- a/src/crh/injective_map/mod.rs +++ b/src/crh/injective_map/mod.rs @@ -1,4 +1,4 @@ -use crate::{CryptoError, Error}; +use crate::Error; use ark_std::rand::Rng; use ark_std::{fmt::Debug, hash::Hash, marker::PhantomData}; @@ -16,7 +16,7 @@ pub mod constraints; pub trait InjectiveMap { type Output: Clone + Eq + Hash + Default + Debug + CanonicalSerialize + CanonicalDeserialize; - fn injective_map(ge: &C::Affine) -> Result; + fn injective_map(ge: &C::Affine) -> Result; } pub struct TECompressor; @@ -24,7 +24,7 @@ pub struct TECompressor; impl InjectiveMap> for TECompressor { type Output =

::BaseField; - fn injective_map(ge: &TEAffine

) -> Result { + fn injective_map(ge: &TEAffine

) -> Result { debug_assert!(ge.is_in_correct_subgroup_assuming_on_curve()); Ok(ge.x) } diff --git a/src/crh/mod.rs b/src/crh/mod.rs index 4a6e5174..e66be16e 100644 --- a/src/crh/mod.rs +++ b/src/crh/mod.rs @@ -21,7 +21,7 @@ pub use constraints::*; /// Interface to CRH. Note that in this release, while all implementations of `CRH` have fixed length, /// variable length CRH may also implement this trait in future. pub trait CRHScheme { - type Input: ?Sized; + type Input: ?Sized + Send; type Output: Clone + Eq + core::fmt::Debug diff --git a/src/lib.rs b/src/lib.rs index c2083553..31ae920b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,22 +43,29 @@ pub mod snark; #[cfg(feature = "sponge")] pub mod sponge; -pub type Error = Box; - #[derive(Debug)] -pub enum CryptoError { +pub enum Error { IncorrectInputLength(usize), NotPrimeOrder, + GenericError(Box), + SerializationError(ark_serialize::SerializationError), } -impl core::fmt::Display for CryptoError { +impl core::fmt::Display for Error { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let msg = match self { - CryptoError::IncorrectInputLength(len) => format!("input length is wrong: {}", len), - CryptoError::NotPrimeOrder => "element is not prime order".to_owned(), - }; - write!(f, "{}", msg) + match self { + Self::IncorrectInputLength(len) => write!(f, "incorrect input length: {len}"), + Self::NotPrimeOrder => write!(f, "element is not prime order"), + Self::GenericError(e) => write!(f, "{e}"), + Self::SerializationError(e) => write!(f, "{e}"), + } } } -impl ark_std::error::Error for CryptoError {} +impl ark_std::error::Error for Error {} + +impl From for Error { + fn from(e: ark_serialize::SerializationError) -> Self { + Self::SerializationError(e) + } +} diff --git a/src/merkle_tree/mod.rs b/src/merkle_tree/mod.rs index d8655bbb..2b624903 100644 --- a/src/merkle_tree/mod.rs +++ b/src/merkle_tree/mod.rs @@ -14,6 +14,9 @@ mod tests; #[cfg(feature = "r1cs")] pub mod constraints; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + /// Convert the hash digest in different layers by converting previous layer's output to /// `TargetType`, which is a `Borrow` to next layer's input. pub trait DigestConverter { @@ -52,15 +55,16 @@ impl DigestConverter for ByteDigestConverter /// * `LeafHash`: Convert leaf to leaf digest /// * `TwoToOneHash`: Compress two inner digests to one inner digest pub trait Config { - type Leaf: ?Sized; // merkle tree does not store the leaf - // leaf layer + type Leaf: ?Sized + Send; // merkle tree does not store the leaf + // leaf layer type LeafDigest: Clone + Eq + core::fmt::Debug + Hash + Default + CanonicalSerialize - + CanonicalDeserialize; + + CanonicalDeserialize + + Send; // transition between leaf layer to inner layer type LeafInnerDigestConverter: DigestConverter< Self::LeafDigest, @@ -73,7 +77,8 @@ pub trait Config { + Hash + Default + CanonicalSerialize - + CanonicalDeserialize; + + CanonicalDeserialize + + Send; // Tom's Note: in the future, if we want different hash function, we can simply add more // types of digest here and specify a digest converter. Same for constraints. @@ -229,32 +234,30 @@ impl MerkleTree

{ height: usize, ) -> Result { // use empty leaf digest - let leaves_digest = vec![P::LeafDigest::default(); 1 << (height - 1)]; - Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digest) + let leaf_digests = vec![P::LeafDigest::default(); 1 << (height - 1)]; + Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaf_digests) } /// Returns a new merkle tree. `leaves.len()` should be power of two. - pub fn new>( + pub fn new + Send>( leaf_hash_param: &LeafParam

, two_to_one_hash_param: &TwoToOneParam

, - leaves: impl IntoIterator, + #[cfg(not(feature = "parallel"))] leaves: impl IntoIterator, + #[cfg(feature = "parallel")] leaves: impl IntoParallelIterator, ) -> Result { - let mut leaves_digests = Vec::new(); - - // compute and store hash values for each leaf - for leaf in leaves.into_iter() { - leaves_digests.push(P::LeafHash::evaluate(leaf_hash_param, leaf)?) - } + let leaf_digests: Vec<_> = cfg_into_iter!(leaves) + .map(|input| P::LeafHash::evaluate(leaf_hash_param, input.as_ref())) + .collect::, _>>()?; - Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digests) + Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaf_digests) } pub fn new_with_leaf_digest( leaf_hash_param: &LeafParam

, two_to_one_hash_param: &TwoToOneParam

, - leaves_digest: Vec, + leaf_digests: Vec, ) -> Result { - let leaf_nodes_size = leaves_digest.len(); + let leaf_nodes_size = leaf_digests.len(); assert!( leaf_nodes_size.is_power_of_two() && leaf_nodes_size > 1, "`leaves.len() should be power of two and greater than one" @@ -266,7 +269,7 @@ impl MerkleTree

{ let hash_of_empty: P::InnerDigest = P::InnerDigest::default(); // initialize the merkle tree as array of nodes in level order - let mut non_leaf_nodes: Vec = (0..non_leaf_nodes_size) + let mut non_leaf_nodes: Vec = cfg_into_iter!(0..non_leaf_nodes_size) .map(|_| hash_of_empty.clone()) .collect(); @@ -282,19 +285,32 @@ impl MerkleTree

{ { let start_index = level_indices.pop().unwrap(); let upper_bound = left_child(start_index); - for current_index in start_index..upper_bound { - // `left_child(current_index)` and `right_child(current_index) returns the position of - // leaf in the whole tree (represented as a list in level order). We need to shift it - // by `-upper_bound` to get the index in `leaf_nodes` list. - let left_leaf_index = left_child(current_index) - upper_bound; - let right_leaf_index = right_child(current_index) - upper_bound; - // compute hash - non_leaf_nodes[current_index] = P::TwoToOneHash::evaluate( - &two_to_one_hash_param, - P::LeafInnerDigestConverter::convert(leaves_digest[left_leaf_index].clone())?, - P::LeafInnerDigestConverter::convert(leaves_digest[right_leaf_index].clone())?, - )? - } + + cfg_iter_mut!(non_leaf_nodes[start_index..upper_bound]) + .enumerate() + .try_for_each(|(i, n)| { + // `left_child(current_index)` and `right_child(current_index) returns the position of + // leaf in the whole tree (represented as a list in level order). We need to shift it + // by `-upper_bound` to get the index in `leaf_nodes` list. + + //similarly, we need to rescale i by start_index + //to get the index outside the slice and in the level-ordered list of nodes + + let current_index = i + start_index; + let left_leaf_index = left_child(current_index) - upper_bound; + let right_leaf_index = right_child(current_index) - upper_bound; + + *n = P::TwoToOneHash::evaluate( + two_to_one_hash_param, + P::LeafInnerDigestConverter::convert( + leaf_digests[left_leaf_index].clone(), + )?, + P::LeafInnerDigestConverter::convert( + leaf_digests[right_leaf_index].clone(), + )?, + )?; + Ok::<(), crate::Error>(()) + })?; } // compute the hash values for nodes in every other layer in the tree @@ -302,19 +318,34 @@ impl MerkleTree

{ for &start_index in &level_indices { // The layer beginning `start_index` ends at `upper_bound` (exclusive). let upper_bound = left_child(start_index); - for current_index in start_index..upper_bound { - let left_index = left_child(current_index); - let right_index = right_child(current_index); - non_leaf_nodes[current_index] = P::TwoToOneHash::compress( - &two_to_one_hash_param, - non_leaf_nodes[left_index].clone(), - non_leaf_nodes[right_index].clone(), - )? - } - } + let (nodes_at_level, nodes_at_prev_level) = + non_leaf_nodes[..].split_at_mut(upper_bound); + // Iterate over the nodes at the current level, and compute the hash of each node + cfg_iter_mut!(nodes_at_level[start_index..]) + .enumerate() + .try_for_each(|(i, n)| { + // `left_child(current_index)` and `right_child(current_index) returns the position of + // leaf in the whole tree (represented as a list in level order). We need to shift it + // by `-upper_bound` to get the index in `leaf_nodes` list. + + //similarly, we need to rescale i by start_index + //to get the index outside the slice and in the level-ordered list of nodes + let current_index = i + start_index; + let left_leaf_index = left_child(current_index) - upper_bound; + let right_leaf_index = right_child(current_index) - upper_bound; + + // need for unwrap as Box does not implement trait Send + *n = P::TwoToOneHash::compress( + two_to_one_hash_param, + nodes_at_prev_level[left_leaf_index].clone(), + nodes_at_prev_level[right_leaf_index].clone(), + )?; + Ok::<_, crate::Error>(()) + })?; + } Ok(MerkleTree { - leaf_nodes: leaves_digest, + leaf_nodes: leaf_digests, non_leaf_nodes, height: tree_height, leaf_hash_param: leaf_hash_param.clone(), diff --git a/src/merkle_tree/tests/constraints.rs b/src/merkle_tree/tests/constraints.rs index fbe5217c..8f1602d7 100644 --- a/src/merkle_tree/tests/constraints.rs +++ b/src/merkle_tree/tests/constraints.rs @@ -61,12 +61,8 @@ mod byte_mt_tests { let leaf_crh_params = ::setup(&mut rng).unwrap(); let two_to_one_crh_params = ::setup(&mut rng).unwrap(); - let mut tree = JubJubMerkleTree::new( - &leaf_crh_params, - &two_to_one_crh_params, - leaves.iter().map(|v| v.as_slice()), - ) - .unwrap(); + let mut tree = + JubJubMerkleTree::new(&leaf_crh_params, &two_to_one_crh_params, leaves).unwrap(); let root = tree.root(); for (i, leaf) in leaves.iter().enumerate() { let cs = ConstraintSystem::::new_ref(); @@ -288,12 +284,7 @@ mod field_mt_tests { ) { let leaf_crh_params = poseidon_parameters(); let two_to_one_params = leaf_crh_params.clone(); - let mut tree = FieldMT::new( - &leaf_crh_params, - &two_to_one_params, - leaves.iter().map(|x| x.as_slice()), - ) - .unwrap(); + let mut tree = FieldMT::new(&leaf_crh_params, &two_to_one_params, leaves).unwrap(); let root = tree.root(); for (i, leaf) in leaves.iter().enumerate() { let cs = ConstraintSystem::::new_ref(); diff --git a/src/merkle_tree/tests/mod.rs b/src/merkle_tree/tests/mod.rs index d328b352..a4968917 100644 --- a/src/merkle_tree/tests/mod.rs +++ b/src/merkle_tree/tests/mod.rs @@ -39,20 +39,18 @@ mod bytes_mt_tests { /// Pedersen only takes bytes as leaf, so we use `ToBytes` trait. fn merkle_tree_test(leaves: &[L], update_query: &[(usize, L)]) -> () { let mut rng = ark_std::test_rng(); + let mut leaves: Vec<_> = leaves .iter() .map(|leaf| crate::to_uncompressed_bytes!(leaf).unwrap()) .collect(); + let leaf_crh_params = ::setup(&mut rng).unwrap(); - let two_to_one_params = ::setup(&mut rng) - .unwrap() - .clone(); - let mut tree = JubJubMerkleTree::new( - &leaf_crh_params.clone(), - &two_to_one_params.clone(), - leaves.iter().map(|x| x.as_slice()), - ) - .unwrap(); + let two_to_one_params = ::setup(&mut rng).unwrap(); + + let mut tree = + JubJubMerkleTree::new(&leaf_crh_params, &two_to_one_params, &leaves).unwrap(); + let mut root = tree.root(); // test merkle tree functionality without update for (i, leaf) in leaves.iter().enumerate() { @@ -145,12 +143,7 @@ mod field_mt_tests { let leaf_crh_params = poseidon_parameters(); let two_to_one_params = leaf_crh_params.clone(); - let mut tree = FieldMT::new( - &leaf_crh_params, - &two_to_one_params, - leaves.iter().map(|x| x.as_slice()), - ) - .unwrap(); + let mut tree = FieldMT::new(&leaf_crh_params, &two_to_one_params, &leaves).unwrap(); let mut root = tree.root(); diff --git a/src/prf/blake2s/mod.rs b/src/prf/blake2s/mod.rs index 6add07d9..7455e18a 100644 --- a/src/prf/blake2s/mod.rs +++ b/src/prf/blake2s/mod.rs @@ -3,7 +3,7 @@ use blake2::{Blake2s256 as B2s, Blake2sMac}; use digest::Digest; use super::PRF; -use crate::CryptoError; +use crate::Error; #[cfg(feature = "r1cs")] pub mod constraints; @@ -16,7 +16,7 @@ impl PRF for Blake2s { type Output = [u8; 32]; type Seed = [u8; 32]; - fn evaluate(seed: &Self::Seed, input: &Self::Input) -> Result { + fn evaluate(seed: &Self::Seed, input: &Self::Input) -> Result { let eval_time = start_timer!(|| "Blake2s::Eval"); let mut h = B2s::new(); h.update(seed.as_ref()); diff --git a/src/prf/mod.rs b/src/prf/mod.rs index d8f870cf..fa3da3a4 100644 --- a/src/prf/mod.rs +++ b/src/prf/mod.rs @@ -2,7 +2,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use core::{fmt::Debug, hash::Hash}; -use crate::CryptoError; +use crate::Error; #[cfg(feature = "r1cs")] pub mod constraints; @@ -17,5 +17,5 @@ pub trait PRF { type Output: CanonicalSerialize + Eq + Clone + Debug + Default + Hash; type Seed: CanonicalDeserialize + CanonicalSerialize + Clone + Default + Debug; - fn evaluate(seed: &Self::Seed, input: &Self::Input) -> Result; + fn evaluate(seed: &Self::Seed, input: &Self::Input) -> Result; }