Skip to content

Commit

Permalink
Implementation of Parallel Merkle Tree (#125)
Browse files Browse the repository at this point in the history
Co-authored-by: f50033134 <francesco.intoci@huawei.com>
Co-authored-by: Giacomo Fenzi <giacomofenzi@outlook.com>
Co-authored-by: Pratyush Mishra <pratyushmishra@berkeley.edu>
  • Loading branch information
4 people committed Dec 29, 2023
1 parent 955b333 commit 8bbb482
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Cargo.lock
params
*.swp
*.swo

.vscode
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ path = "benches/signature.rs"
harness = false
required-features = [ "signature" ]

[[bench]]
name = "merkle_tree"
path = "benches/merkle_tree.rs"
harness = false
required-features = [ "merkle_tree" ]

[patch.crates-io]
ark-r1cs-std = { git = "https://github.com/arkworks-rs/r1cs-std/" }
ark-ff = { git = "https://github.com/arkworks-rs/algebra/" }
Expand Down
66 changes: 66 additions & 0 deletions benches/merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#[macro_use]
extern crate criterion;

static NUM_LEAVES: i32 = 1 << 20;

mod bytes_mt_benches {
use ark_crypto_primitives::crh::*;
use ark_crypto_primitives::merkle_tree::*;
use ark_crypto_primitives::to_uncompressed_bytes;
use ark_ff::BigInteger256;
use ark_serialize::CanonicalSerialize;
use ark_std::{test_rng, UniformRand};
use criterion::Criterion;
use std::borrow::Borrow;

use crate::NUM_LEAVES;

type LeafH = sha2::Sha256;
type CompressH = sha2::Sha256;

struct Sha256MerkleTreeParams;

impl Config for Sha256MerkleTreeParams {
type Leaf = [u8];

type LeafDigest = <LeafH as CRHScheme>::Output;
type LeafInnerDigestConverter = ByteDigestConverter<Self::LeafDigest>;
type InnerDigest = <CompressH as TwoToOneCRHScheme>::Output;

type LeafHash = LeafH;
type TwoToOneHash = CompressH;
}
type Sha256MerkleTree = MerkleTree<Sha256MerkleTreeParams>;

pub fn merkle_tree_create(c: &mut Criterion) {
let mut rng = test_rng();
let leaves: Vec<_> = (0..NUM_LEAVES)
.map(|_| {
let rnd = BigInteger256::rand(&mut rng);
to_uncompressed_bytes!(rnd).unwrap()
})
.collect();
let leaf_crh_params = <LeafH as CRHScheme>::setup(&mut rng).unwrap();
let two_to_one_params = <CompressH as TwoToOneCRHScheme>::setup(&mut rng)
.unwrap()
.clone();
c.bench_function("Merkle Tree Create (Leaves as [u8])", move |b| {
b.iter(|| {
Sha256MerkleTree::new(
&leaf_crh_params.clone(),
&two_to_one_params.clone(),
&leaves,
)
.unwrap();
})
});
}

criterion_group! {
name = mt_create;
config = Criterion::default().sample_size(10);
targets = merkle_tree_create
}
}

criterion_main!(crate::bytes_mt_benches::mt_create,);
4 changes: 2 additions & 2 deletions src/crh/bowe_hopwood/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use ark_ec::{
twisted_edwards::Projective as TEProjective, twisted_edwards::TECurveConfig, AdditiveGroup,
CurveGroup,
};
use ark_ff::{biginteger::BigInteger, fields::PrimeField};
use ark_ff::fields::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::borrow::Borrow;
use ark_std::cfg_chunks;
Expand Down Expand Up @@ -82,7 +82,7 @@ impl<P: TECurveConfig, W: pedersen::Window> CRHScheme for CRH<P, W> {
let mut c = 0;
let mut range = F::BigInt::from(2_u64);
while range < upper_limit {
range.muln(4);
range <<= 4;
c += 1;
}

Expand Down
6 changes: 3 additions & 3 deletions src/crh/injective_map/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{CryptoError, Error};
use crate::Error;
use ark_std::rand::Rng;
use ark_std::{fmt::Debug, hash::Hash, marker::PhantomData};

Expand All @@ -16,15 +16,15 @@ pub mod constraints;
pub trait InjectiveMap<C: CurveGroup> {
type Output: Clone + Eq + Hash + Default + Debug + CanonicalSerialize + CanonicalDeserialize;

fn injective_map(ge: &C::Affine) -> Result<Self::Output, CryptoError>;
fn injective_map(ge: &C::Affine) -> Result<Self::Output, Error>;
}

pub struct TECompressor;

impl<P: TECurveConfig> InjectiveMap<TEProjective<P>> for TECompressor {
type Output = <P as CurveConfig>::BaseField;

fn injective_map(ge: &TEAffine<P>) -> Result<Self::Output, CryptoError> {
fn injective_map(ge: &TEAffine<P>) -> Result<Self::Output, Error> {
debug_assert!(ge.is_in_correct_subgroup_assuming_on_curve());
Ok(ge.x)
}
Expand Down
2 changes: 1 addition & 1 deletion src/crh/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub use constraints::*;
/// Interface to CRH. Note that in this release, while all implementations of `CRH` have fixed length,
/// variable length CRH may also implement this trait in future.
pub trait CRHScheme {
type Input: ?Sized;
type Input: ?Sized + Send;
type Output: Clone
+ Eq
+ core::fmt::Debug
Expand Down
27 changes: 17 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,29 @@ pub mod snark;
#[cfg(feature = "sponge")]
pub mod sponge;

pub type Error = Box<dyn ark_std::error::Error>;

#[derive(Debug)]
pub enum CryptoError {
pub enum Error {
IncorrectInputLength(usize),
NotPrimeOrder,
GenericError(Box<dyn ark_std::error::Error + Send>),
SerializationError(ark_serialize::SerializationError),
}

impl core::fmt::Display for CryptoError {
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let msg = match self {
CryptoError::IncorrectInputLength(len) => format!("input length is wrong: {}", len),
CryptoError::NotPrimeOrder => "element is not prime order".to_owned(),
};
write!(f, "{}", msg)
match self {
Self::IncorrectInputLength(len) => write!(f, "incorrect input length: {len}"),
Self::NotPrimeOrder => write!(f, "element is not prime order"),
Self::GenericError(e) => write!(f, "{e}"),
Self::SerializationError(e) => write!(f, "{e}"),
}
}
}

impl ark_std::error::Error for CryptoError {}
impl ark_std::error::Error for Error {}

impl From<ark_serialize::SerializationError> for Error {
fn from(e: ark_serialize::SerializationError) -> Self {
Self::SerializationError(e)
}
}
115 changes: 73 additions & 42 deletions src/merkle_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ mod tests;
#[cfg(feature = "r1cs")]
pub mod constraints;

#[cfg(feature = "parallel")]
use rayon::prelude::*;

/// Convert the hash digest in different layers by converting previous layer's output to
/// `TargetType`, which is a `Borrow` to next layer's input.
pub trait DigestConverter<From, To: ?Sized> {
Expand Down Expand Up @@ -52,15 +55,16 @@ impl<T: CanonicalSerialize> DigestConverter<T, [u8]> for ByteDigestConverter<T>
/// * `LeafHash`: Convert leaf to leaf digest
/// * `TwoToOneHash`: Compress two inner digests to one inner digest
pub trait Config {
type Leaf: ?Sized; // merkle tree does not store the leaf
// leaf layer
type Leaf: ?Sized + Send; // merkle tree does not store the leaf
// leaf layer
type LeafDigest: Clone
+ Eq
+ core::fmt::Debug
+ Hash
+ Default
+ CanonicalSerialize
+ CanonicalDeserialize;
+ CanonicalDeserialize
+ Send;
// transition between leaf layer to inner layer
type LeafInnerDigestConverter: DigestConverter<
Self::LeafDigest,
Expand All @@ -73,7 +77,8 @@ pub trait Config {
+ Hash
+ Default
+ CanonicalSerialize
+ CanonicalDeserialize;
+ CanonicalDeserialize
+ Send;

// Tom's Note: in the future, if we want different hash function, we can simply add more
// types of digest here and specify a digest converter. Same for constraints.
Expand Down Expand Up @@ -229,32 +234,30 @@ impl<P: Config> MerkleTree<P> {
height: usize,
) -> Result<Self, crate::Error> {
// use empty leaf digest
let leaves_digest = vec![P::LeafDigest::default(); 1 << (height - 1)];
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digest)
let leaf_digests = vec![P::LeafDigest::default(); 1 << (height - 1)];
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaf_digests)
}

/// Returns a new merkle tree. `leaves.len()` should be power of two.
pub fn new<L: Borrow<P::Leaf>>(
pub fn new<L: AsRef<P::Leaf> + Send>(
leaf_hash_param: &LeafParam<P>,
two_to_one_hash_param: &TwoToOneParam<P>,
leaves: impl IntoIterator<Item = L>,
#[cfg(not(feature = "parallel"))] leaves: impl IntoIterator<Item = L>,
#[cfg(feature = "parallel")] leaves: impl IntoParallelIterator<Item = L>,
) -> Result<Self, crate::Error> {
let mut leaves_digests = Vec::new();

// compute and store hash values for each leaf
for leaf in leaves.into_iter() {
leaves_digests.push(P::LeafHash::evaluate(leaf_hash_param, leaf)?)
}
let leaf_digests: Vec<_> = cfg_into_iter!(leaves)
.map(|input| P::LeafHash::evaluate(leaf_hash_param, input.as_ref()))
.collect::<Result<Vec<_>, _>>()?;

Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digests)
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaf_digests)
}

pub fn new_with_leaf_digest(
leaf_hash_param: &LeafParam<P>,
two_to_one_hash_param: &TwoToOneParam<P>,
leaves_digest: Vec<P::LeafDigest>,
leaf_digests: Vec<P::LeafDigest>,
) -> Result<Self, crate::Error> {
let leaf_nodes_size = leaves_digest.len();
let leaf_nodes_size = leaf_digests.len();
assert!(
leaf_nodes_size.is_power_of_two() && leaf_nodes_size > 1,
"`leaves.len() should be power of two and greater than one"
Expand All @@ -266,7 +269,7 @@ impl<P: Config> MerkleTree<P> {
let hash_of_empty: P::InnerDigest = P::InnerDigest::default();

// initialize the merkle tree as array of nodes in level order
let mut non_leaf_nodes: Vec<P::InnerDigest> = (0..non_leaf_nodes_size)
let mut non_leaf_nodes: Vec<P::InnerDigest> = cfg_into_iter!(0..non_leaf_nodes_size)
.map(|_| hash_of_empty.clone())
.collect();

Expand All @@ -282,39 +285,67 @@ impl<P: Config> MerkleTree<P> {
{
let start_index = level_indices.pop().unwrap();
let upper_bound = left_child(start_index);
for current_index in start_index..upper_bound {
// `left_child(current_index)` and `right_child(current_index) returns the position of
// leaf in the whole tree (represented as a list in level order). We need to shift it
// by `-upper_bound` to get the index in `leaf_nodes` list.
let left_leaf_index = left_child(current_index) - upper_bound;
let right_leaf_index = right_child(current_index) - upper_bound;
// compute hash
non_leaf_nodes[current_index] = P::TwoToOneHash::evaluate(
&two_to_one_hash_param,
P::LeafInnerDigestConverter::convert(leaves_digest[left_leaf_index].clone())?,
P::LeafInnerDigestConverter::convert(leaves_digest[right_leaf_index].clone())?,
)?
}

cfg_iter_mut!(non_leaf_nodes[start_index..upper_bound])
.enumerate()
.try_for_each(|(i, n)| {
// `left_child(current_index)` and `right_child(current_index) returns the position of
// leaf in the whole tree (represented as a list in level order). We need to shift it
// by `-upper_bound` to get the index in `leaf_nodes` list.

//similarly, we need to rescale i by start_index
//to get the index outside the slice and in the level-ordered list of nodes

let current_index = i + start_index;
let left_leaf_index = left_child(current_index) - upper_bound;
let right_leaf_index = right_child(current_index) - upper_bound;

*n = P::TwoToOneHash::evaluate(
two_to_one_hash_param,
P::LeafInnerDigestConverter::convert(
leaf_digests[left_leaf_index].clone(),
)?,
P::LeafInnerDigestConverter::convert(
leaf_digests[right_leaf_index].clone(),
)?,
)?;
Ok::<(), crate::Error>(())
})?;
}

// compute the hash values for nodes in every other layer in the tree
level_indices.reverse();
for &start_index in &level_indices {
// The layer beginning `start_index` ends at `upper_bound` (exclusive).
let upper_bound = left_child(start_index);
for current_index in start_index..upper_bound {
let left_index = left_child(current_index);
let right_index = right_child(current_index);
non_leaf_nodes[current_index] = P::TwoToOneHash::compress(
&two_to_one_hash_param,
non_leaf_nodes[left_index].clone(),
non_leaf_nodes[right_index].clone(),
)?
}
}

let (nodes_at_level, nodes_at_prev_level) =
non_leaf_nodes[..].split_at_mut(upper_bound);
// Iterate over the nodes at the current level, and compute the hash of each node
cfg_iter_mut!(nodes_at_level[start_index..])
.enumerate()
.try_for_each(|(i, n)| {
// `left_child(current_index)` and `right_child(current_index) returns the position of
// leaf in the whole tree (represented as a list in level order). We need to shift it
// by `-upper_bound` to get the index in `leaf_nodes` list.

//similarly, we need to rescale i by start_index
//to get the index outside the slice and in the level-ordered list of nodes
let current_index = i + start_index;
let left_leaf_index = left_child(current_index) - upper_bound;
let right_leaf_index = right_child(current_index) - upper_bound;

// need for unwrap as Box<Error> does not implement trait Send
*n = P::TwoToOneHash::compress(
two_to_one_hash_param,
nodes_at_prev_level[left_leaf_index].clone(),
nodes_at_prev_level[right_leaf_index].clone(),
)?;
Ok::<_, crate::Error>(())
})?;
}
Ok(MerkleTree {
leaf_nodes: leaves_digest,
leaf_nodes: leaf_digests,
non_leaf_nodes,
height: tree_height,
leaf_hash_param: leaf_hash_param.clone(),
Expand Down
15 changes: 3 additions & 12 deletions src/merkle_tree/tests/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,8 @@ mod byte_mt_tests {

let leaf_crh_params = <LeafH as CRHScheme>::setup(&mut rng).unwrap();
let two_to_one_crh_params = <CompressH as TwoToOneCRHScheme>::setup(&mut rng).unwrap();
let mut tree = JubJubMerkleTree::new(
&leaf_crh_params,
&two_to_one_crh_params,
leaves.iter().map(|v| v.as_slice()),
)
.unwrap();
let mut tree =
JubJubMerkleTree::new(&leaf_crh_params, &two_to_one_crh_params, leaves).unwrap();
let root = tree.root();
for (i, leaf) in leaves.iter().enumerate() {
let cs = ConstraintSystem::<Fq>::new_ref();
Expand Down Expand Up @@ -288,12 +284,7 @@ mod field_mt_tests {
) {
let leaf_crh_params = poseidon_parameters();
let two_to_one_params = leaf_crh_params.clone();
let mut tree = FieldMT::new(
&leaf_crh_params,
&two_to_one_params,
leaves.iter().map(|x| x.as_slice()),
)
.unwrap();
let mut tree = FieldMT::new(&leaf_crh_params, &two_to_one_params, leaves).unwrap();
let root = tree.root();
for (i, leaf) in leaves.iter().enumerate() {
let cs = ConstraintSystem::<F>::new_ref();
Expand Down
Loading

0 comments on commit 8bbb482

Please sign in to comment.