diff --git a/protocol/src/fschacha20poly1305.rs b/protocol/src/fschacha20poly1305.rs index df3720f..5c9c837 100644 --- a/protocol/src/fschacha20poly1305.rs +++ b/protocol/src/fschacha20poly1305.rs @@ -1,12 +1,13 @@ -use alloc::{fmt, vec::Vec}; +use alloc::fmt; use crate::chacha20poly1305::chacha20::ChaCha20; use crate::chacha20poly1305::ChaCha20Poly1305; const CHACHA_BLOCKS_USED: u32 = 3; -pub(crate) const REKEY_INTERVAL: u32 = 224; +const REKEY_INTERVAL: u32 = 224; const REKEY_INITIAL_NONCE: [u8; 4] = [0xFF, 0xFF, 0xFF, 0xFF]; +/// Errors encrypting and decrypting with FSChaCha20Poly1305. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum Error { Encryption, @@ -32,11 +33,6 @@ impl std::error::Error for Error { } } -pub enum CryptType { - Encrypt, - Decrypt, -} - /// A wrapper over ChaCha20Poly1305 AEAD stream cipher which handles automatically changing /// nonces and re-keying. /// @@ -55,74 +51,75 @@ impl FSChaCha20Poly1305 { } } - fn crypt( - &mut self, - aad: Vec, - contents: Vec, - crypt_type: CryptType, - ) -> Result, Error> { - let mut counter_div = (self.message_counter / REKEY_INTERVAL) - .to_le_bytes() - .to_vec(); - counter_div.extend([0u8; 4]); // ok? invalid for 4 billion messages + /// Derive current nonce. + fn nonce(&self) -> [u8; 12] { + let counter_div = (self.message_counter / REKEY_INTERVAL).to_le_bytes(); let counter_mod = (self.message_counter % REKEY_INTERVAL).to_le_bytes(); - let mut nonce = counter_mod.to_vec(); - nonce.extend(counter_div); // mod slice then div slice - let cipher = - ChaCha20Poly1305::new(self.key, nonce.try_into().expect("Nonce is malformed.")); - let converted_ciphertext: Vec = match crypt_type { - CryptType::Encrypt => { - let mut buffer = contents.clone(); - let tag = cipher - .encrypt(&mut buffer, Some(&aad)) - .map_err(|_| Error::Encryption)?; - buffer.extend(tag); - buffer - } - CryptType::Decrypt => { - let mut ciphertext = contents.clone(); - let ciphertext_len = ciphertext.len(); - let (mut ciphertext, tag) = ciphertext.split_at_mut(ciphertext_len - 16); - cipher - .decrypt( - &mut ciphertext, - tag.try_into().expect("16 byte tag"), - Some(&aad), - ) - .map_err(|_| Error::Decryption)?; - ciphertext.to_vec() - } - }; + let mut nonce = [0u8; 12]; + nonce[0..4].copy_from_slice(&counter_mod); + nonce[4..8].copy_from_slice(&counter_div); + + nonce + } + + /// Increment the message counter and rekey if necessary. + fn rekey(&mut self, aad: &[u8]) -> Result<(), Error> { if (self.message_counter + 1) % REKEY_INTERVAL == 0 { - let mut rekey_nonce = REKEY_INITIAL_NONCE.to_vec(); - let mut counter_div = (self.message_counter / REKEY_INTERVAL) - .to_le_bytes() - .to_vec(); - counter_div.extend([0u8; 4]); - let counter_mod = (self.message_counter % REKEY_INTERVAL).to_le_bytes(); - let mut nonce = counter_mod.to_vec(); - nonce.extend(counter_div); - rekey_nonce.extend(nonce[4..].to_vec()); + let mut rekey_nonce = [0u8; 12]; + rekey_nonce[0..4].copy_from_slice(&REKEY_INITIAL_NONCE); + rekey_nonce[4..].copy_from_slice(&self.nonce()[4..]); + let mut plaintext = [0u8; 32]; - let cipher = ChaCha20Poly1305::new( - self.key, - rekey_nonce.try_into().expect("Nonce is malformed."), - ); + let cipher = ChaCha20Poly1305::new(self.key, rekey_nonce); cipher - .encrypt(&mut plaintext, Some(&aad)) + .encrypt(&mut plaintext, Some(aad)) .map_err(|_| Error::Encryption)?; self.key = plaintext; } + self.message_counter += 1; - Ok(converted_ciphertext) + Ok(()) } - pub fn encrypt(&mut self, aad: Vec, contents: Vec) -> Result, Error> { - self.crypt(aad, contents, CryptType::Encrypt) + /// Encrypt the contents in place and return the 16-byte authentication tag. + /// + /// # Arguments + /// + /// - `content` - Plaintext to be encrypted in place. + /// - `aad` - Optional metadata covered by the authentication tag. + /// + /// # Returns + /// + /// The 16-byte authentication tag. + pub fn encrypt(&mut self, aad: &[u8], content: &mut [u8]) -> Result<[u8; 16], Error> { + let cipher = ChaCha20Poly1305::new(self.key, self.nonce()); + + let tag = cipher + .encrypt(content, Some(aad)) + .map_err(|_| Error::Encryption)?; + + self.rekey(aad)?; + + Ok(tag) } - pub fn decrypt(&mut self, aad: Vec, contents: Vec) -> Result, Error> { - self.crypt(aad, contents, CryptType::Decrypt) + /// Decrypt the contents in place. + /// + /// # Arguments + /// + /// - `content` - Ciphertext to be decrypted in place. + /// - `tag` - 16-byte authentication tag. + /// - `aad` - Optional metadata covered by the authentication tag. + pub fn decrypt(&mut self, aad: &[u8], content: &mut [u8], tag: [u8; 16]) -> Result<(), Error> { + let cipher = ChaCha20Poly1305::new(self.key, self.nonce()); + + cipher + .decrypt(content, tag, Some(aad)) + .map_err(|_| Error::Decryption)?; + + self.rekey(aad)?; + + Ok(()) } } @@ -147,6 +144,7 @@ impl FSChaCha20 { } } + /// Encrypt or decrypt the 3-byte length encodings. pub fn crypt(&mut self, chunk: &mut [u8; 3]) -> Result<(), Error> { let counter_mod = (self.chunk_counter / REKEY_INTERVAL).to_le_bytes(); let mut nonce = [0u8; 12]; diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 707b041..7ce07ec 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -140,7 +140,7 @@ impl PacketReader { /// /// # Arguments /// - /// `len_bytes` - The first three bytes of the ciphertext. + /// - `len_bytes` - The first three bytes of the ciphertext. /// /// # Returns /// @@ -163,9 +163,8 @@ impl PacketReader { /// /// # Arguments /// - /// `contents` - The message from the peer. - /// - /// `aad` - Optional authentication for the peer, currently only used for the first round of messages. + /// - `contents` - The message from the peer. + /// - `aad` - Optional authentication for the peer, currently only used for the first round of messages. /// /// # Returns /// @@ -180,14 +179,21 @@ impl PacketReader { aad: Option>, ) -> Result { let auth = aad.unwrap_or_default(); - let plaintext = self.packet_decoding_cipher.decrypt(auth, contents)?; - let header = *plaintext + let mut contents = contents.clone(); + let contents_len = contents.len(); + let (ciphertext, tag) = contents.split_at_mut(contents_len - 16); + self.packet_decoding_cipher.decrypt( + &auth, + ciphertext, + tag.try_into().expect("16 bytes"), + )?; + let header = *ciphertext .first() .expect("All contents should include a header."); if header.eq(&DECOY) { return Ok(ReceivedMessage { message: None }); } - let message = plaintext[1..].to_vec(); + let message = ciphertext[1..].to_vec(); Ok(ReceivedMessage { message: Some(message), }) @@ -201,26 +207,24 @@ pub struct PacketWriter { } impl PacketWriter { - /// Prepare a vector of bytes to be encrypted and sent over the wire. + /// Encrypt plaintext bytes to be sent over the wire. /// /// # Arguments /// - /// `contents` - The Bitcoin P2P protocol message to send. - /// - /// `aad` - Optional authentication for the peer, currently only used for the first round of messages. - /// - /// `decoy` - Should the peer ignore this message. + /// - `plaintext` - Plaintext to be encrypted. + /// - `aad` - Optional authentication for the peer, currently only used for the first round of messages. + /// - `decoy` - Should the peer ignore this message. /// /// # Returns /// - /// A ciphertext to send over the wire. + /// An encrypted packet to send over the wire. /// /// # Errors /// /// Fails if the packet was not encrypted properly. pub fn prepare_v2_packet( &mut self, - contents: Vec, + plaintext: Vec, aad: Option>, decoy: bool, ) -> Result, Error> { @@ -230,16 +234,17 @@ impl PacketWriter { header = DECOY; } let mut content_len = [0u8; 3]; - content_len.copy_from_slice(&(contents.len() as u32).to_le_bytes()[0..LENGTH_FIELD_LEN]); - let mut plaintext = vec![header]; - plaintext.extend(contents); + content_len.copy_from_slice(&(plaintext.len() as u32).to_le_bytes()[0..LENGTH_FIELD_LEN]); + let mut content = vec![header]; + content.extend(plaintext); let auth = aad.unwrap_or_default(); self.length_encoding_cipher .crypt(&mut content_len) .expect("encrypt length"); - let enc_packet = self.packet_encoding_cipher.encrypt(auth, plaintext)?; + let tag = self.packet_encoding_cipher.encrypt(&auth, &mut content)?; packet.extend(&content_len); - packet.extend(enc_packet); + packet.extend(content); + packet.extend(tag); Ok(packet) } } @@ -417,18 +422,23 @@ impl PacketHandler { if start_index as u32 + aead_len + 3 < ciphertext.len() as u32 { next_content = Some((start_index as u32 + aead_len + 3) as usize); } - let aead = ciphertext[start_index + 3..start_index + (aead_len as usize) + 3].to_vec(); - let plaintext = self - .packet_reader - .packet_decoding_cipher - .decrypt(auth.to_vec(), aead)?; - let header = *plaintext + + let mut aead = ciphertext[start_index + 3..start_index + (aead_len as usize) + 3].to_vec(); + let aead_len = aead.len(); + let (ciphertext, tag) = aead.split_at_mut(aead_len - 16); + + self.packet_reader.packet_decoding_cipher.decrypt( + auth, + ciphertext, + tag.try_into().expect("16 bytes"), + )?; + let header = *ciphertext .first() .expect("All contents should include a header."); if header.eq(&DECOY) { return Ok((None, next_content)); } - let message = plaintext[1..].to_vec(); + let message = ciphertext[1..].to_vec(); Ok((Some(message), next_content)) } } @@ -887,7 +897,8 @@ mod tests { ); let mut alice_packet_handler = PacketHandler::new(session_keys.clone(), Role::Initiator); let mut bob_packet_handler = PacketHandler::new(session_keys, Role::Responder); - for _ in 0..fschacha20poly1305::REKEY_INTERVAL + 100 { + // Force a rekey under the hood. + for _ in 0..(224 + 100) { let message = gen_garbage(4095, &mut rng); let enc_packet = alice_packet_handler .prepare_v2_packet(message.clone(), None, false) @@ -1067,7 +1078,8 @@ mod tests { .unwrap(); let mut message_to_bob = Vec::new(); - for _ in 0..fschacha20poly1305::REKEY_INTERVAL + 100 { + // Force a rekey under the hood. + for _ in 0..(224 + 100) { let message = gen_garbage(420, &mut rng); let enc_packet = alice .prepare_v2_packet(message.clone(), None, false)