diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 5a5b61f..1ea5d32 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -1118,15 +1118,28 @@ where /// State machine of an asynchronous packet read. #[cfg(feature = "async")] -#[derive(Default, Debug)] +#[derive(Debug)] enum DecryptState { - #[default] - ReadingLength, + ReadingLength { + length_bytes: [u8; 3], + bytes_read: usize, + }, ReadingPayload { packet_bytes: Vec, + bytes_read: usize, }, } +#[cfg(feature = "async")] +impl Default for DecryptState { + fn default() -> Self { + DecryptState::ReadingLength { + length_bytes: [0u8; 3], + bytes_read: 0, + } + } +} + /// Manages an async buffer to automatically decrypt contents of received packets. #[cfg(feature = "async")] pub struct AsyncProtocolReader @@ -1145,26 +1158,42 @@ where { /// Decrypt contents of received packet from buffer. /// + /// This function is cancellation safe. + /// /// # Returns /// /// A `Result` containing: /// * `Ok(Payload)`: A decrypted payload. /// * `Err(ProtocolError)`: An error that occurred during the read or decryption. pub async fn decrypt(&mut self) -> Result { - // Storing state between async read_exacts to make function more cancellation safe. + // Storing state between async reads to make function cancellation safe. loop { match &mut self.state { - DecryptState::ReadingLength => { - let mut length_bytes = [0u8; 3]; - self.buffer.read_exact(&mut length_bytes).await?; - let packet_bytes_len = self.packet_reader.decypt_len(length_bytes); + DecryptState::ReadingLength { + length_bytes, + bytes_read, + } => { + while *bytes_read < 3 { + *bytes_read += self.buffer.read(&mut length_bytes[*bytes_read..]).await?; + } + + let packet_bytes_len = self.packet_reader.decypt_len(*length_bytes); let packet_bytes = vec![0u8; packet_bytes_len]; - self.state = DecryptState::ReadingPayload { packet_bytes }; + self.state = DecryptState::ReadingPayload { + packet_bytes, + bytes_read: 0, + }; } - DecryptState::ReadingPayload { packet_bytes } => { - self.buffer.read_exact(packet_bytes).await?; + DecryptState::ReadingPayload { + packet_bytes, + bytes_read, + } => { + while *bytes_read < packet_bytes.len() { + *bytes_read += self.buffer.read(&mut packet_bytes[*bytes_read..]).await?; + } + let payload = self.packet_reader.decrypt_payload(packet_bytes, None)?; - self.state = DecryptState::ReadingLength; + self.state = DecryptState::default(); return Ok(payload); } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 941594c..a8fffb2 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -90,16 +90,28 @@ pub async fn peek_addr(client: &TcpStream, network: Network) -> Result, + bytes_read: usize, }, } +impl Default for ReadState { + fn default() -> Self { + ReadState::ReadingLength { + header_bytes: [0u8; V1_HEADER_BYTES], + bytes_read: 0, + } + } +} + /// Read messages on the V1 protocol. pub struct V1ProtocolReader { input: T, @@ -119,9 +131,14 @@ impl V1ProtocolReader { pub async fn read(&mut self) -> Result { loop { match &mut self.state { - ReadState::ReadingLength => { - let mut header_bytes = [0u8; V1_HEADER_BYTES]; - self.input.read_exact(&mut header_bytes).await?; + ReadState::ReadingLength { + header_bytes, + bytes_read, + } => { + while *bytes_read < V1_HEADER_BYTES { + let n = self.input.read(&mut header_bytes[*bytes_read..]).await?; + *bytes_read += n; + } let payload_len = u32::from_le_bytes( header_bytes[16..20] @@ -130,21 +147,28 @@ impl V1ProtocolReader { ) as usize; let mut packet_bytes = vec![0u8; V1_HEADER_BYTES + payload_len]; - packet_bytes[..V1_HEADER_BYTES].copy_from_slice(&header_bytes); + packet_bytes[..V1_HEADER_BYTES].copy_from_slice(header_bytes); - self.state = ReadState::ReadingPayload { packet_bytes }; + self.state = ReadState::ReadingPayload { + packet_bytes, + bytes_read: V1_HEADER_BYTES, + }; } - ReadState::ReadingPayload { packet_bytes } => { - self.input - .read_exact(&mut packet_bytes[V1_HEADER_BYTES..]) - .await?; + ReadState::ReadingPayload { + packet_bytes, + bytes_read, + } => { + while *bytes_read < packet_bytes.len() { + let n = self.input.read(&mut packet_bytes[*bytes_read..]).await?; + *bytes_read += n; + } let message = RawNetworkMessage::consensus_decode(&mut &packet_bytes[..]) .expect("decode v1"); - // Reset state for next read. - self.state = ReadState::ReadingLength; - + self.state = ReadState::default(); + // The RawNetworkMessage type doesn't have a nice way to pull + // out the payload, so using a clone here. return Ok(message.payload().clone()); } }