Skip to content

Commit

Permalink
messages: Generalize length prefixed item encoding
Browse files Browse the repository at this point in the history
Instead of encoding a single encodable item, use a call back to encode
an arbitrary squence of items.
  • Loading branch information
cjpatton committed Nov 29, 2023
1 parent ce84a04 commit b9245d1
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions daphne/src/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ impl ParameterizedEncode<DapVersion> for Extension {
Self::Taskprov { draft02_payload } => {
EXTENSION_TASKPROV.encode(bytes);
match (version, draft02_payload) {
(DapVersion::DraftLatest, None) => encode_u16_item(bytes, *version, &()),
(DapVersion::DraftLatest, None) => {
encode_u16_prefixed(*version, bytes, |_, _| ());
}
(DapVersion::Draft02, Some(payload)) => encode_u16_bytes(bytes, payload),
_ => unreachable!("unhandled version {version:?}"),
}
Expand All @@ -176,7 +178,7 @@ impl ParameterizedDecode<DapVersion> for Extension {
let typ = u16::decode(bytes)?;
match (version, typ) {
(DapVersion::DraftLatest, EXTENSION_TASKPROV) => {
decode_u16_item::<()>(*version, bytes)?;
decode_u16_prefixed(*version, bytes, |_version, inner| <()>::decode(inner))?;
Ok(Self::Taskprov {
draft02_payload: None,
})
Expand Down Expand Up @@ -1240,27 +1242,28 @@ pub fn decode_base64url_vec<T: AsRef<[u8]>>(input: T) -> Option<Vec<u8>> {
}

// Cribbed from `decode_u16_items()` from libprio.
fn encode_u16_item<E: ParameterizedEncode<DapVersion>>(
bytes: &mut Vec<u8>,
fn encode_u16_prefixed(
version: DapVersion,
item: &E,
bytes: &mut Vec<u8>,
e: impl Fn(DapVersion, &mut Vec<u8>),
) {
// Reserve space for the length prefix.
let len_offset = bytes.len();
0_u16.encode(bytes);

item.encode_with_param(&version, bytes);
e(version, bytes);
let len_bytes = std::mem::size_of::<u16>();
let len = bytes.len() - len_offset - len_bytes;
bytes[len_offset..len_offset + len_bytes]
.copy_from_slice(&u16::to_be_bytes(len.try_into().unwrap()));
}

// Cribbed from `decode_u16_items()` from libprio.
fn decode_u16_item<D: ParameterizedDecode<DapVersion>>(
fn decode_u16_prefixed<O>(
version: DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<D, CodecError> {
d: impl Fn(DapVersion, &mut Cursor<&[u8]>) -> Result<O, CodecError>,
) -> Result<O, CodecError> {
// Read the length prefix.
let len = usize::from(u16::decode(bytes)?);

Expand All @@ -1271,7 +1274,13 @@ fn decode_u16_item<D: ParameterizedDecode<DapVersion>>(
.checked_add(len)
.ok_or_else(|| CodecError::LengthPrefixTooBig(len))?;

let decoded = D::get_decoded_with_param(&version, &bytes.get_ref()[item_start..item_end])?;
let mut inner = Cursor::new(&bytes.get_ref()[item_start..item_end]);
let decoded = d(version, &mut inner)?;

let num_bytes_left_over = item_end - item_start - usize::try_from(inner.position()).unwrap();
if num_bytes_left_over > 0 {
return Err(CodecError::BytesLeftOver(num_bytes_left_over));
}

// Advance outer cursor by the amount read in the inner cursor.
bytes.set_position(item_end.try_into().unwrap());
Expand All @@ -1285,7 +1294,9 @@ fn encode_u16_item_for_version<E: ParameterizedEncode<DapVersion>>(
item: &E,
) {
match version {
DapVersion::DraftLatest => encode_u16_item(bytes, version, item),
DapVersion::DraftLatest => encode_u16_prefixed(version, bytes, |version, bytes| {
item.encode_with_param(&version, bytes);
}),
DapVersion::Draft02 => item.encode_with_param(&version, bytes),
}
}
Expand All @@ -1295,7 +1306,9 @@ fn decode_u16_item_for_version<D: ParameterizedDecode<DapVersion>>(
bytes: &mut Cursor<&[u8]>,
) -> Result<D, CodecError> {
match version {
DapVersion::DraftLatest => decode_u16_item(version, bytes),
DapVersion::DraftLatest => decode_u16_prefixed(version, bytes, |version, inner| {
D::decode_with_param(&version, inner)
}),
DapVersion::Draft02 => D::decode_with_param(&version, bytes),
}
}
Expand Down

0 comments on commit b9245d1

Please sign in to comment.