diff --git a/src/lib.rs b/src/lib.rs index 2ee1be5..84f38d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ #![no_std] -cfg_if::cfg_if!{ +cfg_if::cfg_if! { if #[cfg(feature = "std")] { extern crate std; use std::io::Result; @@ -25,10 +25,9 @@ pub struct BitWriter { // number of unwritten bits in cache bits: u8, - } -impl BitWriter{ +impl BitWriter { pub fn new() -> BitWriter { BitWriter { data: Vec::new(), @@ -38,15 +37,15 @@ impl BitWriter{ } } /// Read at most 8 bits into a u8. - pub fn write_u8(&mut self, v: u8, bit_count: u8) -> Result<()> { + pub fn write_u8(&mut self, v: u8, bit_count: u8) -> Result<()> { self.write_unsigned_bits(v as u64, bit_count, 8) } - pub fn write_u16(&mut self, v: u16, bit_count: u8) -> Result<()> { + pub fn write_u16(&mut self, v: u16, bit_count: u8) -> Result<()> { self.write_unsigned_bits(v as u64, bit_count, 16) } - pub fn write_u32(&mut self, v: u32, bit_count: u8) -> Result<()> { + pub fn write_u32(&mut self, v: u32, bit_count: u8) -> Result<()> { self.write_unsigned_bits(v as u64, bit_count, 32) } @@ -54,15 +53,15 @@ impl BitWriter{ self.write_unsigned_bits(v, bit_count, 64) } - pub fn write_i8(&mut self, v: i8, bit_count: u8) -> Result<()> { + pub fn write_i8(&mut self, v: i8, bit_count: u8) -> Result<()> { self.write_signed_bits(v as i64, bit_count, 8) } - pub fn write_i16(&mut self, v: i16, bit_count: u8) -> Result<()> { + pub fn write_i16(&mut self, v: i16, bit_count: u8) -> Result<()> { self.write_signed_bits(v as i64, bit_count, 16) } - pub fn write_i32(&mut self, v: i32, bit_count: u8) -> Result<()> { + pub fn write_i32(&mut self, v: i32, bit_count: u8) -> Result<()> { self.write_signed_bits(v as i64, bit_count, 32) } @@ -94,7 +93,7 @@ impl BitWriter{ self.skip(bits_to_skip) } - pub fn write_signed_bits(&mut self, mut v: i64, n: u8, maximum_count: u8) -> Result<()> { + pub fn write_signed_bits(&mut self, mut v: i64, n: u8, maximum_count: u8) -> Result<()> { if n == 0 { return Ok(()); } @@ -114,7 +113,12 @@ impl BitWriter{ return Err(Error::new(ErrorKind::Unsupported, "too many bits to write")); } // mask all upper bits out to be 0 - v &= (1 << n) - 1; + if n == 64 { + // avoid bitshift overflow exception + v &= u64::MAX; + } else { + v &= (1 << n) - 1; + } self.bit_count += n as u64; @@ -125,7 +129,7 @@ impl BitWriter{ self.bits = new_bits; return Ok(()); } - + if new_bits >= 8 { // write all bytes, by first taking the existing buffer, form a complete byte, // and write that first. @@ -140,13 +144,13 @@ impl BitWriter{ self.data.push((v >> n) as u8); } } - + // Whatever is left is smaller than a byte, and will be put into the cache self.cache = 0; self.bits = n; if n > 0 { - let mask = ((1< &Vec { &self.data } - } diff --git a/src/tests.rs b/src/tests.rs index b08b5e5..22423d5 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -5,8 +5,12 @@ fn simple_writing() { let mut writer = BitWriter::new(); writer.write_bool(true).expect("failed to write bool"); - writer.write_u32(178956970, 28).expect("failed to write u28"); - writer.write_i32(-22369622, 28).expect("failed to write i28"); + writer + .write_u32(178956970, 28) + .expect("failed to write u28"); + writer + .write_i32(-22369622, 28) + .expect("failed to write i28"); assert_eq!(writer.bit_count, 1 + 28 + 28); writer.close().expect("failed to close byte vector"); @@ -15,3 +19,15 @@ fn simple_writing() { let expected = Vec::::from([0xD5, 0x55, 0x55, 0x57, 0x55, 0x55, 0x55, 0x00]); assert_eq!(writer.data, expected); } + +#[test] +fn test_bitshift_overflow() { + let mut writer = BitWriter::new(); + writer + .write_u64(0xFFFFFFFFFFFFFFFF, 64) + .expect("failed to u64"); + writer.write_u64(0x0, 64).expect("failed to write u64"); + writer.write_i64(0x0, 64).expect("failed to write i64"); + assert_eq!(writer.bit_count, 3 * 64); + writer.close().expect("failed to close byte vector"); +}