Skip to content

Commit

Permalink
introduce checks for maximum symbol in the FSE table decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
KillingSpark committed May 30, 2024
1 parent 0037d69 commit e70edb5
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 24 deletions.
4 changes: 4 additions & 0 deletions src/blocks/sequence_section.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
pub(crate) const MAX_LITERAL_LENGTH_CODE: u8 = 35;
pub(crate) const MAX_MATCH_LENGTH_CODE: u8 = 52;
pub(crate) const MAX_OFFSET_CODE: u8 = 31;

pub struct SequencesHeader {
pub num_sequences: u32,
pub modes: Option<CompressionModes>,
Expand Down
16 changes: 10 additions & 6 deletions src/decoding/scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ use crate::fse::FSETable;
use crate::huff0::HuffmanTable;
use alloc::vec::Vec;

use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};

pub struct DecoderScratch {
pub huf: HuffmanScratch,
pub fse: FSEScratch,
Expand All @@ -23,11 +27,11 @@ impl DecoderScratch {
table: HuffmanTable::new(),
},
fse: FSEScratch {
offsets: FSETable::new(),
offsets: FSETable::new(MAX_OFFSET_CODE),
of_rle: None,
literal_lengths: FSETable::new(),
literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
ll_rle: None,
match_lengths: FSETable::new(),
match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
ml_rle: None,
},
buffer: Decodebuffer::new(window_size),
Expand Down Expand Up @@ -98,11 +102,11 @@ pub struct FSEScratch {
impl FSEScratch {
pub fn new() -> FSEScratch {
FSEScratch {
offsets: FSETable::new(),
offsets: FSETable::new(MAX_OFFSET_CODE),
of_rle: None,
literal_lengths: FSETable::new(),
literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
ll_rle: None,
match_lengths: FSETable::new(),
match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
ml_rle: None,
}
}
Expand Down
22 changes: 17 additions & 5 deletions src/decoding/sequence_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use super::super::blocks::sequence_section::Sequence;
use super::super::blocks::sequence_section::SequencesHeader;
use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use super::scratch::FSEScratch;
use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};
use crate::fse::{FSEDecoder, FSEDecoderError, FSETableError};
use alloc::vec::Vec;

Expand Down Expand Up @@ -131,7 +134,7 @@ fn decode_sequences_with_rle(
//println!("ml Code: {}", ml_value);
//println!("");

if of_code >= 32 {
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
Expand Down Expand Up @@ -206,7 +209,7 @@ fn decode_sequences_without_rle(
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);

if of_code >= 32 {
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
Expand Down Expand Up @@ -273,7 +276,7 @@ fn lookup_ll_code(code: u8) -> (u32, u8) {
33 => (16384, 14),
34 => (32768, 15),
35 => (65536, 16),
_ => (0, 255),
_ => unreachable!("Illegal literal length code was: {}", code),
}
}

Expand Down Expand Up @@ -301,7 +304,7 @@ fn lookup_ml_code(code: u8) -> (u32, u8) {
50 => (16387, 14),
51 => (32771, 15),
52 => (65539, 16),
_ => (0, 255),
_ => unreachable!("Illegal match length code was: {}", code),
}
}

Expand Down Expand Up @@ -335,6 +338,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleLlTable);
}
bytes_read += 1;
if source[0] > MAX_LITERAL_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ll_rle = Some(source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -367,6 +373,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleOfTable);
}
bytes_read += 1;
if of_source[0] > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.of_rle = Some(of_source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -399,6 +408,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
bytes_read += 1;
if ml_source[0] > MAX_MATCH_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ml_rle = Some(ml_source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -437,7 +449,7 @@ const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [

#[test]
fn test_ll_default() {
let mut table = crate::fse::FSETable::new();
let mut table = crate::fse::FSETable::new(MAX_LITERAL_LENGTH_CODE);
table
.build_from_probabilities(
LL_DEFAULT_ACC_LOG,
Expand Down
26 changes: 14 additions & 12 deletions src/fse/fse_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,14 @@ use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use alloc::vec::Vec;

pub struct FSETable {
max_symbol: u8,
pub decode: Vec<Entry>, //used to decode symbols, and calculate the next state

pub accuracy_log: u8,
pub symbol_probabilities: Vec<i32>, //used while building the decode Vector
symbol_counter: Vec<u32>,
}

impl Default for FSETable {
fn default() -> Self {
Self::new()
}
}

#[derive(Debug, derive_more::Display, derive_more::From)]
#[cfg_attr(feature = "std", derive(derive_more::Error))]
#[non_exhaustive]
Expand Down Expand Up @@ -106,8 +101,9 @@ impl<'t> FSEDecoder<'t> {
}

impl FSETable {
pub fn new() -> FSETable {
pub fn new(max_symbol: u8) -> FSETable {
FSETable {
max_symbol,
symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8
symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8
decode: Vec::new(), //depending on acc_log.
Expand Down Expand Up @@ -136,7 +132,7 @@ impl FSETable {
self.accuracy_log = 0;

let bytes_read = self.read_probabilities(source, max_log)?;
self.build_decoding_table();
self.build_decoding_table()?;

Ok(bytes_read)
}
Expand All @@ -151,11 +147,15 @@ impl FSETable {
}
self.symbol_probabilities = probs.to_vec();
self.accuracy_log = acc_log;
self.build_decoding_table();
Ok(())
self.build_decoding_table()
}

fn build_decoding_table(&mut self) {
fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
return Err(FSETableError::TooManySymbols {
got: self.symbol_probabilities.len(),
});
}
self.decode.clear();

let table_size = 1 << self.accuracy_log;
Expand Down Expand Up @@ -227,6 +227,7 @@ impl FSETable {
entry.base_line = bl;
entry.num_bits = nb;
}
Ok(())
}

fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
Expand Down Expand Up @@ -299,7 +300,7 @@ impl FSETable {
symbol_probabilities: self.symbol_probabilities.clone(),
});
}
if self.symbol_probabilities.len() > 256 {
if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
return Err(FSETableError::TooManySymbols {
got: self.symbol_probabilities.len(),
});
Expand All @@ -310,6 +311,7 @@ impl FSETable {
} else {
(br.bits_read() / 8) + 1
};

Ok(bytes_read)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/huff0/huff0_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl HuffmanTable {
bits: Vec::with_capacity(256),
bit_ranks: Vec::with_capacity(11),
rank_indexes: Vec::with_capacity(11),
fse_table: FSETable::new(),
fse_table: FSETable::new(100),
}
}

Expand Down

0 comments on commit e70edb5

Please sign in to comment.