Skip to content

Commit

Permalink
Merge pull request #78 from nyonson/cancellation-safety
Browse files Browse the repository at this point in the history
Improve cancellation safety by switching read_exacts to reads
  • Loading branch information
rustaceanrob authored Oct 17, 2024
2 parents 50892fd + a3f16bb commit 943733f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 28 deletions.
53 changes: 41 additions & 12 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1118,15 +1118,28 @@ where

/// State machine of an asynchronous packet read.
#[cfg(feature = "async")]
#[derive(Default, Debug)]
#[derive(Debug)]
enum DecryptState {
#[default]
ReadingLength,
ReadingLength {
length_bytes: [u8; 3],
bytes_read: usize,
},
ReadingPayload {
packet_bytes: Vec<u8>,
bytes_read: usize,
},
}

#[cfg(feature = "async")]
impl Default for DecryptState {
fn default() -> Self {
DecryptState::ReadingLength {
length_bytes: [0u8; 3],
bytes_read: 0,
}
}
}

/// Manages an async buffer to automatically decrypt contents of received packets.
#[cfg(feature = "async")]
pub struct AsyncProtocolReader<R>
Expand All @@ -1145,26 +1158,42 @@ where
{
/// Decrypt contents of received packet from buffer.
///
/// This function is cancellation safe.
///
/// # Returns
///
/// A `Result` containing:
/// * `Ok(Payload)`: A decrypted payload.
/// * `Err(ProtocolError)`: An error that occurred during the read or decryption.
pub async fn decrypt(&mut self) -> Result<Payload, ProtocolError> {
// Storing state between async read_exacts to make function more cancellation safe.
// Storing state between async reads to make function cancellation safe.
loop {
match &mut self.state {
DecryptState::ReadingLength => {
let mut length_bytes = [0u8; 3];
self.buffer.read_exact(&mut length_bytes).await?;
let packet_bytes_len = self.packet_reader.decypt_len(length_bytes);
DecryptState::ReadingLength {
length_bytes,
bytes_read,
} => {
while *bytes_read < 3 {
*bytes_read += self.buffer.read(&mut length_bytes[*bytes_read..]).await?;
}

let packet_bytes_len = self.packet_reader.decypt_len(*length_bytes);
let packet_bytes = vec![0u8; packet_bytes_len];
self.state = DecryptState::ReadingPayload { packet_bytes };
self.state = DecryptState::ReadingPayload {
packet_bytes,
bytes_read: 0,
};
}
DecryptState::ReadingPayload { packet_bytes } => {
self.buffer.read_exact(packet_bytes).await?;
DecryptState::ReadingPayload {
packet_bytes,
bytes_read,
} => {
while *bytes_read < packet_bytes.len() {
*bytes_read += self.buffer.read(&mut packet_bytes[*bytes_read..]).await?;
}

let payload = self.packet_reader.decrypt_payload(packet_bytes, None)?;
self.state = DecryptState::ReadingLength;
self.state = DecryptState::default();
return Ok(payload);
}
}
Expand Down
56 changes: 40 additions & 16 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,28 @@ pub async fn peek_addr(client: &TcpStream, network: Network) -> Result<SocketAdd
Ok(socket_addr)
}

/// State machine of an asynchronous helps make functions more robust to cancellation.
#[derive(Default, Debug)]
/// State machine of an asynchronous helps make functions cancellation safe.
#[derive(Debug)]
enum ReadState {
#[default]
ReadingLength,
ReadingLength {
header_bytes: [u8; V1_HEADER_BYTES],
bytes_read: usize,
},
ReadingPayload {
packet_bytes: Vec<u8>,
bytes_read: usize,
},
}

impl Default for ReadState {
fn default() -> Self {
ReadState::ReadingLength {
header_bytes: [0u8; V1_HEADER_BYTES],
bytes_read: 0,
}
}
}

/// Read messages on the V1 protocol.
pub struct V1ProtocolReader<T: AsyncRead + Unpin> {
input: T,
Expand All @@ -119,9 +131,14 @@ impl<T: AsyncRead + Unpin> V1ProtocolReader<T> {
pub async fn read(&mut self) -> Result<NetworkMessage, Error> {
loop {
match &mut self.state {
ReadState::ReadingLength => {
let mut header_bytes = [0u8; V1_HEADER_BYTES];
self.input.read_exact(&mut header_bytes).await?;
ReadState::ReadingLength {
header_bytes,
bytes_read,
} => {
while *bytes_read < V1_HEADER_BYTES {
let n = self.input.read(&mut header_bytes[*bytes_read..]).await?;
*bytes_read += n;
}

let payload_len = u32::from_le_bytes(
header_bytes[16..20]
Expand All @@ -130,21 +147,28 @@ impl<T: AsyncRead + Unpin> V1ProtocolReader<T> {
) as usize;

let mut packet_bytes = vec![0u8; V1_HEADER_BYTES + payload_len];
packet_bytes[..V1_HEADER_BYTES].copy_from_slice(&header_bytes);
packet_bytes[..V1_HEADER_BYTES].copy_from_slice(header_bytes);

self.state = ReadState::ReadingPayload { packet_bytes };
self.state = ReadState::ReadingPayload {
packet_bytes,
bytes_read: V1_HEADER_BYTES,
};
}
ReadState::ReadingPayload { packet_bytes } => {
self.input
.read_exact(&mut packet_bytes[V1_HEADER_BYTES..])
.await?;
ReadState::ReadingPayload {
packet_bytes,
bytes_read,
} => {
while *bytes_read < packet_bytes.len() {
let n = self.input.read(&mut packet_bytes[*bytes_read..]).await?;
*bytes_read += n;
}

let message = RawNetworkMessage::consensus_decode(&mut &packet_bytes[..])
.expect("decode v1");

// Reset state for next read.
self.state = ReadState::ReadingLength;

self.state = ReadState::default();
// The RawNetworkMessage type doesn't have a nice way to pull
// out the payload, so using a clone here.
return Ok(message.payload().clone());
}
}
Expand Down

0 comments on commit 943733f

Please sign in to comment.