diff --git a/Cargo.toml b/Cargo.toml index 61a318f..c4305e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,12 +17,13 @@ features = ["serde"] rustdoc-args = ["--cfg", "docsrs"] [features] -encode = ["dep:sha2"] +encode = ["dep:sha2", "dep:varint-simd"] serde = ["encode", "dep:serde"] [dependencies] serde = { version = "1.0", optional = true } sha2 = { version = "0.10", optional = true } +varint-simd = { version = "0.4", optional = true } [dev-dependencies] bincode = "1.3" diff --git a/src/encode.rs b/src/encode.rs index ffc2dff..92eb489 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -79,77 +79,6 @@ impl core::fmt::Display for BoolDecodeError { } } -// When encoding integers we use a variable-length encoding scheme that aims to -// minimize the number of bytes used to encode small integers, where small is -// approx < 2 ^ 16. -// -// There are 3 separate branches based on how big the integer is: -// -// - if the integer is between 0 and 63, we can encode it with 6 bits. In this -// case we use the lower 6 bits of the first byte, and set the 2 highest bits -// to `01`; -// -// - if the integer is between 64 and 2 ^ 15 - 1, we can encode it with 15 -// bits. In this case we use the lower 7 bits of the first byte and all the -// bits of the second byte. The highest bit of the first byte is set to `1`; -// -// - if the integer is greater than 2 ^ 15 - 1, we use the first byte to encode -// the number of bytes used to encode the integer, and then we encode the -// integer itself in little endian, throwing away any trailing zeros. -// -// With this scheme we can encode integers up to 63 with 1 byte, integers up -// to 2 ^ 15 - 1 with 2 bytes, and integers greater than 2 ^ 15 - 1 with 3 -// or more bytes. - -const ENCODE_ONE_BYTE_MASK: u8 = 0b0100_0000; - -const ENCODE_TWO_BYTES_MASK: u8 = 0b1000_0000; - -const LAST_BIT_MASK: u8 = 0b1000_0000; - -#[inline(always)] -fn encode_one_byte(int: u8) -> u8 { - debug_assert!(int < 1 << 6); - int | ENCODE_ONE_BYTE_MASK -} - -#[inline(always)] -fn decode_one_byte(int: u8) -> u8 { - int & !ENCODE_ONE_BYTE_MASK -} - -#[inline(always)] -fn encode_two_bytes(int: u16) -> (u8, u8) { - debug_assert!((1 << 6..1 << 15).contains(&int)); - - let [mut lo, mut hi] = int.to_le_bytes(); - - // Move the last bit of the low byte to the last bit of the high byte. - // - // We know this doesn't lose any information because the int is less than - // 2 ^ 15, so the last bit of the high byte is 0. - hi |= lo & LAST_BIT_MASK; - - // Set the last bit of the low byte to 1 to indicate that this number is - // encoded with 2 bytes. - lo |= ENCODE_TWO_BYTES_MASK; - - (lo, hi) -} - -#[inline(always)] -fn decode_two_bytes(mut lo: u8, mut hi: u8) -> u16 { - lo &= !ENCODE_TWO_BYTES_MASK; - - // Move the last bit of the high byte to the last bit of the low byte. - lo |= hi & LAST_BIT_MASK; - - // Reset the last bit of the high byte to 0. - hi &= !LAST_BIT_MASK; - - u16::from_le_bytes([lo, hi]) -} - impl_int_encode!(u16); impl_int_encode!(u32); impl_int_encode!(u64); @@ -181,35 +110,8 @@ macro_rules! impl_int_encode { impl Encode for $ty { #[inline] fn encode(&self, buf: &mut Vec) { - let int = *self; - - if int < 1 << 6 { - buf.push(encode_one_byte(int as u8)); - return; - } else if int < 1 << 15 { - let (first, second) = encode_two_bytes(int as u16); - buf.push(first); - buf.push(second); - return; - } - - let array = int.to_le_bytes(); - - let num_trailing_zeros = array - .iter() - .rev() - .copied() - .take_while(|&byte| byte == 0) - .count(); - - let len = array.len() - num_trailing_zeros; - - // Make sure that the first 2 bits are 0. - debug_assert_eq!(len & 0b1100_0000, 0); - - buf.push(len as u8); - - buf.extend_from_slice(&array[..len]); + let (array, len) = varint_simd::encode(*self); + buf.extend_from_slice(&array[..len as usize]); } } }; @@ -226,45 +128,18 @@ macro_rules! impl_int_decode { #[inline] fn decode(buf: &[u8]) -> Result<($ty, &[u8]), Self::Error> { - let (&first, buf) = - buf.split_first().ok_or(IntDecodeError::EmptyBuffer)?; - - if first & ENCODE_TWO_BYTES_MASK != 0 { - let lo = first; - - let (&hi, buf) = buf.split_first().ok_or( - IntDecodeError::LengthLessThanPrefix { - prefix: 2, - actual: 1, - }, - )?; - - let int = decode_two_bytes(lo, hi) as $ty; - - return Ok((int, buf)); - } else if first & ENCODE_ONE_BYTE_MASK != 0 { - let int = decode_one_byte(first) as $ty; - return Ok((int, buf)); - } - - let len = first; - - if len as usize > buf.len() { - return Err(IntDecodeError::LengthLessThanPrefix { - prefix: len, - actual: buf.len() as u8, - }); - } - - let mut array = [0u8; ::core::mem::size_of::<$ty>()]; - - let (bytes, buf) = buf.split_at(len as usize); - - array[..bytes.len()].copy_from_slice(bytes); - - let int = <$ty>::from_le_bytes(array); - - Ok((int, buf)) + let (decoded, len) = varint_simd::decode::(buf) + .map_err(IntDecodeError)?; + + // TODO: this check shouldn't be necessary, `decode` should + // fail. Open an issue. + let Some(rest) = buf.get(len as usize..) else { + return Err(IntDecodeError( + varint_simd::VarIntDecodeError::NotEnoughBytes, + )); + }; + + Ok((decoded, rest)) } } }; @@ -273,32 +148,12 @@ macro_rules! impl_int_decode { use impl_int_decode; /// An error that can occur when decoding an [`Int`]. -#[cfg_attr(test, derive(PartialEq, Eq))] -pub(crate) enum IntDecodeError { - /// The buffer passed to `Int::decode` is empty. This is always an error, - /// even if the integer being decoded is zero. - EmptyBuffer, - - /// The actual byte length of the buffer is less than what was specified - /// in the prefix. - LengthLessThanPrefix { prefix: u8, actual: u8 }, -} +pub(crate) struct IntDecodeError(varint_simd::VarIntDecodeError); impl Display for IntDecodeError { #[inline] fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::EmptyBuffer => f.write_str( - "Int couldn't be decoded because the buffer is empty", - ), - Self::LengthLessThanPrefix { prefix, actual } => { - write!( - f, - "Int couldn't be decoded because the buffer's length is \ - {actual}, but the prefix specified a length of {prefix}", - ) - }, - } + Display::fmt(&self.0, f) } } @@ -396,8 +251,17 @@ mod serde { mod tests { use super::*; + impl PartialEq for IntDecodeError { + fn eq(&self, other: &Self) -> bool { + use varint_simd::VarIntDecodeError::*; + matches!( + (&self.0, &other.0), + (Overflow, Overflow) | (NotEnoughBytes, NotEnoughBytes) + ) + } + } + impl core::fmt::Debug for IntDecodeError { - #[inline] fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { core::fmt::Display::fmt(self, f) } @@ -418,31 +282,9 @@ mod tests { } } - /// Tests that integers are encoded using the correct number of bytes. + /// Tests the encoding-decoding roundtrip on a number of inputs. #[test] - fn encode_int_num_bytes() { - fn expected_len(int: u64) -> u8 { - if int < 1 << 6 { - 1 - } else if int < 1 << 15 { - 2 - } else if int < 1 << 16 { - 3 - } else if int < 1 << 24 { - 4 - } else if int < 1 << 32 { - 5 - } else if int < 1 << 40 { - 6 - } else if int < 1 << 48 { - 7 - } else if int < 1 << 56 { - 8 - } else { - 9 - } - } - + fn encode_int_roundtrip() { let ints = (1..=8).chain([ 0, (1 << 6) - 1, @@ -463,15 +305,9 @@ mod tests { for int in ints { int.encode(&mut buf); - - assert_eq!(buf.len(), expected_len(int) as usize); - let (decoded, rest) = u64::decode(&buf).unwrap(); - assert_eq!(int, decoded); - assert!(rest.is_empty()); - buf.clear(); } } @@ -487,7 +323,7 @@ mod tests { assert_eq!( u32::decode(&buf).unwrap_err(), - IntDecodeError::EmptyBuffer + IntDecodeError(varint_simd::VarIntDecodeError::NotEnoughBytes), ); } @@ -503,7 +339,7 @@ mod tests { assert_eq!( u32::decode(&buf).unwrap_err(), - IntDecodeError::LengthLessThanPrefix { prefix: 2, actual: 1 } + IntDecodeError(varint_simd::VarIntDecodeError::NotEnoughBytes), ); } } diff --git a/tests/serde.rs b/tests/serde.rs index 631488b..3102c0d 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -2,6 +2,8 @@ mod common; #[cfg(feature = "serde")] mod serde { + use std::io::{self, Write}; + use serde::de::DeserializeOwned; use serde::ser::Serialize; use traces::{ConcurrentTraceInfos, Crdt, Edit, SequentialTrace}; @@ -141,14 +143,22 @@ mod serde { } }; + let mut stdout = io::stdout(); + let replica_size = E::encode(&replica.encode()).len(); - println!("{} | Replica: {}", E::name(), printed_size(replica_size)); + let _ = writeln!( + &mut stdout, + "{} | Replica: {}", + E::name(), + printed_size(replica_size) + ); let total_insertions_size = insertions.iter().map(Vec::len).sum::(); - println!( + let _ = writeln!( + &mut stdout, "{} | Total insertions: {}", E::name(), printed_size(total_insertions_size) @@ -157,7 +167,8 @@ mod serde { let total_deletions_size = deletions.iter().map(Vec::len).sum::(); - println!( + let _ = writeln!( + &mut stdout, "{} | Total deletions: {}", E::name(), printed_size(total_deletions_size)