Skip to content

Commit

Permalink
Fix encode/decode remaining checks. (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
kixelated authored Mar 22, 2024
1 parent 3ecd433 commit 5966365
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 48 deletions.
9 changes: 9 additions & 0 deletions moq-transport/src/coding/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ use thiserror::Error;

pub trait Decode: Sized {
fn decode<B: bytes::Buf>(buf: &mut B) -> Result<Self, DecodeError>;

// Helper function to make sure we have enough bytes to decode
fn decode_remaining<B: bytes::Buf>(buf: &mut B, required: usize) -> Result<(), DecodeError> {
if required > buf.remaining() {
Err(DecodeError::More(required - buf.remaining()))
} else {
Ok(())
}
}
}

/// A decode error.
Expand Down
9 changes: 9 additions & 0 deletions moq-transport/src/coding/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ use super::BoundsExceeded;

pub trait Encode: Sized {
fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError>;

// Helper function to make sure we have enough bytes to encode
fn encode_remaining<W: bytes::BufMut>(buf: &mut W, required: usize) -> Result<(), EncodeError> {
if required > buf.remaining_mut() {
Err(EncodeError::More(required - buf.remaining_mut()))
} else {
Ok(())
}
}
}

/// An encode error.
Expand Down
11 changes: 2 additions & 9 deletions moq-transport/src/coding/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ impl Decode for Params {
}

let size = usize::decode(&mut r)?;

if r.remaining() < size {
return Err(DecodeError::More(size));
}
Self::decode_remaining(r, size)?;

// Don't allocate the entire requested size to avoid a possible attack
// Instead, we allocate up to 1024 and keep appending as we read further.
Expand All @@ -43,11 +40,7 @@ impl Encode for Params {
for (kind, value) in self.0.iter() {
kind.encode(w)?;
value.len().encode(w)?;

if w.remaining_mut() < value.len() {
return Err(EncodeError::More(value.len()));
}

Self::encode_remaining(w, value.len())?;
w.put_slice(value);
}

Expand Down
9 changes: 2 additions & 7 deletions moq-transport/src/coding/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use super::{Decode, DecodeError, Encode, EncodeError};
impl Encode for String {
fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
self.len().encode(w)?;
if w.remaining_mut() < self.len() {
return Err(EncodeError::More(self.len()));
}

Self::encode_remaining(w, self.len())?;
w.put(self.as_ref());
Ok(())
}
Expand All @@ -16,9 +13,7 @@ impl Decode for String {
/// Decode a string with a varint length prefix.
fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
let size = usize::decode(r)?;
if r.remaining() < size {
return Err(DecodeError::More(size));
}
Self::decode_remaining(r, size)?;

let mut buf = vec![0; size];
r.copy_to_slice(&mut buf);
Expand Down
23 changes: 8 additions & 15 deletions moq-transport/src/coding/varint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ impl fmt::Display for VarInt {
impl Decode for VarInt {
/// Decode a varint from the given reader.
fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
if r.remaining() < 1 {
return Err(DecodeError::More(1));
}
Self::decode_remaining(r, 1)?;

let b = r.get_u8();
let tag = b >> 6;
Expand All @@ -179,26 +177,17 @@ impl Decode for VarInt {
let x = match tag {
0b00 => u64::from(buf[0]),
0b01 => {
if r.remaining() < 1 {
return Err(DecodeError::More(1));
}

Self::decode_remaining(r, 1)?;
r.copy_to_slice(buf[1..2].as_mut());
u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
}
0b10 => {
if r.remaining() < 3 {
return Err(DecodeError::More(3));
}

Self::decode_remaining(r, 3)?;
r.copy_to_slice(buf[1..4].as_mut());
u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
}
0b11 => {
if r.remaining() < 7 {
return Err(DecodeError::More(7));
}

Self::decode_remaining(r, 7)?;
r.copy_to_slice(buf[1..8].as_mut());
u64::from_be_bytes(buf)
}
Expand All @@ -214,12 +203,16 @@ impl Encode for VarInt {
fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
let x = self.0;
if x < 2u64.pow(6) {
Self::encode_remaining(w, 1)?;
w.put_u8(x as u8)
} else if x < 2u64.pow(14) {
Self::encode_remaining(w, 2)?;
w.put_u16(0b01 << 14 | x as u16)
} else if x < 2u64.pow(30) {
Self::encode_remaining(w, 4)?;
w.put_u32(0b10 << 30 | x as u32)
} else if x < 2u64.pow(62) {
Self::encode_remaining(w, 8)?;
w.put_u64(0b11 << 62 | x)
} else {
return Err(BoundsExceeded.into());
Expand Down
5 changes: 1 addition & 4 deletions moq-transport/src/data/datagram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ impl Encode for Datagram {
self.group_id.encode(w)?;
self.object_id.encode(w)?;
self.send_order.encode(w)?;

if w.remaining_mut() < self.payload.len() {
return Err(EncodeError::More(self.payload.len()));
}
Self::encode_remaining(w, self.payload.len())?;
w.put_slice(&self.payload);

Ok(())
Expand Down
9 changes: 2 additions & 7 deletions moq-transport/src/message/subscribe_done.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ impl Decode for SubscribeDone {
let code = u64::decode(r)?;
let reason = String::decode(r)?;

if r.remaining() < 1 {
return Err(DecodeError::More(1));
}

Self::decode_remaining(r, 1)?;
let last = match r.get_u8() {
0 => None,
1 => Some((u64::decode(r)?, u64::decode(r)?)),
Expand All @@ -42,9 +39,7 @@ impl Encode for SubscribeDone {
self.code.encode(w)?;
self.reason.encode(w)?;

if w.remaining_mut() < 1 {
return Err(EncodeError::More(1));
}
Self::encode_remaining(w, 1)?;

if let Some((group, object)) = self.last {
w.put_u8(1);
Expand Down
8 changes: 2 additions & 6 deletions moq-transport/src/message/subscribe_ok.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ impl Decode for SubscribeOk {
expires => Some(expires),
};

if !r.has_remaining() {
return Err(DecodeError::More(1));
}
Self::decode_remaining(r, 1)?;

let latest = match r.get_u8() {
0 => None,
Expand All @@ -40,9 +38,7 @@ impl Encode for SubscribeOk {
self.id.encode(w)?;
self.expires.unwrap_or(0).encode(w)?;

if !w.has_remaining_mut() {
return Err(EncodeError::More(1));
}
Self::encode_remaining(w, 1)?;

match self.latest {
Some((group, object)) => {
Expand Down

0 comments on commit 5966365

Please sign in to comment.