Skip to content

Commit

Permalink
Drop ReceivedMessage in favor of enum type
Browse files Browse the repository at this point in the history
  • Loading branch information
nyonson committed Sep 25, 2024
1 parent 670189e commit cdaf019
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 51 deletions.
95 changes: 54 additions & 41 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,32 +168,32 @@ pub struct SessionKeyMaterial {
}

/// Your role in the handshake.
#[derive(Clone, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Role {
/// You started the handshake with a peer.
Initiator,
/// You are responding to a handshake.
Responder,
}

/// A message or decoy packet from a connected peer.
#[cfg(feature = "alloc")]
#[derive(Clone, Debug)]
pub struct ReceivedMessage {
/// A message to handle or `None` if the peer sent a decoy and the message may be safely ignored.
pub message: Option<Vec<u8>>,
/// A decoy packet contains bogus information, but can be
/// used to hide the shape of the data being communicated
/// over an encrypted channel.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PacketType {
/// Genuine packet contains information.
Genuine,
/// Docoy packet contians bogus information.
Decoy,
}

#[cfg(feature = "alloc")]
impl ReceivedMessage {
pub fn new(msg_bytes: &[u8]) -> Result<Self, Error> {
let header = msg_bytes.first().ok_or(Error::CiphertextTooSmall)?;
if header.eq(&DECOY_BYTE) {
Ok(ReceivedMessage { message: None })
impl PacketType {
/// Check if plaintext packet packets are a decoy.
pub fn from_bytes(plaintext: &[u8]) -> Self {
if plaintext.first() == Some(&DECOY_BYTE) {
PacketType::Decoy
} else {
Ok(ReceivedMessage {
message: Some(msg_bytes[1..].to_vec()),
})
PacketType::Genuine
}
}
}
Expand Down Expand Up @@ -240,15 +240,22 @@ impl PacketReader {
/// - `contents` - Mutable buffer to write plaintext.
/// - `aad` - Optional authentication for the peer.
///
/// # Returns
///
/// A `Result` containing:
/// - `Ok(PacketType)`: A flag indicating if the decoded packet is a decoy or not.
/// - `Err(Error)`: An error that occurred during decryption.
///
/// # Errors
///
/// Fails if the packet was not decrypted or authenticated properly.
/// - `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet.
/// - `BufferTooSmall ` - Contents buffer argument is not large enough for plaintext.
pub fn decrypt_contents(
&mut self,
ciphertext: &[u8],
contents: &mut [u8],
aad: Option<&[u8]>,
) -> Result<(), Error> {
) -> Result<PacketType, Error> {
let auth = aad.unwrap_or_default();
// Check minimum size of ciphertext.
if ciphertext.len() < TAG_BYTES {
Expand All @@ -268,7 +275,7 @@ impl PacketReader {
tag.try_into().expect("16 byte tag"),
)?;

Ok(())
Ok(PacketType::from_bytes(contents))
}

/// Decrypt the rest of the message from the peer, excluding the 3 length bytes. This method should only be called after
Expand All @@ -279,18 +286,30 @@ impl PacketReader {
/// - `ciphertext` - The message from the peer.
/// - `aad` - Optional authentication for the peer.
///
/// # Returns
///
/// A `Result` containing:
/// - `Ok(Some(Vec<u8>))`: The plaintext in a byte vector if it is not a decoy packet.
/// - `Err(Error)`: An error that occurred during decryption.
///
/// # Errors
///
/// Fails if the packet was not decrypted or authenticated properly.
/// - `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet.
#[cfg(feature = "alloc")]
pub fn decrypt_contents_with_alloc(
&mut self,
ciphertext: &[u8],
aad: Option<&[u8]>,
) -> Result<Vec<u8>, Error> {
) -> Result<Option<Vec<u8>>, Error> {
let mut contents = vec![0u8; ciphertext.len() - TAG_BYTES];
self.decrypt_contents(ciphertext, &mut contents, aad)?;
Ok(contents)
match self.decrypt_contents(ciphertext, &mut contents, aad)? {
PacketType::Decoy => Ok(None),
PacketType::Genuine => {
// Drop the decoy byte flag.
contents.remove(0);
Ok(Some(contents))
}
}
}
}

Expand Down Expand Up @@ -492,14 +511,9 @@ impl PacketHandler {
&mut self,
ciphertext: &[u8],
aad: Option<&[u8]>,
) -> Result<ReceivedMessage, Error> {
let contents = self
.packet_reader
.decrypt_contents_with_alloc(ciphertext, aad)?;

let message = ReceivedMessage::new(&contents)?;

Ok(message)
) -> Result<Option<Vec<u8>>, Error> {
self.packet_reader
.decrypt_contents_with_alloc(ciphertext, aad)
}
}

Expand Down Expand Up @@ -722,7 +736,7 @@ impl<'a> Handshake<'a> {
}
};

let mut packet_handler = PacketHandler::new(materials, self.role.clone());
let mut packet_handler = PacketHandler::new(materials, self.role);

// TODO: Support sending decoy packets before the version packet.

Expand Down Expand Up @@ -1063,15 +1077,15 @@ mod tests {
let dec = bob_packet_handler
.decrypt_contents_with_alloc(&enc_packet[LENGTH_BYTES..], None)
.unwrap();
assert_eq!(None, dec.message);
assert_eq!(None, dec);
let message = b"Windows sox!".to_vec();
let enc_packet = bob_packet_handler
.prepare_packet_with_alloc(&message, None, false)
.unwrap();
let dec = alice_packet_handler
.decrypt_contents_with_alloc(&enc_packet[LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec.message.unwrap());
assert_eq!(message, dec.unwrap());
}

#[test]
Expand Down Expand Up @@ -1102,15 +1116,15 @@ mod tests {
let dec_packet = bob_packet_handler
.decrypt_contents_with_alloc(&enc_packet[LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec_packet.message.unwrap());
assert_eq!(message, dec_packet.unwrap());
let message = gen_garbage(420, &mut rng);
let enc_packet = bob_packet_handler
.prepare_packet_with_alloc(&message, None, false)
.unwrap();
let dec_packet = alice_packet_handler
.decrypt_contents_with_alloc(&enc_packet[LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec_packet.message.unwrap());
assert_eq!(message, dec_packet.unwrap());
}
}

Expand Down Expand Up @@ -1193,15 +1207,15 @@ mod tests {
let dec = alice
.decrypt_contents_with_alloc(&encrypted_message_to_alice[LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec.message.unwrap());
assert_eq!(message, dec.unwrap());
let message = b"g!".to_vec();
let encrypted_message_to_bob = alice
.prepare_packet_with_alloc(&message, None, false)
.unwrap();
let dec = bob
.decrypt_contents_with_alloc(&encrypted_message_to_bob[LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec.message.unwrap());
assert_eq!(message, dec.unwrap());
}

#[test]
Expand Down Expand Up @@ -1248,11 +1262,10 @@ mod tests {
let contents = bob
.decrypt_contents_with_alloc(&message_to_bob[3..3 + alice_message_len], None)
.unwrap();
assert_eq!(contents.message.unwrap(), message);
assert_eq!(contents.unwrap(), message);
}

// The rest are sourced from [the BIP324 test vectors](https://github.com/bitcoin/bips/blob/master/bip-0324/packet_encoding_test_vectors.csv).
//

#[test]
#[cfg(feature = "std")]
Expand Down Expand Up @@ -1280,7 +1293,7 @@ mod tests {
let dec_packet = bob_packet_handler
.decrypt_contents_with_alloc(&enc[LENGTH_BYTES..], None)
.unwrap();
assert_eq!(first, dec_packet.message.unwrap());
assert_eq!(first, dec_packet.unwrap());
let contents: Vec<u8> = vec![0x8e];
let enc = alice_packet_handler
.prepare_packet_with_alloc(&contents, None, false)
Expand Down
9 changes: 4 additions & 5 deletions protocol/tests/round_trips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ fn hello_world_happy_path() {
let messages = alice
.decrypt_contents_with_alloc(&encrypted_message_to_alice[3..], None)
.unwrap();
assert_eq!(message, messages.message.unwrap());
assert_eq!(message, messages.unwrap());
let message = b"Goodbye!".to_vec();
let encrypted_message_to_bob = alice
.prepare_packet_with_alloc(&message, None, false)
.unwrap();
let messages = bob
.decrypt_contents_with_alloc(&encrypted_message_to_bob[3..], None)
.unwrap();
assert_eq!(message, messages.message.unwrap());
assert_eq!(message, messages.unwrap());
}

#[test]
Expand All @@ -71,7 +71,7 @@ fn regtest_handshake() {

use bip324::{
serde::{deserialize, serialize, NetworkMessage},
Handshake, ReceivedMessage,
Handshake,
};
use bitcoincore_rpc::{
bitcoin::p2p::{message_network::VersionMessage, Address, ServiceFlags},
Expand Down Expand Up @@ -148,8 +148,7 @@ fn regtest_handshake() {
let msg = decrypter
.decrypt_contents_with_alloc(&response_message, None)
.unwrap();
let message = ReceivedMessage::new(&msg.clone()).unwrap();
let message = deserialize(&message.message.unwrap()).unwrap();
let message = deserialize(&msg.unwrap()).unwrap();
dbg!("{}", message.cmd());
assert_eq!(message.cmd(), "version");
rpc.stop().unwrap();
Expand Down
6 changes: 1 addition & 5 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use std::fmt;
use std::net::SocketAddr;

use bip324::serde::{deserialize, serialize};
use bip324::ReceivedMessage;
use bip324::{PacketReader, PacketWriter};
use bitcoin::consensus::{Decodable, Encodable};
use bitcoin::p2p::message::{NetworkMessage, RawNetworkMessage};
Expand Down Expand Up @@ -142,10 +141,7 @@ pub async fn read_v2<T: AsyncRead + Unpin>(
.decrypt_contents_with_alloc(&packet_bytes, None)
.expect("decrypt");

let contents = ReceivedMessage::new(&raw)
.expect("some bytes")
.message
.expect("not a decoy");
let contents = raw.expect("not a decoy");

let message = deserialize(&contents).map_err(|_| Error::Serde)?;
Ok(message)
Expand Down

0 comments on commit cdaf019

Please sign in to comment.