Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bitshift overflow panic, and cargo fmt #1

Merged
merged 1 commit into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![no_std]
cfg_if::cfg_if!{
cfg_if::cfg_if! {
if #[cfg(feature = "std")] {
extern crate std;
use std::io::Result;
Expand All @@ -25,10 +25,9 @@ pub struct BitWriter {

// number of unwritten bits in cache
bits: u8,

}

impl BitWriter{
impl BitWriter {
pub fn new() -> BitWriter {
BitWriter {
data: Vec::new(),
Expand All @@ -38,31 +37,31 @@ impl BitWriter{
}
}
/// Read at most 8 bits into a u8.
pub fn write_u8(&mut self, v: u8, bit_count: u8) -> Result<()> {
pub fn write_u8(&mut self, v: u8, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v as u64, bit_count, 8)
}

pub fn write_u16(&mut self, v: u16, bit_count: u8) -> Result<()> {
pub fn write_u16(&mut self, v: u16, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v as u64, bit_count, 16)
}

pub fn write_u32(&mut self, v: u32, bit_count: u8) -> Result<()> {
pub fn write_u32(&mut self, v: u32, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v as u64, bit_count, 32)
}

pub fn write_u64(&mut self, v: u64, bit_count: u8) -> Result<()> {
self.write_unsigned_bits(v, bit_count, 64)
}

pub fn write_i8(&mut self, v: i8, bit_count: u8) -> Result<()> {
pub fn write_i8(&mut self, v: i8, bit_count: u8) -> Result<()> {
self.write_signed_bits(v as i64, bit_count, 8)
}

pub fn write_i16(&mut self, v: i16, bit_count: u8) -> Result<()> {
pub fn write_i16(&mut self, v: i16, bit_count: u8) -> Result<()> {
self.write_signed_bits(v as i64, bit_count, 16)
}

pub fn write_i32(&mut self, v: i32, bit_count: u8) -> Result<()> {
pub fn write_i32(&mut self, v: i32, bit_count: u8) -> Result<()> {
self.write_signed_bits(v as i64, bit_count, 32)
}

Expand Down Expand Up @@ -94,7 +93,7 @@ impl BitWriter{
self.skip(bits_to_skip)
}

pub fn write_signed_bits(&mut self, mut v: i64, n: u8, maximum_count: u8) -> Result<()> {
pub fn write_signed_bits(&mut self, mut v: i64, n: u8, maximum_count: u8) -> Result<()> {
if n == 0 {
return Ok(());
}
Expand All @@ -114,7 +113,12 @@ impl BitWriter{
return Err(Error::new(ErrorKind::Unsupported, "too many bits to write"));
}
// mask all upper bits out to be 0
v &= (1 << n) - 1;
if n == 64 {
// avoid bitshift overflow exception
v &= u64::MAX;
} else {
v &= (1 << n) - 1;
}

self.bit_count += n as u64;

Expand All @@ -125,7 +129,7 @@ impl BitWriter{
self.bits = new_bits;
return Ok(());
}

if new_bits >= 8 {
// write all bytes, by first taking the existing buffer, form a complete byte,
// and write that first.
Expand All @@ -140,13 +144,13 @@ impl BitWriter{
self.data.push((v >> n) as u8);
}
}

// Whatever is left is smaller than a byte, and will be put into the cache
self.cache = 0;
self.bits = n;
if n > 0 {
let mask = ((1<<n) as u8) - 1;
self.cache = ((v as u8) & mask) << (8-n);
let mask = ((1 << n) as u8) - 1;
self.cache = ((v as u8) & mask) << (8 - n);
}
Ok(())
}
Expand All @@ -165,5 +169,4 @@ impl BitWriter{
pub fn data(&self) -> &Vec<u8> {
&self.data
}

}
20 changes: 18 additions & 2 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ fn simple_writing() {
let mut writer = BitWriter::new();

writer.write_bool(true).expect("failed to write bool");
writer.write_u32(178956970, 28).expect("failed to write u28");
writer.write_i32(-22369622, 28).expect("failed to write i28");
writer
.write_u32(178956970, 28)
.expect("failed to write u28");
writer
.write_i32(-22369622, 28)
.expect("failed to write i28");
assert_eq!(writer.bit_count, 1 + 28 + 28);

writer.close().expect("failed to close byte vector");
Expand All @@ -15,3 +19,15 @@ fn simple_writing() {
let expected = Vec::<u8>::from([0xD5, 0x55, 0x55, 0x57, 0x55, 0x55, 0x55, 0x00]);
assert_eq!(writer.data, expected);
}

#[test]
fn test_bitshift_overflow() {
let mut writer = BitWriter::new();
writer
.write_u64(0xFFFFFFFFFFFFFFFF, 64)
.expect("failed to u64");
writer.write_u64(0x0, 64).expect("failed to write u64");
writer.write_i64(0x0, 64).expect("failed to write i64");
assert_eq!(writer.bit_count, 3 * 64);
writer.close().expect("failed to close byte vector");
}