diff --git a/benches/reversedbitreader_bench.rs b/benches/reversedbitreader_bench.rs index 7b8d1a42..dd8b9569 100644 --- a/benches/reversedbitreader_bench.rs +++ b/benches/reversedbitreader_bench.rs @@ -5,7 +5,7 @@ use ruzstd::decoding::bit_reader_reverse::BitReaderReversed; fn do_all_accesses(br: &mut BitReaderReversed, accesses: &[u8]) -> u64 { let mut sum = 0; for x in accesses { - sum += br.get_bits(*x).unwrap(); + sum += br.get_bits(*x); } let _ = black_box(br); sum @@ -24,7 +24,7 @@ fn criterion_benchmark(c: &mut Criterion) { let mut br = BitReaderReversed::new(&rand_vec); while br.bits_remaining() > 0 { let x = rng.gen_range(0..20); - br.get_bits(x).unwrap(); + br.get_bits(x); access_vec.push(x); } diff --git a/src/decoding/bit_reader_reverse.rs b/src/decoding/bit_reader_reverse.rs index 5bc5a2a7..13c9c75e 100644 --- a/src/decoding/bit_reader_reverse.rs +++ b/src/decoding/bit_reader_reverse.rs @@ -103,40 +103,33 @@ impl<'s> BitReaderReversed<'s> { } #[inline(always)] - pub fn get_bits(&mut self, n: u8) -> Result { + pub fn get_bits(&mut self, n: u8) -> u64 { if n == 0 { - return Ok(0); + return 0; } if self.bits_in_container >= n { - return Ok(self.get_bits_unchecked(n)); + return self.get_bits_unchecked(n); } self.get_bits_cold(n) } #[cold] - fn get_bits_cold(&mut self, n: u8) -> Result { - if n > 56 { - return Err(GetBitsError::TooManyBits { - num_requested_bits: usize::from(n), - limit: 56, - }); - } - + fn get_bits_cold(&mut self, n: u8) -> u64 { let signed_n = n as isize; if self.bits_remaining() <= 0 { self.idx -= signed_n; - return Ok(0); + return 0; } if self.bits_remaining() < signed_n { let emulated_read_shift = signed_n - self.bits_remaining(); - let v = self.get_bits(self.bits_remaining() as u8)?; + let v = self.get_bits(self.bits_remaining() as u8); debug_assert!(self.idx == 0); let value = v << emulated_read_shift; self.idx -= emulated_read_shift; - return Ok(value); + return value; } while (self.bits_in_container < n) && self.idx > 0 { @@ -147,23 +140,18 @@ impl<'s> BitReaderReversed<'s> { //if we reach this point there are enough bits in the container - Ok(self.get_bits_unchecked(n)) + self.get_bits_unchecked(n) } #[inline(always)] - pub fn get_bits_triple( - &mut self, - n1: u8, - n2: u8, - n3: u8, - ) -> Result<(u64, u64, u64), GetBitsError> { + pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) { let sum = n1 as usize + n2 as usize + n3 as usize; if sum == 0 { - return Ok((0, 0, 0)); + return (0, 0, 0); } if sum > 56 { // try and get the values separatly - return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?)); + return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)); } let sum = sum as u8; @@ -184,29 +172,23 @@ impl<'s> BitReaderReversed<'s> { self.get_bits_unchecked(n3) }; - return Ok((v1, v2, v3)); + return (v1, v2, v3); } self.get_bits_triple_cold(n1, n2, n3, sum) } #[cold] - fn get_bits_triple_cold( - &mut self, - n1: u8, - n2: u8, - n3: u8, - sum: u8, - ) -> Result<(u64, u64, u64), GetBitsError> { + fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) { let sum_signed = sum as isize; if self.bits_remaining() <= 0 { self.idx -= sum_signed; - return Ok((0, 0, 0)); + return (0, 0, 0); } if self.bits_remaining() < sum_signed { - return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?)); + return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)); } while (self.bits_in_container < sum) && self.idx > 0 { @@ -233,7 +215,7 @@ impl<'s> BitReaderReversed<'s> { self.get_bits_unchecked(n3) }; - Ok((v1, v2, v3)) + (v1, v2, v3) } #[inline(always)] diff --git a/src/decoding/literals_section_decoder.rs b/src/decoding/literals_section_decoder.rs index 50e0c941..f437a58f 100644 --- a/src/decoding/literals_section_decoder.rs +++ b/src/decoding/literals_section_decoder.rs @@ -126,7 +126,7 @@ fn decompress_literals( //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found let mut skipped_bits = 0; loop { - let val = br.get_bits(1)?; + let val = br.get_bits(1); skipped_bits += 1; if val == 1 || skipped_bits > 8 { break; @@ -136,11 +136,11 @@ fn decompress_literals( //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data return Err(DecompressLiteralsError::ExtraPadding { skipped_bits }); } - decoder.init_state(&mut br)?; + decoder.init_state(&mut br); while br.bits_remaining() > -(scratch.table.max_num_bits as isize) { target.push(decoder.decode_symbol()); - decoder.next_state(&mut br)?; + decoder.next_state(&mut br); } if br.bits_remaining() != -(scratch.table.max_num_bits as isize) { return Err(DecompressLiteralsError::BitstreamReadMismatch { @@ -158,7 +158,7 @@ fn decompress_literals( let mut br = BitReaderReversed::new(source); let mut skipped_bits = 0; loop { - let val = br.get_bits(1)?; + let val = br.get_bits(1); skipped_bits += 1; if val == 1 || skipped_bits > 8 { break; @@ -168,10 +168,10 @@ fn decompress_literals( //if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data return Err(DecompressLiteralsError::ExtraPadding { skipped_bits }); } - decoder.init_state(&mut br)?; + decoder.init_state(&mut br); while br.bits_remaining() > -(scratch.table.max_num_bits as isize) { target.push(decoder.decode_symbol()); - decoder.next_state(&mut br)?; + decoder.next_state(&mut br); } bytes_read += source.len() as u32; } diff --git a/src/decoding/sequence_section_decoder.rs b/src/decoding/sequence_section_decoder.rs index e95e9e60..7806b4a8 100644 --- a/src/decoding/sequence_section_decoder.rs +++ b/src/decoding/sequence_section_decoder.rs @@ -58,7 +58,7 @@ pub fn decode_sequences( //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found let mut skipped_bits = 0; loop { - let val = br.get_bits(1)?; + let val = br.get_bits(1); skipped_bits += 1; if val == 1 || skipped_bits > 8 { break; @@ -137,7 +137,7 @@ fn decode_sequences_with_rle( }); } - let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?; + let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits); let offset = obits as u32 + (1u32 << of_code); if offset == 0 { @@ -157,13 +157,13 @@ fn decode_sequences_with_rle( // br.bits_remaining() / 8, //); if scratch.ll_rle.is_none() { - ll_dec.update_state(br)?; + ll_dec.update_state(br); } if scratch.ml_rle.is_none() { - ml_dec.update_state(br)?; + ml_dec.update_state(br); } if scratch.of_rle.is_none() { - of_dec.update_state(br)?; + of_dec.update_state(br); } } @@ -212,7 +212,7 @@ fn decode_sequences_without_rle( }); } - let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?; + let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits); let offset = obits as u32 + (1u32 << of_code); if offset == 0 { @@ -231,9 +231,9 @@ fn decode_sequences_without_rle( // br.bits_remaining(), // br.bits_remaining() / 8, //); - ll_dec.update_state(br)?; - ml_dec.update_state(br)?; - of_dec.update_state(br)?; + ll_dec.update_state(br); + ml_dec.update_state(br); + of_dec.update_state(br); } if br.bits_remaining() < 0 { diff --git a/src/fse/fse_decoder.rs b/src/fse/fse_decoder.rs index 969768f1..ede22fc2 100644 --- a/src/fse/fse_decoder.rs +++ b/src/fse/fse_decoder.rs @@ -89,23 +89,19 @@ impl<'t> FSEDecoder<'t> { if self.table.accuracy_log == 0 { return Err(FSEDecoderError::TableIsUninitialized); } - self.state = self.table.decode[bits.get_bits(self.table.accuracy_log)? as usize]; + self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize]; Ok(()) } - pub fn update_state( - &mut self, - bits: &mut BitReaderReversed<'_>, - ) -> Result<(), FSEDecoderError> { + pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; - let add = bits.get_bits(num_bits)?; + let add = bits.get_bits(num_bits); let base_line = self.state.base_line; let new_state = base_line + add as u32; self.state = self.table.decode[new_state as usize]; //println!("Update: {}, {} -> {}", base_line, add, self.state); - Ok(()) } } diff --git a/src/huff0/huff0_decoder.rs b/src/huff0/huff0_decoder.rs index cc12476a..d918d72b 100644 --- a/src/huff0/huff0_decoder.rs +++ b/src/huff0/huff0_decoder.rs @@ -109,26 +109,20 @@ impl<'t> HuffmanDecoder<'t> { self.table.decode[self.state as usize].symbol } - pub fn init_state( - &mut self, - br: &mut BitReaderReversed<'_>, - ) -> Result { + pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 { let num_bits = self.table.max_num_bits; - let new_bits = br.get_bits(num_bits)?; + let new_bits = br.get_bits(num_bits); self.state = new_bits; - Ok(num_bits) + num_bits } - pub fn next_state( - &mut self, - br: &mut BitReaderReversed<'_>, - ) -> Result { + pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 { let num_bits = self.table.decode[self.state as usize].num_bits; - let new_bits = br.get_bits(num_bits)?; + let new_bits = br.get_bits(num_bits); self.state <<= num_bits; self.state &= self.table.decode.len() as u64 - 1; self.state |= new_bits; - Ok(num_bits) + num_bits } } @@ -235,7 +229,7 @@ impl HuffmanTable { //skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found let mut skipped_bits = 0; loop { - let val = br.get_bits(1)?; + let val = br.get_bits(1); skipped_bits += 1; if val == 1 || skipped_bits > 8 { break; @@ -254,7 +248,7 @@ impl HuffmanTable { loop { let w = dec1.decode_symbol(); self.weights.push(w); - dec1.update_state(&mut br)?; + dec1.update_state(&mut br); if br.bits_remaining() <= -1 { //collect final states @@ -264,7 +258,7 @@ impl HuffmanTable { let w = dec2.decode_symbol(); self.weights.push(w); - dec2.update_state(&mut br)?; + dec2.update_state(&mut br); if br.bits_remaining() <= -1 { //collect final states diff --git a/src/tests/bit_reader.rs b/src/tests/bit_reader.rs index 84f5ad5f..06097ee0 100644 --- a/src/tests/bit_reader.rs +++ b/src/tests/bit_reader.rs @@ -22,7 +22,7 @@ fn test_bitreader_reversed() { num_bits = 128 - bits_read; } - let bits = br.get_bits(num_bits).unwrap(); + let bits = br.get_bits(num_bits); bits_read += num_bits; accumulator |= u128::from(bits) << (128 - bits_read); if bits_read >= 128 {