From 314c213527d35fe926aff28cfc6d17ae859c12cf Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Fri, 12 Apr 2024 08:40:19 -0700 Subject: [PATCH] Add split packet handler function --- protocol/src/lib.rs | 191 +++++++++++++++++++++++++++++++++----------- 1 file changed, 144 insertions(+), 47 deletions(-) diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index f55187c..a206821 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -126,15 +126,123 @@ pub struct ReceivedMessage { pub message: Option>, } -/// Encrypt and decrypt messages with a peer. #[derive(Clone, Debug)] -pub struct PacketHandler { - length_encoding_cipher: FSChaCha20, +pub struct PacketReader { length_decoding_cipher: FSChaCha20, - packet_encoding_cipher: FSChaCha20Poly1305, packet_decoding_cipher: FSChaCha20Poly1305, } +impl PacketReader { + /// Decode the length, in bytes, of the of the rest imbound message. Intended for use with `TcpStream` and `read_exact`. + /// Note that this does not decode to the length of contents described in BIP324, and is meant to represent the entire imbound message. + /// + /// # Arguments + /// + /// `len_slice` - The first three bytes of the message. + /// + /// # Returns + /// + /// The length to be read into the buffer next to receive the full message from the peer. + pub fn decypt_len(&mut self, len_slice: [u8; 3]) -> usize { + let mut enc_content_len = self.length_decoding_cipher.crypt(len_slice.to_vec()); + enc_content_len.push(0u8); + let content_slice: [u8; 4] = enc_content_len + .try_into() + .expect("Length of slice should be 4."); + let content_len = u32::from_le_bytes(content_slice); + content_len as usize + 17 + } + + /// Decrypt the rest of the message from the peer, excluding the 3 length bytes. This method should only be called after + /// calling `decrypt_len` on the first three bytes of the buffer. + /// + /// # Arguments + /// + /// `contents` - The message from the peer. + /// + /// `aad` - Optional authentication for the peer, currently only used for the first round of messages. + /// + /// # Returns + /// + /// The message from the peer. + /// + /// # Errors + /// + /// Fails if the packet was not decrypted or authenticated properly. + pub fn decrypt_contents( + &mut self, + contents: Vec, + aad: Option>, + ) -> Result { + let auth = aad.unwrap_or_default(); + let plaintext = self.packet_decoding_cipher.decrypt(auth, contents)?; + let header = *plaintext + .first() + .expect("All contents should include a header."); + if header.eq(&DECOY) { + return Ok(ReceivedMessage { message: None }); + } + let message = plaintext[1..].to_vec(); + Ok(ReceivedMessage { + message: Some(message), + }) + } +} + +#[derive(Clone, Debug)] +pub struct PacketWriter { + length_encoding_cipher: FSChaCha20, + packet_encoding_cipher: FSChaCha20Poly1305, +} + +impl PacketWriter { + /// Prepare a vector of bytes to be encrypted and sent over the wire. + /// + /// # Arguments + /// + /// `contents` - The Bitcoin P2P protocol message to send. + /// + /// `aad` - Optional authentication for the peer, currently only used for the first round of messages. + /// + /// `decoy` - Should the peer ignore this message. + /// + /// # Returns + /// + /// A ciphertext to send over the wire. + /// + /// # Errors + /// + /// Fails if the packet was not encrypted properly. + pub fn prepare_v2_packet( + &mut self, + contents: Vec, + aad: Option>, + decoy: bool, + ) -> Result, Error> { + let mut packet: Vec = Vec::new(); + let mut header: u8 = 0; + if decoy { + header = DECOY; + } + let content_len = (contents.len() as u32).to_le_bytes()[0..LENGTH_FIELD_LEN].to_vec(); + let mut plaintext = vec![header]; + plaintext.extend(contents); + let auth = aad.unwrap_or_default(); + let enc_len = self.length_encoding_cipher.crypt(content_len); + let enc_packet = self.packet_encoding_cipher.encrypt(auth, plaintext)?; + packet.extend(enc_len); + packet.extend(enc_packet); + Ok(packet) + } +} + +/// Encrypt and decrypt messages with a peer. +#[derive(Clone, Debug)] +pub struct PacketHandler { + packet_reader: PacketReader, + packet_writer: PacketWriter, +} + impl PacketHandler { fn new(materials: SessionKeyMaterial, role: Role) -> Self { match role { @@ -146,10 +254,14 @@ impl PacketHandler { let packet_decoding_cipher = FSChaCha20Poly1305::new(materials.responder_packet_key); PacketHandler { - length_encoding_cipher, - length_decoding_cipher, - packet_encoding_cipher, - packet_decoding_cipher, + packet_reader: PacketReader { + length_decoding_cipher, + packet_decoding_cipher, + }, + packet_writer: PacketWriter { + length_encoding_cipher, + packet_encoding_cipher, + }, } } Role::Responder => { @@ -160,15 +272,24 @@ impl PacketHandler { let packet_decoding_cipher = FSChaCha20Poly1305::new(materials.initiator_packet_key); PacketHandler { - length_encoding_cipher, - length_decoding_cipher, - packet_encoding_cipher, - packet_decoding_cipher, + packet_reader: PacketReader { + length_decoding_cipher, + packet_decoding_cipher, + }, + packet_writer: PacketWriter { + length_encoding_cipher, + packet_encoding_cipher, + }, } } } } + /// Split the handler into separate reader and a writer. + pub fn split(self) -> (PacketReader, PacketWriter) { + (self.packet_reader, self.packet_writer) + } + /// Prepare a vector of bytes to be encrypted and sent over the wire. /// /// # Arguments @@ -192,20 +313,7 @@ impl PacketHandler { aad: Option>, decoy: bool, ) -> Result, Error> { - let mut packet: Vec = Vec::new(); - let mut header: u8 = 0; - if decoy { - header = DECOY; - } - let content_len = (contents.len() as u32).to_le_bytes()[0..LENGTH_FIELD_LEN].to_vec(); - let mut plaintext = vec![header]; - plaintext.extend(contents); - let auth = aad.unwrap_or_default(); - let enc_len = self.length_encoding_cipher.crypt(content_len); - let enc_packet = self.packet_encoding_cipher.encrypt(auth, plaintext)?; - packet.extend(enc_len); - packet.extend(enc_packet); - Ok(packet) + self.packet_writer.prepare_v2_packet(contents, aad, decoy) } /// Decode the length, in bytes, of the of the rest imbound message. Intended for use with `TcpStream` and `read_exact`. @@ -219,13 +327,7 @@ impl PacketHandler { /// /// The length to be read into the buffer next to receive the full message from the peer. pub fn decypt_len(&mut self, len_slice: [u8; 3]) -> usize { - let mut enc_content_len = self.length_decoding_cipher.crypt(len_slice.to_vec()); - enc_content_len.push(0u8); - let content_slice: [u8; 4] = enc_content_len - .try_into() - .expect("Length of slice should be 4."); - let content_len = u32::from_le_bytes(content_slice); - content_len as usize + 17 + self.packet_reader.decypt_len(len_slice) } /// Decrypt the rest of the message from the peer, excluding the 3 length bytes. This method should only be called after @@ -249,18 +351,7 @@ impl PacketHandler { contents: Vec, aad: Option>, ) -> Result { - let auth = aad.unwrap_or_default(); - let plaintext = self.packet_decoding_cipher.decrypt(auth, contents)?; - let header = *plaintext - .first() - .expect("All contents should include a header."); - if header.eq(&DECOY) { - return Ok(ReceivedMessage { message: None }); - } - let message = plaintext[1..].to_vec(); - Ok(ReceivedMessage { - message: Some(message), - }) + self.packet_reader.decrypt_contents(contents, aad) } /// Decrypt the one or more messages from bytes received by a V2 peer. @@ -302,7 +393,10 @@ impl PacketHandler { start_index: usize, ) -> Result<(Option>, Option), Error> { let enc_content_len = ciphertext[start_index..LENGTH_FIELD_LEN + start_index].to_vec(); - let mut content_len = self.length_decoding_cipher.crypt(enc_content_len); + let mut content_len = self + .packet_reader + .length_decoding_cipher + .crypt(enc_content_len); content_len.push(0u8); let content_slice: [u8; 4] = content_len .try_into() @@ -317,7 +411,10 @@ impl PacketHandler { next_content = Some((start_index as u32 + aead_len + 3) as usize); } let aead = ciphertext[start_index + 3..start_index + (aead_len as usize) + 3].to_vec(); - let plaintext = self.packet_decoding_cipher.decrypt(auth.to_vec(), aead)?; + let plaintext = self + .packet_reader + .packet_decoding_cipher + .decrypt(auth.to_vec(), aead)?; let header = *plaintext .first() .expect("All contents should include a header.");