Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't return errors on too large requests on a reversed bitreader #58

Merged
merged 6 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ This document records the changes made between versions, starting with version 0
* The FrameDecoder is now Send + Sync (RingBuffer impls these traits now)

# After 0.6.0
* Small fix in the zstd binary, progress tracking was slighty off for skippable frames resulting in an error only when the last frame in a file was skippable
* Small fix in the zstd binary, progress tracking was slighty off for skippable frames resulting in an error only when the last frame in a file was skippable
* Small performance improvement by reorganizing code with `#[cold]` annotations
* Documentation for `StreamDecoder` mentioning the limitations around multiple frames (https://github.com/Sorseg)
* Documentation around skippable frames (https://github.com/Sorseg)
4 changes: 2 additions & 2 deletions benches/reversedbitreader_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ruzstd::decoding::bit_reader_reverse::BitReaderReversed;
fn do_all_accesses(br: &mut BitReaderReversed, accesses: &[u8]) -> u64 {
let mut sum = 0;
for x in accesses {
sum += br.get_bits(*x).unwrap();
sum += br.get_bits(*x);
}
let _ = black_box(br);
sum
Expand All @@ -24,7 +24,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let mut br = BitReaderReversed::new(&rand_vec);
while br.bits_remaining() > 0 {
let x = rng.gen_range(0..20);
br.get_bits(x).unwrap();
br.get_bits(x);
access_vec.push(x);
}

Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extern crate ruzstd;
use ruzstd::frame_decoder;

fuzz_target!(|data: &[u8]| {
let mut content = data.clone();
let mut content = data;
let mut frame_dec = frame_decoder::FrameDecoder::new();

match frame_dec.reset(&mut content){
Expand Down
4 changes: 4 additions & 0 deletions src/blocks/sequence_section.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Utilities and representations for the second half of a block, the sequence section.
//! This section copies literals from the literals section into the decompressed output.

pub(crate) const MAX_LITERAL_LENGTH_CODE: u8 = 35;
pub(crate) const MAX_MATCH_LENGTH_CODE: u8 = 52;
pub(crate) const MAX_OFFSET_CODE: u8 = 31;

pub struct SequencesHeader {
pub num_sequences: u32,
pub modes: Option<CompressionModes>,
Expand Down
53 changes: 18 additions & 35 deletions src/decoding/bit_reader_reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,40 +111,34 @@ impl<'s> BitReaderReversed<'s> {
/// Read `n` number of bits from the source. Returns an error if the reader
/// requests more bits than remain for reading.
#[inline(always)]
pub fn get_bits(&mut self, n: u8) -> Result<u64, GetBitsError> {
pub fn get_bits(&mut self, n: u8) -> u64 {
if n == 0 {
return Ok(0);
return 0;
}
if self.bits_in_container >= n {
return Ok(self.get_bits_unchecked(n));
return self.get_bits_unchecked(n);
}

self.get_bits_cold(n)
}

#[cold]
fn get_bits_cold(&mut self, n: u8) -> Result<u64, GetBitsError> {
if n > 56 {
return Err(GetBitsError::TooManyBits {
num_requested_bits: usize::from(n),
limit: 56,
});
}

fn get_bits_cold(&mut self, n: u8) -> u64 {
let n = u8::min(n, 56);
let signed_n = n as isize;

if self.bits_remaining() <= 0 {
self.idx -= signed_n;
return Ok(0);
return 0;
}

if self.bits_remaining() < signed_n {
let emulated_read_shift = signed_n - self.bits_remaining();
let v = self.get_bits(self.bits_remaining() as u8)?;
let v = self.get_bits(self.bits_remaining() as u8);
debug_assert!(self.idx == 0);
let value = v << emulated_read_shift;
let value = v.wrapping_shl(emulated_read_shift as u32);
self.idx -= emulated_read_shift;
return Ok(value);
return value;
}

while (self.bits_in_container < n) && self.idx > 0 {
Expand All @@ -155,23 +149,18 @@ impl<'s> BitReaderReversed<'s> {

//if we reach this point there are enough bits in the container

Ok(self.get_bits_unchecked(n))
self.get_bits_unchecked(n)
}

#[inline(always)]
pub fn get_bits_triple(
&mut self,
n1: u8,
n2: u8,
n3: u8,
) -> Result<(u64, u64, u64), GetBitsError> {
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let sum = n1 as usize + n2 as usize + n3 as usize;
if sum == 0 {
return Ok((0, 0, 0));
return (0, 0, 0);
}
if sum > 56 {
// try and get the values separately
return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?));
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}
let sum = sum as u8;

Expand All @@ -192,29 +181,23 @@ impl<'s> BitReaderReversed<'s> {
self.get_bits_unchecked(n3)
};

return Ok((v1, v2, v3));
return (v1, v2, v3);
}

self.get_bits_triple_cold(n1, n2, n3, sum)
}

#[cold]
fn get_bits_triple_cold(
&mut self,
n1: u8,
n2: u8,
n3: u8,
sum: u8,
) -> Result<(u64, u64, u64), GetBitsError> {
fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
let sum_signed = sum as isize;

if self.bits_remaining() <= 0 {
self.idx -= sum_signed;
return Ok((0, 0, 0));
return (0, 0, 0);
}

if self.bits_remaining() < sum_signed {
return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?));
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}

while (self.bits_in_container < sum) && self.idx > 0 {
Expand All @@ -241,7 +224,7 @@ impl<'s> BitReaderReversed<'s> {
self.get_bits_unchecked(n3)
};

Ok((v1, v2, v3))
(v1, v2, v3)
}

#[inline(always)]
Expand Down
12 changes: 6 additions & 6 deletions src/decoding/literals_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ fn decompress_literals(
//skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand All @@ -208,11 +208,11 @@ fn decompress_literals(
//if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoder.init_state(&mut br)?;
decoder.init_state(&mut br);

while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
target.push(decoder.decode_symbol());
decoder.next_state(&mut br)?;
decoder.next_state(&mut br);
}
if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
return Err(DecompressLiteralsError::BitstreamReadMismatch {
Expand All @@ -230,7 +230,7 @@ fn decompress_literals(
let mut br = BitReaderReversed::new(source);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand All @@ -240,10 +240,10 @@ fn decompress_literals(
//if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoder.init_state(&mut br)?;
decoder.init_state(&mut br);
while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
target.push(decoder.decode_symbol());
decoder.next_state(&mut br)?;
decoder.next_state(&mut br);
}
bytes_read += source.len() as u32;
}
Expand Down
16 changes: 10 additions & 6 deletions src/decoding/scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use crate::fse::FSETable;
use crate::huff0::HuffmanTable;
use alloc::vec::Vec;

use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};

/// A block level decoding buffer.
pub struct DecoderScratch {
/// The decoder used for Huffman blocks.
Expand All @@ -29,11 +33,11 @@ impl DecoderScratch {
table: HuffmanTable::new(),
},
fse: FSEScratch {
offsets: FSETable::new(),
offsets: FSETable::new(MAX_OFFSET_CODE),
of_rle: None,
literal_lengths: FSETable::new(),
literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
ll_rle: None,
match_lengths: FSETable::new(),
match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
ml_rle: None,
},
buffer: DecodeBuffer::new(window_size),
Expand Down Expand Up @@ -104,11 +108,11 @@ pub struct FSEScratch {
impl FSEScratch {
pub fn new() -> FSEScratch {
FSEScratch {
offsets: FSETable::new(),
offsets: FSETable::new(MAX_OFFSET_CODE),
of_rle: None,
literal_lengths: FSETable::new(),
literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
ll_rle: None,
match_lengths: FSETable::new(),
match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
ml_rle: None,
}
}
Expand Down
40 changes: 26 additions & 14 deletions src/decoding/sequence_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use super::super::blocks::sequence_section::Sequence;
use super::super::blocks::sequence_section::SequencesHeader;
use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use super::scratch::FSEScratch;
use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};
use crate::fse::{FSEDecoder, FSEDecoderError, FSETableError};
use alloc::vec::Vec;

Expand Down Expand Up @@ -116,7 +119,7 @@ pub fn decode_sequences(
//skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand Down Expand Up @@ -189,13 +192,13 @@ fn decode_sequences_with_rle(
//println!("ml Code: {}", ml_value);
//println!("");

if of_code >= 32 {
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
}

let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?;
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);

if offset == 0 {
Expand All @@ -215,13 +218,13 @@ fn decode_sequences_with_rle(
// br.bits_remaining() / 8,
//);
if scratch.ll_rle.is_none() {
ll_dec.update_state(br)?;
ll_dec.update_state(br);
}
if scratch.ml_rle.is_none() {
ml_dec.update_state(br)?;
ml_dec.update_state(br);
}
if scratch.of_rle.is_none() {
of_dec.update_state(br)?;
of_dec.update_state(br);
}
}

Expand Down Expand Up @@ -264,13 +267,13 @@ fn decode_sequences_without_rle(
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);

if of_code >= 32 {
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
}

let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?;
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);

if offset == 0 {
Expand All @@ -289,9 +292,9 @@ fn decode_sequences_without_rle(
// br.bits_remaining(),
// br.bits_remaining() / 8,
//);
ll_dec.update_state(br)?;
ml_dec.update_state(br)?;
of_dec.update_state(br)?;
ll_dec.update_state(br);
ml_dec.update_state(br);
of_dec.update_state(br);
}

if br.bits_remaining() < 0 {
Expand Down Expand Up @@ -335,7 +338,7 @@ fn lookup_ll_code(code: u8) -> (u32, u8) {
33 => (16384, 14),
34 => (32768, 15),
35 => (65536, 16),
_ => (0, 255),
_ => unreachable!("Illegal literal length code was: {}", code),
}
}

Expand Down Expand Up @@ -367,7 +370,7 @@ fn lookup_ml_code(code: u8) -> (u32, u8) {
50 => (16387, 14),
51 => (32771, 15),
52 => (65539, 16),
_ => (0, 255),
_ => unreachable!("Illegal match length code was: {}", code),
}
}

Expand Down Expand Up @@ -405,6 +408,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleLlTable);
}
bytes_read += 1;
if source[0] > MAX_LITERAL_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ll_rle = Some(source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -437,6 +443,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleOfTable);
}
bytes_read += 1;
if of_source[0] > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.of_rle = Some(of_source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -469,6 +478,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
bytes_read += 1;
if ml_source[0] > MAX_MATCH_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ml_rle = Some(ml_source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -522,7 +534,7 @@ const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [

#[test]
fn test_ll_default() {
let mut table = crate::fse::FSETable::new();
let mut table = crate::fse::FSETable::new(MAX_LITERAL_LENGTH_CODE);
table
.build_from_probabilities(
LL_DEFAULT_ACC_LOG,
Expand Down
Loading