From b9245d158250170ef31d2940ff2e3df77f208fa6 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Wed, 29 Nov 2023 14:57:47 -0800 Subject: [PATCH] messages: Generalize length prefixed item encoding Instead of encoding a single encodable item, use a call back to encode an arbitrary squence of items. --- daphne/src/messages/mod.rs | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index 7cc7f6e1d..198d190e6 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -155,7 +155,9 @@ impl ParameterizedEncode for Extension { Self::Taskprov { draft02_payload } => { EXTENSION_TASKPROV.encode(bytes); match (version, draft02_payload) { - (DapVersion::DraftLatest, None) => encode_u16_item(bytes, *version, &()), + (DapVersion::DraftLatest, None) => { + encode_u16_prefixed(*version, bytes, |_, _| ()); + } (DapVersion::Draft02, Some(payload)) => encode_u16_bytes(bytes, payload), _ => unreachable!("unhandled version {version:?}"), } @@ -176,7 +178,7 @@ impl ParameterizedDecode for Extension { let typ = u16::decode(bytes)?; match (version, typ) { (DapVersion::DraftLatest, EXTENSION_TASKPROV) => { - decode_u16_item::<()>(*version, bytes)?; + decode_u16_prefixed(*version, bytes, |_version, inner| <()>::decode(inner))?; Ok(Self::Taskprov { draft02_payload: None, }) @@ -1240,16 +1242,16 @@ pub fn decode_base64url_vec>(input: T) -> Option> { } // Cribbed from `decode_u16_items()` from libprio. -fn encode_u16_item>( - bytes: &mut Vec, +fn encode_u16_prefixed( version: DapVersion, - item: &E, + bytes: &mut Vec, + e: impl Fn(DapVersion, &mut Vec), ) { // Reserve space for the length prefix. let len_offset = bytes.len(); 0_u16.encode(bytes); - item.encode_with_param(&version, bytes); + e(version, bytes); let len_bytes = std::mem::size_of::(); let len = bytes.len() - len_offset - len_bytes; bytes[len_offset..len_offset + len_bytes] @@ -1257,10 +1259,11 @@ fn encode_u16_item>( } // Cribbed from `decode_u16_items()` from libprio. -fn decode_u16_item>( +fn decode_u16_prefixed( version: DapVersion, bytes: &mut Cursor<&[u8]>, -) -> Result { + d: impl Fn(DapVersion, &mut Cursor<&[u8]>) -> Result, +) -> Result { // Read the length prefix. let len = usize::from(u16::decode(bytes)?); @@ -1271,7 +1274,13 @@ fn decode_u16_item>( .checked_add(len) .ok_or_else(|| CodecError::LengthPrefixTooBig(len))?; - let decoded = D::get_decoded_with_param(&version, &bytes.get_ref()[item_start..item_end])?; + let mut inner = Cursor::new(&bytes.get_ref()[item_start..item_end]); + let decoded = d(version, &mut inner)?; + + let num_bytes_left_over = item_end - item_start - usize::try_from(inner.position()).unwrap(); + if num_bytes_left_over > 0 { + return Err(CodecError::BytesLeftOver(num_bytes_left_over)); + } // Advance outer cursor by the amount read in the inner cursor. bytes.set_position(item_end.try_into().unwrap()); @@ -1285,7 +1294,9 @@ fn encode_u16_item_for_version>( item: &E, ) { match version { - DapVersion::DraftLatest => encode_u16_item(bytes, version, item), + DapVersion::DraftLatest => encode_u16_prefixed(version, bytes, |version, bytes| { + item.encode_with_param(&version, bytes); + }), DapVersion::Draft02 => item.encode_with_param(&version, bytes), } } @@ -1295,7 +1306,9 @@ fn decode_u16_item_for_version>( bytes: &mut Cursor<&[u8]>, ) -> Result { match version { - DapVersion::DraftLatest => decode_u16_item(version, bytes), + DapVersion::DraftLatest => decode_u16_prefixed(version, bytes, |version, inner| { + D::decode_with_param(&version, inner) + }), DapVersion::Draft02 => D::decode_with_param(&version, bytes), } }