From 59663657389954d36790519e3b468ec6c6f3250e Mon Sep 17 00:00:00 2001 From: kixelated Date: Thu, 21 Mar 2024 23:47:00 -0700 Subject: [PATCH] Fix encode/decode remaining checks. (#140) --- moq-transport/src/coding/decode.rs | 9 ++++++++ moq-transport/src/coding/encode.rs | 9 ++++++++ moq-transport/src/coding/params.rs | 11 ++-------- moq-transport/src/coding/string.rs | 9 ++------ moq-transport/src/coding/varint.rs | 23 +++++++-------------- moq-transport/src/data/datagram.rs | 5 +---- moq-transport/src/message/subscribe_done.rs | 9 ++------ moq-transport/src/message/subscribe_ok.rs | 8 ++----- 8 files changed, 35 insertions(+), 48 deletions(-) diff --git a/moq-transport/src/coding/decode.rs b/moq-transport/src/coding/decode.rs index a5f6e87e..84f13a06 100644 --- a/moq-transport/src/coding/decode.rs +++ b/moq-transport/src/coding/decode.rs @@ -4,6 +4,15 @@ use thiserror::Error; pub trait Decode: Sized { fn decode(buf: &mut B) -> Result; + + // Helper function to make sure we have enough bytes to decode + fn decode_remaining(buf: &mut B, required: usize) -> Result<(), DecodeError> { + if required > buf.remaining() { + Err(DecodeError::More(required - buf.remaining())) + } else { + Ok(()) + } + } } /// A decode error. diff --git a/moq-transport/src/coding/encode.rs b/moq-transport/src/coding/encode.rs index e906c7d4..1f2e80c7 100644 --- a/moq-transport/src/coding/encode.rs +++ b/moq-transport/src/coding/encode.rs @@ -4,6 +4,15 @@ use super::BoundsExceeded; pub trait Encode: Sized { fn encode(&self, w: &mut W) -> Result<(), EncodeError>; + + // Helper function to make sure we have enough bytes to encode + fn encode_remaining(buf: &mut W, required: usize) -> Result<(), EncodeError> { + if required > buf.remaining_mut() { + Err(EncodeError::More(required - buf.remaining_mut())) + } else { + Ok(()) + } + } } /// An encode error. diff --git a/moq-transport/src/coding/params.rs b/moq-transport/src/coding/params.rs index d78f753f..92f7fc72 100644 --- a/moq-transport/src/coding/params.rs +++ b/moq-transport/src/coding/params.rs @@ -19,10 +19,7 @@ impl Decode for Params { } let size = usize::decode(&mut r)?; - - if r.remaining() < size { - return Err(DecodeError::More(size)); - } + Self::decode_remaining(r, size)?; // Don't allocate the entire requested size to avoid a possible attack // Instead, we allocate up to 1024 and keep appending as we read further. @@ -43,11 +40,7 @@ impl Encode for Params { for (kind, value) in self.0.iter() { kind.encode(w)?; value.len().encode(w)?; - - if w.remaining_mut() < value.len() { - return Err(EncodeError::More(value.len())); - } - + Self::encode_remaining(w, value.len())?; w.put_slice(value); } diff --git a/moq-transport/src/coding/string.rs b/moq-transport/src/coding/string.rs index 7627bc5d..dd387da9 100644 --- a/moq-transport/src/coding/string.rs +++ b/moq-transport/src/coding/string.rs @@ -3,10 +3,7 @@ use super::{Decode, DecodeError, Encode, EncodeError}; impl Encode for String { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.len().encode(w)?; - if w.remaining_mut() < self.len() { - return Err(EncodeError::More(self.len())); - } - + Self::encode_remaining(w, self.len())?; w.put(self.as_ref()); Ok(()) } @@ -16,9 +13,7 @@ impl Decode for String { /// Decode a string with a varint length prefix. fn decode(r: &mut R) -> Result { let size = usize::decode(r)?; - if r.remaining() < size { - return Err(DecodeError::More(size)); - } + Self::decode_remaining(r, size)?; let mut buf = vec![0; size]; r.copy_to_slice(&mut buf); diff --git a/moq-transport/src/coding/varint.rs b/moq-transport/src/coding/varint.rs index 763af016..8690fba2 100644 --- a/moq-transport/src/coding/varint.rs +++ b/moq-transport/src/coding/varint.rs @@ -166,9 +166,7 @@ impl fmt::Display for VarInt { impl Decode for VarInt { /// Decode a varint from the given reader. fn decode(r: &mut R) -> Result { - if r.remaining() < 1 { - return Err(DecodeError::More(1)); - } + Self::decode_remaining(r, 1)?; let b = r.get_u8(); let tag = b >> 6; @@ -179,26 +177,17 @@ impl Decode for VarInt { let x = match tag { 0b00 => u64::from(buf[0]), 0b01 => { - if r.remaining() < 1 { - return Err(DecodeError::More(1)); - } - + Self::decode_remaining(r, 1)?; r.copy_to_slice(buf[1..2].as_mut()); u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) } 0b10 => { - if r.remaining() < 3 { - return Err(DecodeError::More(3)); - } - + Self::decode_remaining(r, 3)?; r.copy_to_slice(buf[1..4].as_mut()); u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) } 0b11 => { - if r.remaining() < 7 { - return Err(DecodeError::More(7)); - } - + Self::decode_remaining(r, 7)?; r.copy_to_slice(buf[1..8].as_mut()); u64::from_be_bytes(buf) } @@ -214,12 +203,16 @@ impl Encode for VarInt { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { let x = self.0; if x < 2u64.pow(6) { + Self::encode_remaining(w, 1)?; w.put_u8(x as u8) } else if x < 2u64.pow(14) { + Self::encode_remaining(w, 2)?; w.put_u16(0b01 << 14 | x as u16) } else if x < 2u64.pow(30) { + Self::encode_remaining(w, 4)?; w.put_u32(0b10 << 30 | x as u32) } else if x < 2u64.pow(62) { + Self::encode_remaining(w, 8)?; w.put_u64(0b11 << 62 | x) } else { return Err(BoundsExceeded.into()); diff --git a/moq-transport/src/data/datagram.rs b/moq-transport/src/data/datagram.rs index 8bf62c63..970b5220 100644 --- a/moq-transport/src/data/datagram.rs +++ b/moq-transport/src/data/datagram.rs @@ -47,10 +47,7 @@ impl Encode for Datagram { self.group_id.encode(w)?; self.object_id.encode(w)?; self.send_order.encode(w)?; - - if w.remaining_mut() < self.payload.len() { - return Err(EncodeError::More(self.payload.len())); - } + Self::encode_remaining(w, self.payload.len())?; w.put_slice(&self.payload); Ok(()) diff --git a/moq-transport/src/message/subscribe_done.rs b/moq-transport/src/message/subscribe_done.rs index 08dcab28..eeb78c3c 100644 --- a/moq-transport/src/message/subscribe_done.rs +++ b/moq-transport/src/message/subscribe_done.rs @@ -22,10 +22,7 @@ impl Decode for SubscribeDone { let code = u64::decode(r)?; let reason = String::decode(r)?; - if r.remaining() < 1 { - return Err(DecodeError::More(1)); - } - + Self::decode_remaining(r, 1)?; let last = match r.get_u8() { 0 => None, 1 => Some((u64::decode(r)?, u64::decode(r)?)), @@ -42,9 +39,7 @@ impl Encode for SubscribeDone { self.code.encode(w)?; self.reason.encode(w)?; - if w.remaining_mut() < 1 { - return Err(EncodeError::More(1)); - } + Self::encode_remaining(w, 1)?; if let Some((group, object)) = self.last { w.put_u8(1); diff --git a/moq-transport/src/message/subscribe_ok.rs b/moq-transport/src/message/subscribe_ok.rs index a91bfdce..d4bc16f2 100644 --- a/moq-transport/src/message/subscribe_ok.rs +++ b/moq-transport/src/message/subscribe_ok.rs @@ -21,9 +21,7 @@ impl Decode for SubscribeOk { expires => Some(expires), }; - if !r.has_remaining() { - return Err(DecodeError::More(1)); - } + Self::decode_remaining(r, 1)?; let latest = match r.get_u8() { 0 => None, @@ -40,9 +38,7 @@ impl Encode for SubscribeOk { self.id.encode(w)?; self.expires.unwrap_or(0).encode(w)?; - if !w.has_remaining_mut() { - return Err(EncodeError::More(1)); - } + Self::encode_remaining(w, 1)?; match self.latest { Some((group, object)) => {