diff --git a/CHANGELOG.md b/CHANGELOG.md index e0790f2..095dd0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# [Unreleased] + +- Add versionize proc macro support for HashMap, HashSet and VecDeque. + # v0.1.9 - Implement Versionize for i128 and u128 diff --git a/src/lib.rs b/src/lib.rs index e9460da..d35b934 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,7 @@ //! `Versionize` trait is implemented for the following primitives: //! u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, char, f32, f64, //! String, Vec, Arrays up to 32 elements, Box, Wrapping, Option, -//! FamStructWrapper, and (T, U). +//! FamStructWrapper, VecDeque, HashMap, HashSet and (T, U). //! //! Known issues and limitations: //! - Union serialization is not supported via the `Versionize` proc macro. @@ -59,6 +59,10 @@ pub enum VersionizeError { StringLength(usize), /// Vector length exceeded. VecLength(usize), + /// HashMap length exceeded. + HashMapLength(usize), + /// HashSet length exceeded. + HashSetLength(usize), } impl std::fmt::Display for VersionizeError { @@ -82,6 +86,18 @@ impl std::fmt::Display for VersionizeError { bad_len, primitives::MAX_VEC_SIZE ), + HashMapLength(bad_len) => write!( + f, + "HashMap of length exceeded {} > {} bytes", + bad_len, + primitives::MAX_HASH_MAP_LEN + ), + HashSetLength(bad_len) => write!( + f, + "HashSet of length exceeded {} > {} bytes", + bad_len, + primitives::MAX_HASH_SET_LEN + ), } } } diff --git a/src/primitives.rs b/src/primitives.rs index 3d75167..b8f3c23 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -3,6 +3,9 @@ //! Serialization support for primitive data types. #![allow(clippy::float_cmp)] +use std::collections::{HashMap, HashSet, VecDeque}; +use std::hash::Hash; + use self::super::{VersionMap, Versionize, VersionizeError, VersionizeResult}; use vmm_sys_util::fam::{FamStruct, FamStructWrapper}; @@ -12,6 +15,10 @@ pub const MAX_STRING_LEN: usize = 16384; /// Maximum allowed vec size in bytes (10MB). /// Calling `serialize()` or `deserialiaze()` will fail beyond this limit. pub const MAX_VEC_SIZE: usize = 10_485_760; +/// Maximum hashmap len in bytes (20MB). +pub const MAX_HASH_MAP_LEN: usize = 20_971_520; +/// Maximum hashset len in bytes (10MB). +pub const MAX_HASH_SET_LEN: usize = 10_485_760; /// A macro that implements the Versionize trait for primitive types using the /// serde bincode backed. @@ -281,9 +288,75 @@ where } } -impl Versionize for Vec +macro_rules! impl_versionize_vec_like_type { + ($VecType:ident <$GenericParam:ident>) => { + impl<$GenericParam> Versionize for $VecType<$GenericParam> + where + $GenericParam: Versionize, + { + #[inline] + fn serialize( + &self, + mut writer: &mut W, + version_map: &VersionMap, + app_version: u16, + ) -> VersionizeResult<()> { + if self.len() > MAX_VEC_SIZE / std::mem::size_of::<$GenericParam>() { + return Err(VersionizeError::VecLength(self.len())); + } + + // Serialize in the same fashion as bincode: + // Write len. + bincode::serialize_into(&mut writer, &self.len()) + .map_err(|ref err| VersionizeError::Serialize(format!("{:?}", err)))?; + // Walk the vec and write each element. + for element in self { + element.serialize(writer, version_map, app_version)?; + } + Ok(()) + } + + #[inline] + fn deserialize( + mut reader: &mut R, + version_map: &VersionMap, + app_version: u16, + ) -> VersionizeResult { + let len: usize = bincode::deserialize_from(&mut reader) + .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; + + if len > MAX_VEC_SIZE / std::mem::size_of::<$GenericParam>() { + return Err(VersionizeError::VecLength(len)); + } + + let mut v = Vec::with_capacity(len); + + for _ in 0..len { + let element: $GenericParam = + $GenericParam::deserialize(reader, version_map, app_version).map_err( + |ref err| VersionizeError::Deserialize(format!("{:?}", err)), + )?; + v.push(element); + } + Ok(v.into()) + } + + // Not used yet. + fn version() -> u16 { + 1 + } + } + }; +} + +impl_versionize_vec_like_type!(Vec); +impl_versionize_vec_like_type!(VecDeque); + +// Implement versioning for FAM structures by using the FamStructWrapper interface. +impl Versionize for FamStructWrapper where - T: Versionize, + ::Entry: Versionize, + T: std::fmt::Debug, { #[inline] fn serialize( @@ -292,16 +365,67 @@ where version_map: &VersionMap, app_version: u16, ) -> VersionizeResult<()> { - if self.len() > MAX_VEC_SIZE / std::mem::size_of::() { - return Err(VersionizeError::VecLength(self.len())); + // Write the fixed size header. + self.as_fam_struct_ref() + .serialize(&mut writer, version_map, app_version)?; + // Write the array. + self.as_slice() + .to_vec() + .serialize(&mut writer, version_map, app_version)?; + + Ok(()) + } + + #[inline] + fn deserialize( + reader: &mut R, + version_map: &VersionMap, + app_version: u16, + ) -> VersionizeResult { + let header = T::deserialize(reader, version_map, app_version) + .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; + let entries: Vec<::Entry> = + Vec::deserialize(reader, version_map, app_version) + .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; + // Construct the object from the array items. + // Header(T) fields will be initialized by Default trait impl. + let mut object = FamStructWrapper::from_entries(&entries) + .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; + // Update Default T with the deserialized header. + *object.as_mut_fam_struct() = header; + Ok(object) + } + + // Not used. + fn version() -> u16 { + 1 + } +} + +impl Versionize for HashMap +where + K: Versionize + Eq + Hash + Clone + std::fmt::Debug, + V: Versionize + Clone + std::fmt::Debug, +{ + #[inline] + fn serialize( + &self, + mut writer: &mut W, + version_map: &VersionMap, + app_version: u16, + ) -> VersionizeResult<()> { + let bytes_len = self.len() * (std::mem::size_of::() + std::mem::size_of::()); + if bytes_len > MAX_HASH_MAP_LEN { + return Err(VersionizeError::HashMapLength(bytes_len)); } - // Serialize in the same fashion as bincode: - // Write len. + + // Write len bincode::serialize_into(&mut writer, &self.len()) .map_err(|ref err| VersionizeError::Serialize(format!("{:?}", err)))?; - // Walk the vec and write each element. - for element in self { - element.serialize(writer, version_map, app_version)?; + // Walk the hash map and write each element. + for (k, v) in self.iter() { + k.serialize(writer, version_map, app_version)?; + v.serialize(writer, version_map, app_version)?; } Ok(()) } @@ -312,20 +436,24 @@ where version_map: &VersionMap, app_version: u16, ) -> VersionizeResult { - let mut v = Vec::new(); let len: usize = bincode::deserialize_from(&mut reader) .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; - if len > MAX_VEC_SIZE / std::mem::size_of::() { - return Err(VersionizeError::VecLength(len)); + let bytes_len = len * (std::mem::size_of::() + std::mem::size_of::()); + if bytes_len > MAX_HASH_MAP_LEN { + return Err(VersionizeError::HashMapLength(bytes_len)); } + let mut map = HashMap::with_capacity(len); + for _ in 0..len { - let element: T = T::deserialize(reader, version_map, app_version) + let k = K::deserialize(reader, version_map, app_version) .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; - v.push(element); + let v = V::deserialize(reader, version_map, app_version) + .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; + map.insert(k, v); } - Ok(v) + Ok(map) } // Not used yet. @@ -334,11 +462,9 @@ where } } -// Implement versioning for FAM structures by using the FamStructWrapper interface. -impl Versionize for FamStructWrapper +impl Versionize for HashSet where - ::Entry: Versionize, - T: std::fmt::Debug, + T: Versionize + Hash + Eq, { #[inline] fn serialize( @@ -347,38 +473,46 @@ where version_map: &VersionMap, app_version: u16, ) -> VersionizeResult<()> { - // Write the fixed size header. - self.as_fam_struct_ref() - .serialize(&mut writer, version_map, app_version)?; - // Write the array. - self.as_slice() - .to_vec() - .serialize(&mut writer, version_map, app_version)?; - + let bytes_len = self.len() * std::mem::size_of::(); + if bytes_len > MAX_HASH_SET_LEN { + return Err(VersionizeError::HashSetLength(bytes_len)); + } + // Serialize in the same fashion as bincode: + // Write len. + bincode::serialize_into(&mut writer, &self.len()) + .map_err(|ref err| VersionizeError::Serialize(format!("{:?}", err)))?; + // Walk the vec and write each element. + for element in self.iter() { + element.serialize(writer, version_map, app_version)?; + } Ok(()) } #[inline] fn deserialize( - reader: &mut R, + mut reader: &mut R, version_map: &VersionMap, app_version: u16, ) -> VersionizeResult { - let header = T::deserialize(reader, version_map, app_version) + let len: usize = bincode::deserialize_from(&mut reader) .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; - let entries: Vec<::Entry> = - Vec::deserialize(reader, version_map, app_version) + + let bytes_len = len * std::mem::size_of::(); + if bytes_len > MAX_HASH_SET_LEN { + return Err(VersionizeError::HashSetLength(bytes_len)); + } + + let mut set = HashSet::with_capacity(len); + + for _ in 0..len { + let element: T = T::deserialize(reader, version_map, app_version) .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; - // Construct the object from the array items. - // Header(T) fields will be initialized by Default trait impl. - let mut object = FamStructWrapper::from_entries(&entries) - .map_err(|ref err| VersionizeError::Deserialize(format!("{:?}", err)))?; - // Update Default T with the deserialized header. - *object.as_mut_fam_struct() = header; - Ok(object) + set.insert(element); + } + Ok(set) } - // Not used. + // Not used yet. fn version() -> u16 { 1 } @@ -818,4 +952,254 @@ mod tests { "String length exceeded 16385 > 16384 bytes" ); } + + #[test] + fn test_ser_de_vec_deque() { + let vm = VersionMap::new(); + let mut snapshot_mem = vec![0u8; 64]; + + let mut store = VecDeque::new(); + store.push_back("test 1".to_owned()); + store.push_back("test 2".to_owned()); + store.push_back("test 3".to_owned()); + + store + .serialize(&mut snapshot_mem.as_mut_slice(), &vm, 1) + .unwrap(); + let restore = + as Versionize>::deserialize(&mut snapshot_mem.as_slice(), &vm, 1).unwrap(); + + assert_eq!(store, restore); + } + + #[test] + fn test_corrupted_vec_deque_len() { + let vm = VersionMap::new(); + let mut buffer = vec![0u8; 1024]; + + let string = String::from("Test string1"); + let vec_deque = VecDeque::from(string.into_bytes()); + + vec_deque + .serialize(&mut buffer.as_mut_slice(), &vm, 1) + .unwrap(); + + // Test corrupt length field. + assert_eq!( + as Versionize>::deserialize( + &mut buffer.as_slice().split_first().unwrap().1, + &vm, + 1 + ) + .unwrap_err(), + VersionizeError::VecLength(6052837899185946624) + ); + + // Test incomplete Vec. + assert_eq!( + as Versionize>::deserialize(&mut buffer.as_slice().split_at(6).0, &vm, 1) + .unwrap_err(), + VersionizeError::Deserialize( + "Io(Error { kind: UnexpectedEof, message: \"failed to fill whole buffer\" })" + .to_owned() + ) + ); + + // Test NULL Vec len. + buffer[0] = 0; + assert_eq!( + as Versionize>::deserialize(&mut buffer.as_slice(), &vm, 1).unwrap(), + VecDeque::new() + ); + } + + #[test] + fn test_vec_deque_limit() { + // We need extra 8 bytes for vector len. + let mut snapshot_mem = vec![0u8; MAX_VEC_SIZE + 8]; + let err = VecDeque::from(vec![123u8; MAX_VEC_SIZE + 1]) + .serialize(&mut snapshot_mem.as_mut_slice(), &VersionMap::new(), 1) + .unwrap_err(); + assert_eq!(err, VersionizeError::VecLength(MAX_VEC_SIZE + 1)); + assert_eq!( + format!("{}", err), + "Vec of length 10485761 exceeded maximum size of 10485760 bytes" + ); + } + + #[test] + fn test_ser_de_hash_map() { + let vm = VersionMap::new(); + let mut snapshot_mem = vec![0u8; 128]; + + let mut store = HashMap::new(); + store.insert(1, "test 1".to_owned()); + store.insert(2, "test 2".to_owned()); + store.insert(3, "test 3".to_owned()); + + store + .serialize(&mut snapshot_mem.as_mut_slice(), &vm, 1) + .unwrap(); + let restore = as Versionize>::deserialize( + &mut snapshot_mem.as_slice(), + &vm, + 1, + ) + .unwrap(); + + assert_eq!(store, restore); + } + + #[test] + fn test_corrupted_hash_map_len() { + let vm = VersionMap::new(); + let mut buffer = vec![0u8; 1024]; + + let mut hash_map: HashMap = HashMap::new(); + hash_map.insert(1, 'a'); + hash_map.insert(2, 'b'); + hash_map.insert(3, 'c'); + + hash_map + .serialize(&mut buffer.as_mut_slice(), &vm, 1) + .unwrap(); + + // Test corrupt length field. + // + // Because of the order of hash_map may different, the error length may + // also be different + matches!( + as Versionize>::deserialize( + &mut buffer.as_slice().split_first().unwrap().1, + &vm, + 1 + ) + .unwrap_err(), + VersionizeError::HashMapLength(..) + ); + + // Test incomplete HashMap. + assert_eq!( + as Versionize>::deserialize( + &mut buffer.as_slice().split_at(6).0, + &vm, + 1 + ) + .unwrap_err(), + VersionizeError::Deserialize( + "Io(Error { kind: UnexpectedEof, message: \"failed to fill whole buffer\" })" + .to_owned() + ) + ); + + // Test NULL HashMap len. + buffer[0] = 0; + assert_eq!( + as Versionize>::deserialize(&mut buffer.as_slice(), &vm, 1) + .unwrap(), + HashMap::new() + ); + } + + #[test] + fn test_hash_map_limit() { + // We need extra 8 bytes for HashMap's len. + let mut snapshot_mem = vec![0u8; MAX_HASH_MAP_LEN / 16 + 1]; + let mut err = HashMap::with_capacity(MAX_HASH_MAP_LEN / 16 + 1); + for i in 0..(MAX_HASH_MAP_LEN / 16 + 1) { + err.insert(i, i); + } + let err = err + .serialize(&mut snapshot_mem.as_mut_slice(), &VersionMap::new(), 1) + .unwrap_err(); + assert_eq!(err, VersionizeError::HashMapLength(MAX_HASH_MAP_LEN + 16)); + assert_eq!( + format!("{}", err), + "HashMap of length exceeded 20971536 > 20971520 bytes" + ); + } + + #[test] + fn test_ser_de_hash_set() { + let vm = VersionMap::new(); + let mut snapshot_mem = vec![0u8; 64]; + + let mut store = HashSet::new(); + store.insert("test 1".to_owned()); + store.insert("test 2".to_owned()); + store.insert("test 3".to_owned()); + + store + .serialize(&mut snapshot_mem.as_mut_slice(), &vm, 1) + .unwrap(); + let restore = + as Versionize>::deserialize(&mut snapshot_mem.as_slice(), &vm, 1) + .unwrap(); + + assert_eq!(store, restore); + } + + #[test] + fn test_corrupted_hash_set_len() { + let vm = VersionMap::new(); + let mut buffer = vec![0u8; 1024]; + + let mut hash_set: HashSet = HashSet::new(); + hash_set.insert(1); + hash_set.insert(2); + hash_set.insert(3); + + hash_set + .serialize(&mut buffer.as_mut_slice(), &vm, 1) + .unwrap(); + + // Test corrupt length field. + // + // Because of the order of hash_set may different, the error length may + // also be different + matches!( + as Versionize>::deserialize( + &mut buffer.as_slice().split_first().unwrap().1, + &vm, + 1 + ) + .unwrap_err(), + VersionizeError::HashSetLength(..) + ); + + // Test incomplete HashSet. + assert_eq!( + as Versionize>::deserialize(&mut buffer.as_slice().split_at(6).0, &vm, 1) + .unwrap_err(), + VersionizeError::Deserialize( + "Io(Error { kind: UnexpectedEof, message: \"failed to fill whole buffer\" })" + .to_owned() + ) + ); + + // Test NULL HashSet len. + buffer[0] = 0; + assert_eq!( + as Versionize>::deserialize(&mut buffer.as_slice(), &vm, 1).unwrap(), + HashSet::new() + ); + } + + #[test] + fn test_hash_set_limit() { + // We need extra 8 bytes for HashSet's len. + let mut snapshot_mem = vec![0u8; MAX_HASH_SET_LEN / 8 + 1]; + let mut err = HashSet::with_capacity(MAX_HASH_SET_LEN / 8 + 1); + for i in 0..(MAX_HASH_SET_LEN / 8 + 1) { + err.insert(i); + } + let err = err + .serialize(&mut snapshot_mem.as_mut_slice(), &VersionMap::new(), 1) + .unwrap_err(); + assert_eq!(err, VersionizeError::HashSetLength(MAX_HASH_SET_LEN + 8)); + assert_eq!( + format!("{}", err), + "HashSet of length exceeded 10485768 > 10485760 bytes" + ); + } }