Skip to content

Commit

Permalink
Fix incompatibility issue
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaseizinger committed Sep 29, 2023
1 parent 2c05e75 commit 102f58c
Showing 1 changed file with 73 additions and 39 deletions.
112 changes: 73 additions & 39 deletions transports/noise/src/io/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
use super::handshake::proto;
use crate::{protocol::PublicKey, Error};
use asynchronous_codec::{Decoder, Encoder, LengthCodec};
use bytes::{Bytes, BytesMut};
use asynchronous_codec::{Decoder, Encoder};
use bytes::{Buf, Bytes, BytesMut};
use log::{debug, error};
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use std::io;
use std::mem::size_of;

/// Max. size of a noise message.
const MAX_NOISE_MSG_LEN: usize = 65535;
Expand All @@ -45,36 +46,30 @@ static_assertions::const_assert! {
/// encoding and decoding length-delimited session messages.
pub(crate) struct Codec<S> {
session: S,
write_buffer: BytesMut,
encrypt_buffer: BytesMut,
decrypt_buffer: BytesMut,
length_codec: LengthCodec,
length_codec: U16LengthCodec,
}

impl<S: SessionState> Codec<S> {
pub(crate) fn new(session: S) -> Self {
Codec {
session,
write_buffer: BytesMut::new(),
encrypt_buffer: BytesMut::new(),
decrypt_buffer: BytesMut::new(),
length_codec: LengthCodec,
length_codec: U16LengthCodec,
}
}

fn encode_bytes(&mut self, item: &[u8], dst: &mut BytesMut) -> Result<(), io::Error> {
self.encrypt_buffer
.resize(item.len() + EXTRA_ENCRYPT_SPACE, 0);
let n = match self.session.write_message(item, &mut self.encrypt_buffer) {
let mut encrypt_buffer = BytesMut::zeroed(item.len() + EXTRA_ENCRYPT_SPACE);

let n = match self.session.write_message(item, &mut encrypt_buffer) {
Ok(n) => n,
Err(e) => {
error!("encryption error: {:?}", e);
return Err(io::ErrorKind::InvalidData.into());
}
};

let msg = self.encrypt_buffer.split_to(n).freeze();
self.length_codec.encode(msg, dst)
self.length_codec
.encode(encrypt_buffer.split_to(n).freeze(), dst)
}

fn decode_bytes(&mut self, src: &mut BytesMut) -> Result<Option<Bytes>, io::Error> {
Expand All @@ -84,18 +79,16 @@ impl<S: SessionState> Codec<S> {
None => return Ok(None),
};

self.decrypt_buffer.resize(bytes.len(), 0u8);
let n = match self.session.read_message(&bytes, &mut self.decrypt_buffer) {
let mut decrypt_buffer = BytesMut::zeroed(bytes.len());
let n = match self.session.read_message(&bytes, &mut decrypt_buffer) {
Ok(n) => n,
Err(e) => {
debug!("decryption error {e}");
return Err(io::ErrorKind::InvalidData.into());
}
};

self.decrypt_buffer.truncate(n);

Ok(Some(self.decrypt_buffer.split().freeze()))
Ok(Some(decrypt_buffer.split_to(n).freeze()))
}
}

Expand Down Expand Up @@ -140,34 +133,35 @@ impl Encoder for Codec<snow::HandshakeState> {
type Item<'a> = &'a proto::NoiseHandshakePayload;

fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
self.write_buffer.resize(item.get_size(), 0u8);
let mut writer = Writer::new(&mut self.write_buffer[..]);
let mut write_buffer = BytesMut::zeroed(item.get_size());

let mut writer = Writer::new(&mut write_buffer[..]);
item.write_message(&mut writer)
.expect("Protobuf encoding to succeed");

let pb = self.write_buffer.split().freeze();
self.encode_bytes(&pb, dst)
}
self.encode_bytes(&write_buffer.split_to(item.get_size()).freeze(), dst)
}gst
}
impl Decoder for Codec<snow::HandshakeState> {
type Error = io::Error;
type Item = proto::NoiseHandshakePayload;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.decode_bytes(src)? {
Some(bytes) => {
let mut reader = BytesReader::from_bytes(&bytes[..]);
let pb = proto::NoiseHandshakePayload::from_reader(&mut reader, &bytes[..])
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"Failed decoding handshake payload",
)
})?;
Ok(Some(pb))
}
None => Ok(None),
}
let bytes = match self.decode_bytes(src)? {
Some(bytes) => bytes,
None => return Ok(None),
};

let mut reader = BytesReader::from_bytes(&bytes[..]);
let pb =
proto::NoiseHandshakePayload::from_reader(&mut reader, &bytes[..]).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"Failed decoding handshake payload",
)
})?;

Ok(Some(pb))
}
}

Expand Down Expand Up @@ -216,3 +210,43 @@ impl SessionState for snow::TransportState {
self.write_message(msg, buf)
}
}

/// A codec that prefixes messages with their length encoded as a big-endian u16.
struct U16LengthCodec;

const U16_LENGTH: usize = size_of::<u16>();

impl Encoder for U16LengthCodec {
type Item<'a> = Bytes;
type Error = io::Error;

fn encode(&mut self, src: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.reserve(U16_LENGTH + src.len());
dst.extend_from_slice(&(src.len() as u16).to_be_bytes());
dst.extend_from_slice(&src);
Ok(())
}
}

impl Decoder for U16LengthCodec {
type Item = Bytes;
type Error = io::Error;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < size_of::<u16>() {
return Ok(None);
}

let mut len_bytes = [0u8; U16_LENGTH];
len_bytes.copy_from_slice(&src[..U16_LENGTH]);
let len = u16::from_be_bytes(len_bytes) as usize;

if src.len() - U16_LENGTH >= len {
// Skip the length header we already read.
src.advance(U16_LENGTH);
Ok(Some(src.split_to(len).freeze()))
} else {
Ok(None)
}
}
}

0 comments on commit 102f58c

Please sign in to comment.