diff --git a/target_chains/starknet/contracts/src/lib.cairo b/target_chains/starknet/contracts/src/lib.cairo index 73ae1b38b5..c43ba79dd8 100644 --- a/target_chains/starknet/contracts/src/lib.cairo +++ b/target_chains/starknet/contracts/src/lib.cairo @@ -3,3 +3,4 @@ pub mod wormhole; pub mod reader; pub mod hash; pub mod util; +pub mod merkle_tree; diff --git a/target_chains/starknet/contracts/src/merkle_tree.cairo b/target_chains/starknet/contracts/src/merkle_tree.cairo new file mode 100644 index 0000000000..12633f75ab --- /dev/null +++ b/target_chains/starknet/contracts/src/merkle_tree.cairo @@ -0,0 +1,70 @@ +use super::hash::{Hasher, HasherImpl}; +use super::reader::{Reader, ReaderImpl, ByteArray}; +use super::util::ONE_SHIFT_96; +use core::cmp::{min, max}; + +const MERKLE_LEAF_PREFIX: u8 = 0; +const MERKLE_NODE_PREFIX: u8 = 1; +const MERKLE_EMPTY_LEAF_PREFIX: u8 = 2; + +#[derive(Copy, Drop, Debug, Serde, PartialEq)] +pub enum MerkleVerificationError { + Reader: super::reader::Error, + DigestMismatch, +} + +#[generate_trait] +impl ResultReaderToMerkleVerification of ResultReaderToMerkleVerificationTrait { + fn map_err(self: Result) -> Result { + match self { + Result::Ok(v) => Result::Ok(v), + Result::Err(err) => Result::Err(MerkleVerificationError::Reader(err)), + } + } +} + +fn leaf_hash(mut reader: Reader) -> Result { + let mut hasher = HasherImpl::new(); + hasher.push_u8(MERKLE_LEAF_PREFIX); + hasher.push_reader(ref reader)?; + let hash = hasher.finalize() / ONE_SHIFT_96; + Result::Ok(hash) +} + +fn node_hash(a: u256, b: u256) -> u256 { + let mut hasher = HasherImpl::new(); + hasher.push_u8(MERKLE_NODE_PREFIX); + hasher.push_u160(min(a, b)); + hasher.push_u160(max(a, b)); + hasher.finalize() / ONE_SHIFT_96 +} + +pub fn read_and_verify_proof( + root_digest: u256, message: @ByteArray, ref reader: Reader +) -> Result<(), MerkleVerificationError> { + let mut message_reader = ReaderImpl::new(message.clone()); + let mut current_hash = leaf_hash(message_reader.clone()).map_err()?; + + let proof_size = reader.read_u8().map_err()?; + let mut i = 0; + + let mut result = Result::Ok(()); + while i < proof_size { + match reader.read_u160().map_err() { + Result::Ok(sibling_digest) => { + current_hash = node_hash(current_hash, sibling_digest); + }, + Result::Err(err) => { + result = Result::Err(err); + break; + }, + } + i += 1; + }; + result?; + + if root_digest != current_hash { + return Result::Err(MerkleVerificationError::DigestMismatch); + } + Result::Ok(()) +}