Skip to content

Commit

Permalink
Update read_exact to read for cancellation safety
Browse files Browse the repository at this point in the history
  • Loading branch information
nyonson committed Oct 17, 2024
1 parent 31c7449 commit facd501
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1098,15 +1098,28 @@ where

/// State machine of an asynchronous packet read.
#[cfg(feature = "async")]
#[derive(Default, Debug)]
#[derive(Debug)]
enum DecryptState {
#[default]
ReadingLength,
ReadingLength {
length_bytes: [u8; 3],
bytes_read: usize,
},
ReadingPayload {
packet_bytes: Vec<u8>,
bytes_read: usize,
},
}

#[cfg(feature = "async")]
impl Default for DecryptState {
fn default() -> Self {
DecryptState::ReadingLength {
length_bytes: [0u8; 3],
bytes_read: 0,
}
}
}

/// Manages an async buffer to automatically decrypt contents of received packets.
#[cfg(feature = "async")]
pub struct AsyncProtocolReader<R>
Expand All @@ -1131,20 +1144,37 @@ where
/// * `Ok(Payload)`: A decrypted payload.
/// * `Err(ProtocolError)`: An error that occurred during the read or decryption.
pub async fn decrypt(&mut self) -> Result<Payload, ProtocolError> {
// Storing state between async read_exacts to make function more cancellation safe.
// Storing state between async reads to make function cancellation safe.
loop {
match &mut self.state {
DecryptState::ReadingLength => {
let mut length_bytes = [0u8; 3];
self.buffer.read_exact(&mut length_bytes).await?;
let packet_bytes_len = self.packet_reader.decypt_len(length_bytes);
DecryptState::ReadingLength {
length_bytes,
bytes_read,
} => {
while *bytes_read < 3 {
*bytes_read += self.buffer.read(&mut length_bytes[*bytes_read..]).await?;
}

let packet_bytes_len = self.packet_reader.decypt_len(*length_bytes);
let packet_bytes = vec![0u8; packet_bytes_len];
self.state = DecryptState::ReadingPayload { packet_bytes };
self.state = DecryptState::ReadingPayload {
packet_bytes,
bytes_read: 0,
};
}
DecryptState::ReadingPayload { packet_bytes } => {
self.buffer.read_exact(packet_bytes).await?;
DecryptState::ReadingPayload {
packet_bytes,
bytes_read,
} => {
while *bytes_read < packet_bytes.len() {
*bytes_read += self.buffer.read(&mut packet_bytes[*bytes_read..]).await?;
}

let payload = self.packet_reader.decrypt_payload(packet_bytes, None)?;
self.state = DecryptState::ReadingLength;
self.state = DecryptState::ReadingLength {
length_bytes: [0u8; 3],
bytes_read: 0,
};
return Ok(payload);
}
}
Expand Down

0 comments on commit facd501

Please sign in to comment.