diff --git a/src/chacha20poly1305.rs b/src/chacha20poly1305.rs index 2ae230f..01f3c32 100644 --- a/src/chacha20poly1305.rs +++ b/src/chacha20poly1305.rs @@ -2,7 +2,7 @@ mod chacha20; mod poly1305; use crate::error; -use chacha20::ChaCha20; +pub use chacha20::ChaCha20; use poly1305::Poly1305; use error::ChaCha20Poly1305DecryptionError; @@ -10,6 +10,9 @@ use error::ChaCha20Poly1305EncryptionError; use alloc::string::ToString; +// Zero array for padding slices. +const ZEROES: [u8; 16] = [0u8; 16]; + #[derive(Debug)] pub struct ChaCha20Poly1305 { key: [u8; 32], @@ -39,8 +42,19 @@ impl ChaCha20Poly1305 { .expect("32 is a valid subset of 64."), ); let aad = aad.unwrap_or(&[]); + // AAD and ciphertext are padded if not 16-byte aligned. poly.add(aad); + let aad_overflow = aad.len() % 16; + if aad_overflow > 0 { + poly.add(&ZEROES[0..(16 - aad_overflow)]); + } + poly.add(plaintext); + let text_overflow = plaintext.len() % 16; + if text_overflow > 0 { + poly.add(&ZEROES[0..(16 - text_overflow)]); + } + let aad_len = aad.len().to_le_bytes(); let msg_len = plaintext.len().to_le_bytes(); let mut len_buffer = [0u8; 16]; @@ -79,7 +93,17 @@ impl ChaCha20Poly1305 { if ciphertext.len() >= 16 { let (received_msg, received_tag) = ciphertext.split_at_mut(ciphertext.len() - 16); poly.add(aad); + // AAD and ciphertext are padded if not 16-byte aligned. + let aad_overflow = aad.len() % 16; + if aad_overflow > 0 { + poly.add(&ZEROES[0..(16 - aad_overflow)]); + } poly.add(received_msg); + let msg_overflow = received_msg.len() % 16; + if msg_overflow > 0 { + poly.add(&ZEROES[0..(16 - msg_overflow)]); + } + let aad_len = aad.len().to_le_bytes(); let msg_len = received_msg.len().to_le_bytes(); let mut len_buffer = [0u8; 16]; diff --git a/src/chacha20poly1305/chacha20.rs b/src/chacha20poly1305/chacha20.rs index 1c36eed..a70eeb6 100644 --- a/src/chacha20poly1305/chacha20.rs +++ b/src/chacha20poly1305/chacha20.rs @@ -30,7 +30,7 @@ const CHACHA_BLOCKSIZE: usize = 64; /// The ChaCha20 stream cipher. #[derive(Debug)] -pub(crate) struct ChaCha20 { +pub struct ChaCha20 { /// A 256 bit secret session key shared by the parties communitcating. key: [u8; 32], /// A 96 bit initialization vector (IV), or nonce. A key/nonce pair should only be used once. @@ -423,11 +423,11 @@ mod tests { assert_eq!(binding, to); } - fn gen_garbage(garbage_len: u32) -> Vec { - let mut rng = rand::thread_rng(); - let buffer: Vec = (0..garbage_len).map(|_| rng.gen()).collect(); - buffer - } + // fn gen_garbage(garbage_len: u32) -> Vec { + // let mut rng = rand::thread_rng(); + // let buffer: Vec = (0..garbage_len).map(|_| rng.gen()).collect(); + // buffer + // } // #[test] // fn test_fuzz_other() { diff --git a/src/chacha20poly1305/poly1305.rs b/src/chacha20poly1305/poly1305.rs index 09e78ba..80af23b 100644 --- a/src/chacha20poly1305/poly1305.rs +++ b/src/chacha20poly1305/poly1305.rs @@ -3,26 +3,32 @@ //! Implementation heavily inspired by [this implementation in C](https://github.com/floodyberry/poly1305-donna/blob/master/poly1305-donna-32.h) //! referred to as "Donna". Further reference to [this](https://loup-vaillant.fr/tutorials/poly1305-design) article was used to formulate the multiplication loop. -/// +/// 2^26 for the 26-bit limbs. const BITMASK: u32 = 0x03ffffff; -/// Number is encoded in five 26-bit "limbs". +/// Number is encoded in five 26-bit limbs. const CARRY: u32 = 26; -/// Poly1305 authenticator. +/// Poly1305 authenticator takes a 32-byte one-time key and a message and produces a 16-byte tag. +/// +/// 64-bit constant time multiplication and addition implementation. #[derive(Debug)] pub(crate) struct Poly1305 { - /// + /// r part of the secret key. r: [u32; 5], - /// + /// s part of the secret key. s: [u32; 4], - /// + /// State used to create tag. acc: [u32; 5], + /// Leftovers between adds. + leftovers: [u8; 16], + /// Track relevant leftover bytes. + leftovers_len: usize, } impl Poly1305 { - /// Initialize + /// Initialize authenticator with a 32-byte one-time secret key. pub(crate) fn new(key: [u8; 32]) -> Self { - // taken from donna. assigns R to a 26-bit 5-limb number while simultaneously 'clamping' R + // Taken from donna. Assigns r to a 26-bit 5-limb number while simultaneously 'clamping' r. let r0 = u32::from_le_bytes(key[0..4].try_into().expect("Valid subset of 32.")) & 0x3ffffff; let r1 = u32::from_le_bytes(key[3..7].try_into().expect("Valid subset of 32.")) >> 2 & 0x03ffff03; @@ -39,31 +45,77 @@ impl Poly1305 { let s3 = u32::from_le_bytes(key[28..32].try_into().expect("Valid subset of 32.")); let s = [s0, s1, s2, s3]; let acc = [0; 5]; - Poly1305 { r, s, acc } + + // Initilize leftovers to zero. + let leftovers = [0u8; 16]; + let leftovers_len = 0; + + Poly1305 { + r, + s, + acc, + leftovers, + leftovers_len, + } } - /// Add message to be authenticated. + /// Add message to be authenticated, can be called multiple times before creating tag. pub(crate) fn add(&mut self, message: &[u8]) { - let mut i = 0; - while i < message.len() / 16 { - let msg_slice = prepare_padded_message_slice(&message[i * 16..(i + 1) * 16], false); + // Deal with previous leftovers if message is long enough. + let fill = if self.leftovers_len > 0 && (self.leftovers_len + message.len() >= 16) { + 16 - self.leftovers_len + } else { + 0 + }; + if fill > 0 { + self.leftovers[self.leftovers_len..].copy_from_slice(&message[0..fill]); + + let msg_slice = prepare_padded_message_slice(&self.leftovers, false); for (i, b) in msg_slice.iter().enumerate() { self.acc[i] += *b; } self.r_times_a(); - i += 1; + self.leftovers_len = 0; } - if message.len() % 16 > 0 { - let msg_slice = prepare_padded_message_slice(&message[i * 16..], true); + + // Remove prefix already processed in leftovers. + let remaining_message = &message[fill..]; + + // Add message to accumulator. + let mut i = 0; + while i < remaining_message.len() / 16 { + let msg_slice = + prepare_padded_message_slice(&remaining_message[i * 16..(i + 1) * 16], false); for (i, b) in msg_slice.iter().enumerate() { self.acc[i] += *b; } self.r_times_a(); + i += 1; + } + + // Save any leftovers. + if remaining_message.len() % 16 > 0 { + let message_index = remaining_message.len() - (remaining_message.len() % 16); + let new_len = self.leftovers_len + remaining_message.len() % 16; + self.leftovers[self.leftovers_len..new_len] + .copy_from_slice(&remaining_message[message_index..]); + self.leftovers_len = new_len; } } /// Generate authentication tag. pub(crate) fn tag(&mut self) -> [u8; 16] { + // Add any remaining leftovers to accumulator. + if self.leftovers_len > 0 { + let msg_slice = + prepare_padded_message_slice(&self.leftovers[..self.leftovers_len], true); + for (i, b) in msg_slice.iter().enumerate() { + self.acc[i] += *b; + } + self.r_times_a(); + self.leftovers_len = 0; + } + // Carry and mask. for i in 1..4 { self.acc[i + 1] += self.acc[i] >> CARRY; @@ -73,7 +125,7 @@ impl Poly1305 { for i in 0..self.acc.len() { self.acc[i] &= BITMASK; } - // Reduce + // Reduce. let mut t = self.acc; t[0] += 5; t[4] = t[4].wrapping_sub(1 << CARRY); @@ -101,12 +153,12 @@ impl Poly1305 { tag[i] = a[i] as u64 + self.s[i] as u64; } - // Carry + // Carry. for i in 0..3 { tag[i + 1] += tag[i] >> 32; } - // return the 16 least significant bytes + // Return the 16 least significant bytes. let mut ret: [u8; 16] = [0; 16]; for i in 0..tag.len() { let bytes = (tag[i] as u32).to_le_bytes(); @@ -127,11 +179,11 @@ impl Poly1305 { *t += modulus * self.r[i] as u64 * self.acc[(start + j) % 5] as u64; } } - // Carry + // Carry. for i in 0..4 { t[i + 1] += t[i] >> CARRY; } - // Mask + // Mask. for (i, t) in t.iter().enumerate().take(self.acc.len()) { self.acc[i] = *t as u32 & BITMASK; } @@ -142,11 +194,14 @@ impl Poly1305 { } } +// Encode 16-byte (tag sized), unless is_last flag set to true, piece of message into 5 26-bit limbs. fn prepare_padded_message_slice(msg: &[u8], is_last: bool) -> [u32; 5] { let hi_bit: u32 = if is_last { 0 } else { 1 << 24 }; let mut fmt_msg = [0u8; 17]; fmt_msg[..msg.len()].clone_from_slice(msg); - fmt_msg[16] = 0x01; + // Tack on a 1-byte so messages with buncha zeroes at the end don't have the same MAC. + fmt_msg[msg.len()] = 0x01; + // Encode number in five 26-bit limbs. let m0 = u32::from_le_bytes(fmt_msg[0..4].try_into().expect("Valid subset of 32.")) & BITMASK; let m1 = u32::from_le_bytes(fmt_msg[3..7].try_into().expect("Valid subset of 32.")) >> 2 & BITMASK; @@ -154,11 +209,8 @@ fn prepare_padded_message_slice(msg: &[u8], is_last: bool) -> [u32; 5] { u32::from_le_bytes(fmt_msg[6..10].try_into().expect("Valid subset of 32.")) >> 4 & BITMASK; let m3 = u32::from_le_bytes(fmt_msg[9..13].try_into().expect("Valid subset of 32.")) >> 6 & BITMASK; - let m4: u32 = if is_last { - u32::from_le_bytes(fmt_msg[13..17].try_into().expect("Valid subset of 32.")) | hi_bit - } else { - u32::from_le_bytes(fmt_msg[12..16].try_into().expect("Valid subset of 32.")) >> 8 | hi_bit - }; + let m4 = + u32::from_le_bytes(fmt_msg[12..16].try_into().expect("Valid subset of 32.")) >> 8 | hi_bit; [m0, m1, m2, m3, m4] } @@ -181,7 +233,7 @@ mod tests { use super::*; #[test] - fn test_none_message() { + fn test_rfc7539_none_message() { let key = hex::decode("85d6be7857556d337f4452fe42d506a80103808afb0db2fd4abff6af4149f51b") .unwrap(); let key = key.as_slice().try_into().unwrap();