diff --git a/Cargo.toml b/Cargo.toml index ac48aac..feeee3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bcs" -version = "0.1.3" +version = "0.1.4" authors = ["Diem "] description = "Binary Canonical Serialization (BCS)" repository = "https://github.com/diem/bcs" diff --git a/src/de.rs b/src/de.rs index 8b94b56..9240cc4 100644 --- a/src/de.rs +++ b/src/de.rs @@ -369,7 +369,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { self.leave_named_container(); r } - + #[allow(clippy::needless_borrow)] fn deserialize_seq(mut self, visitor: V) -> Result where V: Visitor<'de>, @@ -377,14 +377,14 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { let len = self.parse_length()?; visitor.visit_seq(SeqDeserializer::new(&mut self, len)) } - + #[allow(clippy::needless_borrow)] fn deserialize_tuple(mut self, len: usize, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_seq(SeqDeserializer::new(&mut self, len)) } - + #[allow(clippy::needless_borrow)] fn deserialize_tuple_struct( mut self, name: &'static str, @@ -399,7 +399,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { self.leave_named_container(); r } - + #[allow(clippy::needless_borrow)] fn deserialize_map(mut self, visitor: V) -> Result where V: Visitor<'de>, @@ -407,7 +407,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { let len = self.parse_length()?; visitor.visit_map(MapDeserializer::new(&mut self, len)) } - + #[allow(clippy::needless_borrow)] fn deserialize_struct( mut self, name: &'static str, @@ -464,7 +464,7 @@ struct SeqDeserializer<'a, 'de: 'a> { de: &'a mut Deserializer<'de>, remaining: usize, } - +#[allow(clippy::needless_borrow)] impl<'a, 'de> SeqDeserializer<'a, 'de> { fn new(de: &'a mut Deserializer<'de>, remaining: usize) -> Self { Self { de, remaining } diff --git a/src/error.rs b/src/error.rs index 4a57084..3a913e4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,7 +6,7 @@ use std::fmt; use thiserror::Error; pub type Result = std::result::Result; - +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Debug, Error, PartialEq)] pub enum Error { #[error("unexpected end of input")] diff --git a/src/ser.rs b/src/ser.rs index ca26fb0..b119926 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -260,10 +260,11 @@ where fn serialize_unit_variant( mut self, - _name: &'static str, + name: &'static str, variant_index: u32, _variant: &'static str, ) -> Result<()> { + self.enter_named_container(name)?; self.output_variant_index(variant_index) } diff --git a/tests/serde.rs b/tests/serde.rs index 2e6f095..77053a9 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -570,26 +570,38 @@ fn serde_known_vector() { } #[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] -struct List { - next: Option<(usize, Box)>, +struct List { + value: T, + next: Option>>, } - -impl List { - fn empty() -> Self { - Self { next: None } +impl List { + fn head(value: T) -> Self { + Self { value, next: None } } - fn cons(value: usize, tail: List) -> Self { + fn cons(value: T, tail: List) -> Self { Self { - next: Some((value, Box::new(tail))), + value, + next: Some(Box::new(tail)), + } + } +} +impl List { + fn repeat(len: usize, value: T) -> Self { + if len == 0 { + Self::head(value) + } else { + Self::cons(value.clone(), Self::repeat(len - 1, value)) } } +} +impl List { fn integers(len: usize) -> Self { if len == 0 { - Self::empty() + Self::head(0) } else { - Self::cons(len - 1, Self::integers(len - 1)) + Self::cons(len, Self::integers(len - 1)) } } } @@ -601,31 +613,30 @@ fn test_recursion_limit() { assert_eq!( b1, vec![ - 1, 3, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0 + 4, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] ); - assert_eq!(from_bytes::(&b1).unwrap(), l1); + assert_eq!(from_bytes::>(&b1).unwrap(), l1); let l2 = List::integers(MAX_CONTAINER_DEPTH - 1); let b2 = to_bytes(&l2).unwrap(); - assert_eq!(from_bytes::(&b2).unwrap(), l2); - + assert_eq!(from_bytes::>(&b2).unwrap(), l2); let l3 = List::integers(MAX_CONTAINER_DEPTH); assert_eq!( to_bytes(&l3), Err(Error::ExceededContainerDepthLimit("List")) ); - let mut b3 = vec![1, 243, 1, 0, 0, 0, 0, 0, 0]; + let mut b3 = vec![244, 1, 0, 0, 0, 0, 0, 0, 1]; b3.extend(b2); assert_eq!( - from_bytes::(&b3), + from_bytes::>(&b3), Err(Error::ExceededContainerDepthLimit("List")) ); let b2_pair = to_bytes(&(&l2, &l2)).unwrap(); assert_eq!( - from_bytes::<(List, List)>(&b2_pair).unwrap(), + from_bytes::<(List<_>, List<_>)>(&b2_pair).unwrap(), (l2.clone(), l2.clone()) ); assert_eq!( @@ -641,3 +652,31 @@ fn test_recursion_limit() { Err(Error::ExceededContainerDepthLimit("List")) ); } +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, Debug)] +enum EnumA { + ValueA, +} + +#[test] +fn test_recursion_limit_enum() { + let l1 = List::repeat(6, EnumA::ValueA); + let b1 = to_bytes(&l1).unwrap(); + assert_eq!(b1, vec![0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0],); + assert_eq!(from_bytes::>(&b1).unwrap(), l1); + + let l2 = List::repeat(MAX_CONTAINER_DEPTH - 2, EnumA::ValueA); + let b2 = to_bytes(&l2).unwrap(); + assert_eq!(from_bytes::>(&b2).unwrap(), l2); + + let l3 = List::repeat(MAX_CONTAINER_DEPTH - 1, EnumA::ValueA); + assert_eq!( + to_bytes(&l3), + Err(Error::ExceededContainerDepthLimit("EnumA")) + ); + let mut b3 = vec![0, 1]; + b3.extend(b2); + assert_eq!( + from_bytes::>(&b3), + Err(Error::ExceededContainerDepthLimit("EnumA")) + ); +}