diff --git a/Cargo.toml b/Cargo.toml index 334a2705..b9600585 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ categories = ["compression"] [dependencies] byteorder = { version = "1.5", default-features = false } twox-hash = { version = "1.6", default-features = false, optional = true } -derive_more = { version = "0.99", default-features = false, features = ["display", "from"] } [dev-dependencies] criterion = "0.5" @@ -24,7 +23,7 @@ rand = { version = "0.8.5", features = ["small_rng"] } [features] default = ["hash", "std"] hash = ["dep:twox-hash"] -std = ["derive_more/error"] +std = [] [[bench]] name = "reversedbitreader_bench" diff --git a/src/blocks/block.rs b/src/blocks/block.rs index 078eb44e..c8e63b74 100644 --- a/src/blocks/block.rs +++ b/src/blocks/block.rs @@ -1,8 +1,16 @@ +//! Block header definitions. + +/// There are 4 different kinds of blocks, and the type of block influences the meaning of `Block_Size`. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BlockType { + /// An uncompressed block. Raw, + /// A single byte, repeated `Block_Size` times (Run Length Encoding). RLE, + /// A Zstandard compressed block. `Block_Size` is the length of the compressed data. Compressed, + /// This is not a valid block, and this value should not be used. + /// If this value is present, it should be considered corrupted data. Reserved, } @@ -17,9 +25,18 @@ impl core::fmt::Display for BlockType { } } +/// A representation of a single block header. As well as containing a frame header, +/// each Zstandard frame contains one or more blocks. pub struct BlockHeader { + /// Whether this block is the last block in the frame. + /// It may be followed by an optional `Content_Checksum` if it is. pub last_block: bool, pub block_type: BlockType, + /// The size of the decompressed data. If the block type + /// is [BlockType::Reserved] or [BlockType::Compressed], + /// this value is set to zero and should not be referenced. pub decompressed_size: u32, + /// The size of the block. If the block is [BlockType::RLE], + /// this value will be 1. pub content_size: u32, } diff --git a/src/blocks/literals_section.rs b/src/blocks/literals_section.rs index 50d821c1..d7b908dd 100644 --- a/src/blocks/literals_section.rs +++ b/src/blocks/literals_section.rs @@ -1,34 +1,92 @@ +//! Utilities and representations for the first half of a block, the literals section. +//! It contains data that is then copied from by the sequences section. use super::super::decoding::bit_reader::{BitReader, GetBitsError}; +/// A compressed block consists of two sections, a literals section, and a sequences section. +/// This is the first of those two sections. A literal is just any arbitrary data, and it is copied by the sequences section pub struct LiteralsSection { + /// - If this block is of type [LiteralsSectionType::Raw], then the data is `regenerated_bytes` + /// bytes long, and it contains the raw literals data to be used during the second section, + /// the sequences section. + /// - If this block is of type [LiteralsSectionType::RLE], + /// then the literal consists of a single byte repeated `regenerated_size` times. + /// - For types [LiteralsSectionType::Compressed] or [LiteralsSectionType::Treeless], + /// then this is the size of the decompressed data. pub regenerated_size: u32, + /// - For types [LiteralsSectionType::Raw] and [LiteralsSectionType::RLE], this value is not present. + /// - For types [LiteralsSectionType::Compressed] and [LiteralsSectionType::Treeless], this value will + /// be set to the size of the compressed data. pub compressed_size: Option, + /// This value will be either 1 stream or 4 streams if the literal is of type + /// [LiteralsSectionType::Compressed] or [LiteralsSectionType::Treeless], and it + /// is not used for RLE or uncompressed literals. pub num_streams: Option, + /// The type of the literal section. pub ls_type: LiteralsSectionType, } +/// The way which a literal section is encoded. pub enum LiteralsSectionType { + /// Literals are stored uncompressed. Raw, + /// Literals consist of a single byte value repeated [LiteralsSection::regenerated_size] times. RLE, + /// This is a standard Huffman-compressed block, starting with a Huffman tree description. + /// In this mode, there are at least *2* different literals represented in the Huffman tree + /// description. Compressed, + /// This is a Huffman-compressed block, + /// using the Huffman tree from the previous [LiteralsSectionType::Compressed] block + /// in the sequence. If this mode is triggered without any previous Huffman-tables in the + /// frame (or dictionary), it should be treated as data corruption. Treeless, } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum LiteralsSectionParseError { - #[display(fmt = "Illegal literalssectiontype. Is: {got}, must be in: 0, 1, 2, 3")] IllegalLiteralSectionType { got: u8 }, - #[display(fmt = "{_0:?}")] - #[from] GetBitsError(GetBitsError), - #[display( - fmt = "Not enough byte to parse the literals section header. Have: {have}, Need: {need}" - )] NotEnoughBytes { have: usize, need: u8 }, } +#[cfg(feature = "std")] +impl std::error::Error for LiteralsSectionParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + LiteralsSectionParseError::GetBitsError(source) => Some(source), + _ => None, + } + } +} +impl core::fmt::Display for LiteralsSectionParseError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + LiteralsSectionParseError::IllegalLiteralSectionType { got } => { + write!( + f, + "Illegal literalssectiontype. Is: {}, must be in: 0, 1, 2, 3", + got + ) + } + LiteralsSectionParseError::GetBitsError(e) => write!(f, "{:?}", e), + LiteralsSectionParseError::NotEnoughBytes { have, need } => { + write!( + f, + "Not enough byte to parse the literals section header. Have: {}, Need: {}", + have, need, + ) + } + } + } +} + +impl From for LiteralsSectionParseError { + fn from(val: GetBitsError) -> Self { + Self::GetBitsError(val) + } +} + impl core::fmt::Display for LiteralsSectionType { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { match self { @@ -47,6 +105,7 @@ impl Default for LiteralsSection { } impl LiteralsSection { + /// Create a new [LiteralsSection]. pub fn new() -> LiteralsSection { LiteralsSection { regenerated_size: 0, @@ -56,25 +115,26 @@ impl LiteralsSection { } } + /// Given the first byte of a header, determine the size of the whole header, from 1 to 5 bytes. pub fn header_bytes_needed(&self, first_byte: u8) -> Result { - let ls_type = Self::section_type(first_byte)?; + let ls_type: LiteralsSectionType = Self::section_type(first_byte)?; let size_format = (first_byte >> 2) & 0x3; match ls_type { LiteralsSectionType::RLE | LiteralsSectionType::Raw => { match size_format { 0 | 2 => { - //size_format actually only uses one bit - //regenerated_size uses 5 bits + // size_format actually only uses one bit + // regenerated_size uses 5 bits Ok(1) } 1 => { - //size_format uses 2 bit - //regenerated_size uses 12 bits + // size_format uses 2 bit + // regenerated_size uses 12 bits Ok(2) } 3 => { - //size_format uses 2 bit - //regenerated_size uses 20 bits + // size_format uses 2 bit + // regenerated_size uses 20 bits Ok(3) } _ => panic!( @@ -85,16 +145,16 @@ impl LiteralsSection { LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => { match size_format { 0 | 1 => { - //Only differ in num_streams - //both regenerated and compressed sizes use 10 bit + // Only differ in num_streams + // both regenerated and compressed sizes use 10 bit Ok(3) } 2 => { - //both regenerated and compressed sizes use 14 bit + // both regenerated and compressed sizes use 14 bit Ok(4) } 3 => { - //both regenerated and compressed sizes use 18 bit + // both regenerated and compressed sizes use 18 bit Ok(5) } @@ -106,10 +166,11 @@ impl LiteralsSection { } } + /// Parse the header into `self`, and returns the number of bytes read. pub fn parse_from_header(&mut self, raw: &[u8]) -> Result { - let mut br = BitReader::new(raw); - let t = br.get_bits(2)? as u8; - self.ls_type = Self::section_type(t)?; + let mut br: BitReader<'_> = BitReader::new(raw); + let block_type = br.get_bits(2)? as u8; + self.ls_type = Self::section_type(block_type)?; let size_format = br.get_bits(2)? as u8; let byte_needed = self.header_bytes_needed(raw[0])?; @@ -125,20 +186,20 @@ impl LiteralsSection { self.compressed_size = None; match size_format { 0 | 2 => { - //size_format actually only uses one bit - //regenerated_size uses 5 bits + // size_format actually only uses one bit + // regenerated_size uses 5 bits self.regenerated_size = u32::from(raw[0]) >> 3; Ok(1) } 1 => { - //size_format uses 2 bit - //regenerated_size uses 12 bits + // size_format uses 2 bit + // regenerated_size uses 12 bits self.regenerated_size = (u32::from(raw[0]) >> 4) + (u32::from(raw[1]) << 4); Ok(2) } 3 => { - //size_format uses 2 bit - //regenerated_size uses 20 bits + // size_format uses 2 bit + // regenerated_size uses 20 bits self.regenerated_size = (u32::from(raw[0]) >> 4) + (u32::from(raw[1]) << 4) + (u32::from(raw[2]) << 12); @@ -164,10 +225,10 @@ impl LiteralsSection { match size_format { 0 | 1 => { - //Differ in num_streams see above - //both regenerated and compressed sizes use 10 bit + // Differ in num_streams see above + // both regenerated and compressed sizes use 10 bit - //4 from the first, six from the second byte + // 4 from the first, six from the second byte self.regenerated_size = (u32::from(raw[0]) >> 4) + ((u32::from(raw[1]) & 0x3f) << 4); @@ -177,27 +238,27 @@ impl LiteralsSection { Ok(3) } 2 => { - //both regenerated and compressed sizes use 14 bit + // both regenerated and compressed sizes use 14 bit - //4 from first, full second, 2 from the third byte + // 4 from first, full second, 2 from the third byte self.regenerated_size = (u32::from(raw[0]) >> 4) + (u32::from(raw[1]) << 4) + ((u32::from(raw[2]) & 0x3) << 12); - //6 from the third, full last byte + // 6 from the third, full last byte self.compressed_size = Some((u32::from(raw[2]) >> 2) + (u32::from(raw[3]) << 6)); Ok(4) } 3 => { - //both regenerated and compressed sizes use 18 bit + // both regenerated and compressed sizes use 18 bit - //4 from first, full second, six from third byte + // 4 from first, full second, six from third byte self.regenerated_size = (u32::from(raw[0]) >> 4) + (u32::from(raw[1]) << 4) + ((u32::from(raw[2]) & 0x3F) << 12); - //2 from third, full fourth, full fifth byte + // 2 from third, full fourth, full fifth byte self.compressed_size = Some( (u32::from(raw[2]) >> 6) + (u32::from(raw[3]) << 2) @@ -214,6 +275,7 @@ impl LiteralsSection { } } + /// Given the first two bits of a header, determine the type of a header. fn section_type(raw: u8) -> Result { let t = raw & 0x3; match t { diff --git a/src/blocks/mod.rs b/src/blocks/mod.rs index d12a1866..c4787b87 100644 --- a/src/blocks/mod.rs +++ b/src/blocks/mod.rs @@ -1,3 +1,10 @@ +//! In a Zstandard frame, there's a frame header, followed by one or more *blocks*. +//! +//! A block contains data, and a header describing how that data is encoded, as well +//! as other misc metadata. +//! +//! + pub mod block; pub mod literals_section; pub mod sequence_section; diff --git a/src/blocks/sequence_section.rs b/src/blocks/sequence_section.rs index 0553163f..9822f946 100644 --- a/src/blocks/sequence_section.rs +++ b/src/blocks/sequence_section.rs @@ -1,3 +1,6 @@ +//! Utilities and representations for the second half of a block, the sequence section. +//! This section copies literals from the literals section into the decompressed output. + pub(crate) const MAX_LITERAL_LENGTH_CODE: u8 = 35; pub(crate) const MAX_MATCH_LENGTH_CODE: u8 = 52; pub(crate) const MAX_OFFSET_CODE: u8 = 31; @@ -7,10 +10,27 @@ pub struct SequencesHeader { pub modes: Option, } +/// A sequence represents potentially redundant data, and it can be broken up into 2 steps: +/// - A copy step, where data is copied from the literals section to the decompressed output +/// - A *match* copy step that copies data from within the previously decompressed output. +/// +/// #[derive(Clone, Copy)] pub struct Sequence { + /// Literal length, or the number of bytes to be copied from the literals section + /// in the copy step. pub ll: u32, + /// The length of the match to make during the match copy step. pub ml: u32, + /// How far back to go in the decompressed data to read from the match copy step. + /// If this value is greater than 3, then the offset is `of -3`. If `of` is from 1-3, + /// then it has special handling: + /// + /// The first 3 values define 3 different repeated offsets, with 1 referring to the most + /// recent, 2 the second recent, and so on. When the current sequence has a literal length of 0, + /// then the repeated offsets are shifted by 1. So an offset value of 1 refers to 2, 2 refers to 3, + /// and 3 refers to the most recent offset minus one. If that value is equal to zero, the data + /// is considered corrupted. pub of: u32, } @@ -20,16 +40,27 @@ impl core::fmt::Display for Sequence { } } +/// This byte defines the compression mode of each symbol type #[derive(Copy, Clone)] pub struct CompressionModes(u8); +/// The compression mode used for symbol compression pub enum ModeType { + /// A predefined FSE distribution table is used, and no distribution table + /// will be present. Predefined, + /// The table consists of a single byte, which contains the symbol's value. RLE, + /// Standard FSE compression, a distribution table will be present. This + /// mode should not be used when only one symbol is present. FSECompressed, + /// The table used in the previous compressed block with at least one sequence + /// will be used again. If this is the first block, the table in the dictionary will + /// be used. Repeat, } impl CompressionModes { + /// Deserialize a two bit mode value into a [ModeType] pub fn decode_mode(m: u8) -> ModeType { match m { 0 => ModeType::Predefined, @@ -39,15 +70,17 @@ impl CompressionModes { _ => panic!("This can never happen"), } } - + /// Read the compression mode of the literal lengths field. pub fn ll_mode(self) -> ModeType { Self::decode_mode(self.0 >> 6) } + /// Read the compression mode of the offset value field. pub fn of_mode(self) -> ModeType { Self::decode_mode((self.0 >> 4) & 0x3) } + /// Read the compression mode of the match lengths field. pub fn ml_mode(self) -> ModeType { Self::decode_mode((self.0 >> 2) & 0x3) } @@ -59,17 +92,31 @@ impl Default for SequencesHeader { } } -#[derive(Debug, derive_more::Display)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum SequencesHeaderParseError { - #[display( - fmt = "source must have at least {need_at_least} bytes to parse header; got {got} bytes" - )] NotEnoughBytes { need_at_least: u8, got: usize }, } +#[cfg(feature = "std")] +impl std::error::Error for SequencesHeaderParseError {} + +impl core::fmt::Display for SequencesHeaderParseError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + SequencesHeaderParseError::NotEnoughBytes { need_at_least, got } => { + write!( + f, + "source must have at least {} bytes to parse header; got {} bytes", + need_at_least, got, + ) + } + } + } +} + impl SequencesHeader { + /// Create a new [SequencesHeader]. pub fn new() -> SequencesHeader { SequencesHeader { num_sequences: 0, @@ -77,6 +124,7 @@ impl SequencesHeader { } } + /// Attempt to deserialize the provided buffer into `self`, returning the number of bytes read. pub fn parse_from_header(&mut self, source: &[u8]) -> Result { let mut bytes_read = 0; if source.is_empty() { diff --git a/src/decoding/bit_reader.rs b/src/decoding/bit_reader.rs index 55fb3064..85058211 100644 --- a/src/decoding/bit_reader.rs +++ b/src/decoding/bit_reader.rs @@ -1,21 +1,50 @@ +/// Interact with a provided source at a bit level. pub struct BitReader<'s> { idx: usize, //index counts bits already read source: &'s [u8], } -#[derive(Debug, derive_more::Display)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum GetBitsError { - #[display( - fmt = "Cant serve this request. The reader is limited to {limit} bits, requested {num_requested_bits} bits" - )] TooManyBits { num_requested_bits: usize, limit: u8, }, - #[display(fmt = "Can't read {requested} bits, only have {remaining} bits left")] - NotEnoughRemainingBits { requested: usize, remaining: usize }, + NotEnoughRemainingBits { + requested: usize, + remaining: usize, + }, +} + +#[cfg(feature = "std")] +impl std::error::Error for GetBitsError {} + +impl core::fmt::Display for GetBitsError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + GetBitsError::TooManyBits { + num_requested_bits, + limit, + } => { + write!( + f, + "Cant serve this request. The reader is limited to {} bits, requested {} bits", + limit, num_requested_bits, + ) + } + GetBitsError::NotEnoughRemainingBits { + requested, + remaining, + } => { + write!( + f, + "Can\'t read {} bits, only have {} bits left", + requested, remaining, + ) + } + } + } } impl<'s> BitReader<'s> { diff --git a/src/decoding/bit_reader_reverse.rs b/src/decoding/bit_reader_reverse.rs index c86d4343..58c42ed9 100644 --- a/src/decoding/bit_reader_reverse.rs +++ b/src/decoding/bit_reader_reverse.rs @@ -2,15 +2,21 @@ pub use super::bit_reader::GetBitsError; use byteorder::ByteOrder; use byteorder::LittleEndian; +/// Zstandard encodes some types of data in a way that the data must be read +/// back to front to decode it properly. `BitReaderReversed` provides a +/// convenient interface to do that. pub struct BitReaderReversed<'s> { idx: isize, //index counts bits already read source: &'s [u8], - + /// The reader doesn't read directly from the source, + /// it reads bits from here, and the container is + /// "refilled" as it's emptied. bit_container: u64, bits_in_container: u8, } impl<'s> BitReaderReversed<'s> { + /// How many bits are left to read by the reader. pub fn bits_remaining(&self) -> isize { self.idx + self.bits_in_container as isize } @@ -102,6 +108,8 @@ impl<'s> BitReaderReversed<'s> { (self.idx - 1) / 8 } + /// Read `n` number of bits from the source. Returns an error if the reader + /// requests more bits than remain for reading. #[inline(always)] pub fn get_bits(&mut self, n: u8) -> u64 { if n == 0 { @@ -151,7 +159,7 @@ impl<'s> BitReaderReversed<'s> { return (0, 0, 0); } if sum > 56 { - // try and get the values separatly + // try and get the values separately return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)); } let sum = sum as u8; diff --git a/src/decoding/block_decoder.rs b/src/decoding/block_decoder.rs index 18307b09..2d421fda 100644 --- a/src/decoding/block_decoder.rs +++ b/src/decoding/block_decoder.rs @@ -25,91 +25,246 @@ enum DecoderState { Failed, //TODO put "self.internal_state = DecoderState::Failed;" everywhere an unresolvable error occurs } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum BlockHeaderReadError { - #[display(fmt = "Error while reading the block header")] - #[from] ReadError(io::Error), - #[display(fmt = "Reserved block occured. This is considered corruption by the documentation")] FoundReservedBlock, - #[display(fmt = "Error getting block type: {_0}")] - #[from] BlockTypeError(BlockTypeError), - #[display(fmt = "Error getting block content size: {_0}")] - #[from] BlockSizeError(BlockSizeError), } -#[derive(Debug, derive_more::Display)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[cfg(feature = "std")] +impl std::error::Error for BlockHeaderReadError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + BlockHeaderReadError::ReadError(source) => Some(source), + BlockHeaderReadError::BlockTypeError(source) => Some(source), + BlockHeaderReadError::BlockSizeError(source) => Some(source), + BlockHeaderReadError::FoundReservedBlock => None, + } + } +} + +impl ::core::fmt::Display for BlockHeaderReadError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> ::core::fmt::Result { + match self { + BlockHeaderReadError::ReadError(_) => write!(f, "Error while reading the block header"), + BlockHeaderReadError::FoundReservedBlock => write!( + f, + "Reserved block occured. This is considered corruption by the documentation" + ), + BlockHeaderReadError::BlockTypeError(e) => write!(f, "Error getting block type: {}", e), + BlockHeaderReadError::BlockSizeError(e) => { + write!(f, "Error getting block content size: {}", e) + } + } + } +} + +impl From for BlockHeaderReadError { + fn from(val: io::Error) -> Self { + Self::ReadError(val) + } +} + +impl From for BlockHeaderReadError { + fn from(val: BlockTypeError) -> Self { + Self::BlockTypeError(val) + } +} + +impl From for BlockHeaderReadError { + fn from(val: BlockSizeError) -> Self { + Self::BlockSizeError(val) + } +} + +#[derive(Debug)] #[non_exhaustive] pub enum BlockTypeError { - #[display( - fmt = "Invalid Blocktype number. Is: {num} Should be one of: 0, 1, 2, 3 (3 is reserved though" - )] InvalidBlocktypeNumber { num: u8 }, } -#[derive(Debug, derive_more::Display)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[cfg(feature = "std")] +impl std::error::Error for BlockTypeError {} + +impl core::fmt::Display for BlockTypeError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + BlockTypeError::InvalidBlocktypeNumber { num } => { + write!(f, + "Invalid Blocktype number. Is: {} Should be one of: 0, 1, 2, 3 (3 is reserved though", + num, + ) + } + } + } +} + +#[derive(Debug)] #[non_exhaustive] pub enum BlockSizeError { - #[display( - fmt = "Blocksize was bigger than the absolute maximum {ABSOLUTE_MAXIMUM_BLOCK_SIZE} (128kb). Is: {size}" - )] BlockSizeTooLarge { size: u32 }, } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[cfg(feature = "std")] +impl std::error::Error for BlockSizeError {} + +impl core::fmt::Display for BlockSizeError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + BlockSizeError::BlockSizeTooLarge { size } => { + write!( + f, + "Blocksize was bigger than the absolute maximum {} (128kb). Is: {}", + ABSOLUTE_MAXIMUM_BLOCK_SIZE, size, + ) + } + } + } +} + +#[derive(Debug)] #[non_exhaustive] pub enum DecompressBlockError { - #[display(fmt = "Error while reading the block content: {_0}")] - #[from] BlockContentReadError(io::Error), - #[display( - fmt = "Malformed section header. Says literals would be this long: {expected_len} but there are only {remaining_bytes} bytes left" - )] MalformedSectionHeader { expected_len: usize, remaining_bytes: usize, }, - #[display(fmt = "{_0:?}")] - #[from] DecompressLiteralsError(DecompressLiteralsError), - #[display(fmt = "{_0:?}")] - #[from] LiteralsSectionParseError(LiteralsSectionParseError), - #[display(fmt = "{_0:?}")] - #[from] SequencesHeaderParseError(SequencesHeaderParseError), - #[display(fmt = "{_0:?}")] - #[from] DecodeSequenceError(DecodeSequenceError), - #[display(fmt = "{_0:?}")] - #[from] ExecuteSequencesError(ExecuteSequencesError), } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[cfg(feature = "std")] +impl std::error::Error for DecompressBlockError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DecompressBlockError::BlockContentReadError(source) => Some(source), + DecompressBlockError::DecompressLiteralsError(source) => Some(source), + DecompressBlockError::LiteralsSectionParseError(source) => Some(source), + DecompressBlockError::SequencesHeaderParseError(source) => Some(source), + DecompressBlockError::DecodeSequenceError(source) => Some(source), + DecompressBlockError::ExecuteSequencesError(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for DecompressBlockError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + DecompressBlockError::BlockContentReadError(e) => { + write!(f, "Error while reading the block content: {}", e) + } + DecompressBlockError::MalformedSectionHeader { + expected_len, + remaining_bytes, + } => { + write!(f, + "Malformed section header. Says literals would be this long: {} but there are only {} bytes left", + expected_len, + remaining_bytes, + ) + } + DecompressBlockError::DecompressLiteralsError(e) => write!(f, "{:?}", e), + DecompressBlockError::LiteralsSectionParseError(e) => write!(f, "{:?}", e), + DecompressBlockError::SequencesHeaderParseError(e) => write!(f, "{:?}", e), + DecompressBlockError::DecodeSequenceError(e) => write!(f, "{:?}", e), + DecompressBlockError::ExecuteSequencesError(e) => write!(f, "{:?}", e), + } + } +} + +impl From for DecompressBlockError { + fn from(val: io::Error) -> Self { + Self::BlockContentReadError(val) + } +} + +impl From for DecompressBlockError { + fn from(val: DecompressLiteralsError) -> Self { + Self::DecompressLiteralsError(val) + } +} + +impl From for DecompressBlockError { + fn from(val: LiteralsSectionParseError) -> Self { + Self::LiteralsSectionParseError(val) + } +} + +impl From for DecompressBlockError { + fn from(val: SequencesHeaderParseError) -> Self { + Self::SequencesHeaderParseError(val) + } +} + +impl From for DecompressBlockError { + fn from(val: DecodeSequenceError) -> Self { + Self::DecodeSequenceError(val) + } +} + +impl From for DecompressBlockError { + fn from(val: ExecuteSequencesError) -> Self { + Self::ExecuteSequencesError(val) + } +} + +#[derive(Debug)] #[non_exhaustive] pub enum DecodeBlockContentError { - #[display(fmt = "Can't decode next block if failed along the way. Results will be nonsense")] DecoderStateIsFailed, - #[display( - fmt = "Cant decode next block body, while expecting to decode the header of the previous block. Results will be nonsense" - )] ExpectedHeaderOfPreviousBlock, - #[display(fmt = "Error while reading bytes for {step}: {source}")] ReadError { step: BlockType, source: io::Error }, - #[display(fmt = "{_0:?}")] - #[from] DecompressBlockError(DecompressBlockError), } +#[cfg(feature = "std")] +impl std::error::Error for DecodeBlockContentError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DecodeBlockContentError::ReadError { step: _, source } => Some(source), + DecodeBlockContentError::DecompressBlockError(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for DecodeBlockContentError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + DecodeBlockContentError::DecoderStateIsFailed => { + write!( + f, + "Can't decode next block if failed along the way. Results will be nonsense", + ) + } + DecodeBlockContentError::ExpectedHeaderOfPreviousBlock => { + write!(f, + "Can't decode next block body, while expecting to decode the header of the previous block. Results will be nonsense", + ) + } + DecodeBlockContentError::ReadError { step, source } => { + write!(f, "Error while reading bytes for {}: {}", step, source,) + } + DecodeBlockContentError::DecompressBlockError(e) => write!(f, "{:?}", e), + } + } +} + +impl From for DecodeBlockContentError { + fn from(val: DecompressBlockError) -> Self { + Self::DecompressBlockError(val) + } +} + +/// Create a new [BlockDecoder]. pub fn new() -> BlockDecoder { BlockDecoder { internal_state: DecoderState::ReadyToDecodeNextHeader, @@ -320,14 +475,14 @@ impl BlockDecoder { let decompressed_size = match btype { BlockType::Raw => block_size, BlockType::RLE => block_size, - BlockType::Reserved => 0, //should be catched above, this is an error state + BlockType::Reserved => 0, //should be caught above, this is an error state BlockType::Compressed => 0, //unknown but will be smaller than 128kb (or window_size if that is smaller than 128kb) }; let content_size = match btype { BlockType::Raw => block_size, BlockType::Compressed => block_size, BlockType::RLE => 1, - BlockType::Reserved => 0, //should be catched above, this is an error state + BlockType::Reserved => 0, //should be caught above, this is an error state }; let last_block = self.is_last(); diff --git a/src/decoding/decodebuffer.rs b/src/decoding/decodebuffer.rs index 8fcb98c4..04a43e8b 100644 --- a/src/decoding/decodebuffer.rs +++ b/src/decoding/decodebuffer.rs @@ -5,7 +5,7 @@ use core::hash::Hasher; use super::ringbuffer::RingBuffer; -pub struct Decodebuffer { +pub struct DecodeBuffer { buffer: RingBuffer, pub dict_content: Vec, @@ -15,17 +15,34 @@ pub struct Decodebuffer { pub hash: twox_hash::XxHash64, } -#[derive(Debug, derive_more::Display)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] -pub enum DecodebufferError { - #[display(fmt = "Need {need} bytes from the dictionary but it is only {got} bytes long")] +pub enum DecodeBufferError { NotEnoughBytesInDictionary { got: usize, need: usize }, - #[display(fmt = "offset: {offset} bigger than buffer: {buf_len}")] OffsetTooBig { offset: usize, buf_len: usize }, } -impl Read for Decodebuffer { +#[cfg(feature = "std")] +impl std::error::Error for DecodeBufferError {} + +impl core::fmt::Display for DecodeBufferError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + DecodeBufferError::NotEnoughBytesInDictionary { got, need } => { + write!( + f, + "Need {} bytes from the dictionary but it is only {} bytes long", + need, got, + ) + } + DecodeBufferError::OffsetTooBig { offset, buf_len } => { + write!(f, "offset: {} bigger than buffer: {}", offset, buf_len,) + } + } + } +} + +impl Read for DecodeBuffer { fn read(&mut self, target: &mut [u8]) -> Result { let max_amount = self.can_drain_to_window_size().unwrap_or(0); let amount = max_amount.min(target.len()); @@ -40,9 +57,9 @@ impl Read for Decodebuffer { } } -impl Decodebuffer { - pub fn new(window_size: usize) -> Decodebuffer { - Decodebuffer { +impl DecodeBuffer { + pub fn new(window_size: usize) -> DecodeBuffer { + DecodeBuffer { buffer: RingBuffer::new(), dict_content: Vec::new(), window_size, @@ -77,7 +94,7 @@ impl Decodebuffer { self.total_output_counter += data.len() as u64; } - pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodebufferError> { + pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> { if offset > self.buffer.len() { self.repeat_from_dict(offset, match_length) } else { @@ -146,13 +163,13 @@ impl Decodebuffer { &mut self, offset: usize, match_length: usize, - ) -> Result<(), DecodebufferError> { + ) -> Result<(), DecodeBufferError> { if self.total_output_counter <= self.window_size as u64 { // at least part of that repeat is from the dictionary content let bytes_from_dict = offset - self.buffer.len(); if bytes_from_dict > self.dict_content.len() { - return Err(DecodebufferError::NotEnoughBytesInDictionary { + return Err(DecodeBufferError::NotEnoughBytesInDictionary { got: self.dict_content.len(), need: bytes_from_dict, }); @@ -172,14 +189,14 @@ impl Decodebuffer { } Ok(()) } else { - Err(DecodebufferError::OffsetTooBig { + Err(DecodeBufferError::OffsetTooBig { offset, buf_len: self.buffer.len(), }) } } - // Check if and how many bytes can currently be drawn from the buffer + /// Check if and how many bytes can currently be drawn from the buffer pub fn can_drain_to_window_size(&self) -> Option { if self.buffer.len() > self.window_size { Some(self.buffer.len() - self.window_size) @@ -193,8 +210,8 @@ impl Decodebuffer { self.buffer.len() } - //drain as much as possible while retaining enough so that decoding si still possible with the required window_size - //At best call only if can_drain_to_window_size reports a 'high' number of bytes to reduce allocations + /// Drain as much as possible while retaining enough so that decoding si still possible with the required window_size + /// At best call only if can_drain_to_window_size reports a 'high' number of bytes to reduce allocations pub fn drain_to_window_size(&mut self) -> Option> { //TODO investigate if it is possible to return the std::vec::Drain iterator directly without collecting here match self.can_drain_to_window_size() { @@ -221,7 +238,7 @@ impl Decodebuffer { } } - //drain the buffer completely + /// drain the buffer completely pub fn drain(&mut self) -> Vec { let (slice1, slice2) = self.buffer.as_slices(); #[cfg(feature = "hash")] @@ -333,7 +350,7 @@ fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error #[cfg(test)] mod tests { - use super::Decodebuffer; + use super::DecodeBuffer; use crate::io::{Error, ErrorKind, Write}; extern crate std; @@ -368,7 +385,7 @@ mod tests { write_len: 10, }; - let mut decode_buf = Decodebuffer::new(100); + let mut decode_buf = DecodeBuffer::new(100); decode_buf.push(b"0123456789"); decode_buf.repeat(10, 90).unwrap(); let repeats = 1000; @@ -418,7 +435,7 @@ mod tests { block_every: 5, }; - let mut decode_buf = Decodebuffer::new(100); + let mut decode_buf = DecodeBuffer::new(100); decode_buf.push(b"0123456789"); decode_buf.repeat(10, 90).unwrap(); let repeats = 1000; diff --git a/src/decoding/dictionary.rs b/src/decoding/dictionary.rs index b86678e1..2d930c61 100644 --- a/src/decoding/dictionary.rs +++ b/src/decoding/dictionary.rs @@ -6,35 +6,90 @@ use crate::decoding::scratch::HuffmanScratch; use crate::fse::FSETableError; use crate::huff0::HuffmanTableError; +/// Zstandard includes support for "raw content" dictionaries, that store bytes optionally used +/// during sequence execution. +/// +/// pub struct Dictionary { + /// A 4 byte value used by decoders to check if they can use + /// the correct dictionary. This value must not be zero. pub id: u32, + /// A dictionary can contain an entropy table, either FSE or + /// Huffman. pub fse: FSEScratch, + /// A dictionary can contain an entropy table, either FSE or + /// Huffman. pub huf: HuffmanScratch, + /// The content of a dictionary acts as a "past" in front of data + /// to compress or decompress, + /// so it can be referenced in sequence commands. + /// As long as the amount of data decoded from this frame is less than or + /// equal to Window_Size, sequence commands may specify offsets longer than + /// the total length of decoded output so far to reference back to the + /// dictionary, even parts of the dictionary with offsets larger than Window_Size. + /// After the total output has surpassed Window_Size however, + /// this is no longer allowed and the dictionary is no longer accessible pub dict_content: Vec, + /// The 3 most recent offsets are stored so that they can be used + /// during sequence execution, see + /// + /// for more. pub offset_hist: [u32; 3], } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum DictionaryDecodeError { - #[display( - fmt = "Bad magic_num at start of the dictionary; Got: {got:#04X?}, Expected: {MAGIC_NUM:#04x?}" - )] BadMagicNum { got: [u8; 4] }, - #[display(fmt = "{_0:?}")] - #[from] FSETableError(FSETableError), - #[display(fmt = "{_0:?}")] - #[from] HuffmanTableError(HuffmanTableError), } +#[cfg(feature = "std")] +impl std::error::Error for DictionaryDecodeError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DictionaryDecodeError::FSETableError(source) => Some(source), + DictionaryDecodeError::HuffmanTableError(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for DictionaryDecodeError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + DictionaryDecodeError::BadMagicNum { got } => { + write!( + f, + "Bad magic_num at start of the dictionary; Got: {:#04X?}, Expected: {:#04x?}", + got, MAGIC_NUM, + ) + } + DictionaryDecodeError::FSETableError(e) => write!(f, "{:?}", e), + DictionaryDecodeError::HuffmanTableError(e) => write!(f, "{:?}", e), + } + } +} + +impl From for DictionaryDecodeError { + fn from(val: FSETableError) -> Self { + Self::FSETableError(val) + } +} + +impl From for DictionaryDecodeError { + fn from(val: HuffmanTableError) -> Self { + Self::HuffmanTableError(val) + } +} + +/// This 4 byte (little endian) magic number refers to the start of a dictionary pub const MAGIC_NUM: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC]; impl Dictionary { - /// parses the dictionary and set the tables - /// it returns the dict_id for checking with the frame's dict_id + /// Parses the dictionary from `raw` and set the tables + /// it returns the dict_id for checking with the frame's `dict_id`` pub fn decode_dict(raw: &[u8]) -> Result { let mut new_dict = Dictionary { id: 0, diff --git a/src/decoding/literals_section_decoder.rs b/src/decoding/literals_section_decoder.rs index f437a58f..63301a66 100644 --- a/src/decoding/literals_section_decoder.rs +++ b/src/decoding/literals_section_decoder.rs @@ -1,46 +1,114 @@ +//! This module contains the [decompress_literals] function, used to take a +//! parsed literals header and a source and decompress it. + use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType}; use super::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use super::scratch::HuffmanScratch; use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError}; use alloc::vec::Vec; -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum DecompressLiteralsError { - #[display( - fmt = "compressed size was none even though it must be set to something for compressed literals" - )] MissingCompressedSize, - #[display( - fmt = "num_streams was none even though it must be set to something (1 or 4) for compressed literals" - )] MissingNumStreams, - #[display(fmt = "{_0:?}")] - #[from] GetBitsError(GetBitsError), - #[display(fmt = "{_0:?}")] - #[from] HuffmanTableError(HuffmanTableError), - #[display(fmt = "{_0:?}")] - #[from] HuffmanDecoderError(HuffmanDecoderError), - #[display(fmt = "Tried to reuse huffman table but it was never initialized")] UninitializedHuffmanTable, - #[display(fmt = "Need 6 bytes to decode jump header, got {got} bytes")] MissingBytesForJumpHeader { got: usize }, - #[display(fmt = "Need at least {needed} bytes to decode literals. Have: {got} bytes")] MissingBytesForLiterals { got: usize, needed: usize }, - #[display( - fmt = "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption" - )] ExtraPadding { skipped_bits: i32 }, - #[display(fmt = "Bitstream was read till: {read_til}, should have been: {expected}")] BitstreamReadMismatch { read_til: isize, expected: isize }, - #[display(fmt = "Did not decode enough literals: {decoded}, Should have been: {expected}")] DecodedLiteralCountMismatch { decoded: usize, expected: usize }, } +#[cfg(feature = "std")] +impl std::error::Error for DecompressLiteralsError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DecompressLiteralsError::GetBitsError(source) => Some(source), + DecompressLiteralsError::HuffmanTableError(source) => Some(source), + DecompressLiteralsError::HuffmanDecoderError(source) => Some(source), + _ => None, + } + } +} +impl core::fmt::Display for DecompressLiteralsError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + DecompressLiteralsError::MissingCompressedSize => { + write!(f, + "compressed size was none even though it must be set to something for compressed literals", + ) + } + DecompressLiteralsError::MissingNumStreams => { + write!(f, + "num_streams was none even though it must be set to something (1 or 4) for compressed literals", + ) + } + DecompressLiteralsError::GetBitsError(e) => write!(f, "{:?}", e), + DecompressLiteralsError::HuffmanTableError(e) => write!(f, "{:?}", e), + DecompressLiteralsError::HuffmanDecoderError(e) => write!(f, "{:?}", e), + DecompressLiteralsError::UninitializedHuffmanTable => { + write!( + f, + "Tried to reuse huffman table but it was never initialized", + ) + } + DecompressLiteralsError::MissingBytesForJumpHeader { got } => { + write!(f, "Need 6 bytes to decode jump header, got {} bytes", got,) + } + DecompressLiteralsError::MissingBytesForLiterals { got, needed } => { + write!( + f, + "Need at least {} bytes to decode literals. Have: {} bytes", + needed, got, + ) + } + DecompressLiteralsError::ExtraPadding { skipped_bits } => { + write!(f, + "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption", + skipped_bits, + ) + } + DecompressLiteralsError::BitstreamReadMismatch { read_til, expected } => { + write!( + f, + "Bitstream was read till: {}, should have been: {}", + read_til, expected, + ) + } + DecompressLiteralsError::DecodedLiteralCountMismatch { decoded, expected } => { + write!( + f, + "Did not decode enough literals: {}, Should have been: {}", + decoded, expected, + ) + } + } + } +} + +impl From for DecompressLiteralsError { + fn from(val: HuffmanDecoderError) -> Self { + Self::HuffmanDecoderError(val) + } +} + +impl From for DecompressLiteralsError { + fn from(val: GetBitsError) -> Self { + Self::GetBitsError(val) + } +} + +impl From for DecompressLiteralsError { + fn from(val: HuffmanTableError) -> Self { + Self::HuffmanTableError(val) + } +} + +/// Decode and decompress the provided literals section into `target`, returning the number of bytes read. pub fn decode_literals( section: &LiteralsSection, scratch: &mut HuffmanScratch, @@ -65,6 +133,10 @@ pub fn decode_literals( } } +/// Decompress the provided literals section and source into the provided `target`. +/// This function is used when the literals section is `Compressed` or `Treeless` +/// +/// Returns the number of bytes read. fn decompress_literals( section: &LiteralsSection, scratch: &mut HuffmanScratch, diff --git a/src/decoding/mod.rs b/src/decoding/mod.rs index b89df351..a9f9b7ae 100644 --- a/src/decoding/mod.rs +++ b/src/decoding/mod.rs @@ -1,3 +1,6 @@ +//! Structures and utilities used for reading from data, decoding that data +//! and storing the output. + pub mod bit_reader; pub mod bit_reader_reverse; pub mod block_decoder; diff --git a/src/decoding/ringbuffer.rs b/src/decoding/ringbuffer.rs index 8303edf5..e364d902 100644 --- a/src/decoding/ringbuffer.rs +++ b/src/decoding/ringbuffer.rs @@ -37,16 +37,19 @@ impl RingBuffer { } } + /// Return the number of bytes in the buffer. pub fn len(&self) -> usize { let (x, y) = self.data_slice_lengths(); x + y } + /// Return the amount of available space (in bytes) of the buffer. pub fn free(&self) -> usize { let (x, y) = self.free_slice_lengths(); (x + y).saturating_sub(1) } + /// Empty the buffer and reset the head and tail. pub fn clear(&mut self) { // SAFETY: Upholds invariant 2, trivially // SAFETY: Upholds invariant 3; 0 is always valid @@ -54,10 +57,12 @@ impl RingBuffer { self.tail = 0; } + /// Whether the buffer is empty pub fn is_empty(&self) -> bool { self.head == self.tail } + /// Ensure that there's space for `amount` elements in the buffer. pub fn reserve(&mut self, amount: usize) { let free = self.free(); if free >= amount { @@ -131,6 +136,8 @@ impl RingBuffer { self.tail = (self.tail + 1) % self.cap; } + /// Fetch the byte stored at the selected index from the buffer, returning it, or + /// `None` if the index is out of bounds. #[allow(dead_code)] pub fn get(&self, idx: usize) -> Option { if idx < self.len() { @@ -142,7 +149,7 @@ impl RingBuffer { None } } - + /// Append the provided data to the end of `self`. pub fn extend(&mut self, data: &[u8]) { let len = data.len(); let ptr = data.as_ptr(); @@ -178,6 +185,8 @@ impl RingBuffer { self.tail = (self.tail + len) % self.cap; } + /// Advance head past `amount` elements, effectively removing + /// them from the buffer. pub fn drop_first_n(&mut self, amount: usize) { debug_assert!(amount <= self.len()); let amount = usize::min(amount, self.len()); @@ -186,6 +195,8 @@ impl RingBuffer { self.head = (self.head + amount) % self.cap; } + /// Return the size of the two contiguous occupied sections of memory used + /// by the buffer. // SAFETY: other code relies on this pointing to initialized halves of the buffer only fn data_slice_lengths(&self) -> (usize, usize) { let len_after_head; @@ -203,6 +214,7 @@ impl RingBuffer { } // SAFETY: other code relies on this pointing to initialized halves of the buffer only + /// Return pointers to the head and tail, and the length of each section. fn data_slice_parts(&self) -> ((*const u8, usize), (*const u8, usize)) { let (len_after_head, len_to_tail) = self.data_slice_lengths(); @@ -211,6 +223,8 @@ impl RingBuffer { (self.buf.as_ptr(), len_to_tail), ) } + + /// Return references to each part of the ring buffer. pub fn as_slices(&self) -> (&[u8], &[u8]) { let (s1, s2) = self.data_slice_parts(); unsafe { @@ -223,6 +237,7 @@ impl RingBuffer { // SAFETY: other code relies on this producing the lengths of free zones // at the beginning/end of the buffer. Everything else must be initialized + /// Returns the size of the two unoccupied sections of memory used by the buffer. fn free_slice_lengths(&self) -> (usize, usize) { let len_to_head; let len_after_tail; @@ -238,6 +253,8 @@ impl RingBuffer { (len_to_head, len_after_tail) } + /// Returns mutable references to the available space and the size of that available space, + /// for the two sections in the buffer. // SAFETY: Other code relies on this pointing to the free zones, data after the first and before the second must // be valid fn free_slice_parts(&self) -> ((*mut u8, usize), (*mut u8, usize)) { @@ -249,6 +266,7 @@ impl RingBuffer { ) } + /// Copies elements from the provided range to the end of the buffer. #[allow(dead_code)] pub fn extend_from_within(&mut self, start: usize, len: usize) { if start + len > self.len() { @@ -268,6 +286,9 @@ impl RingBuffer { unsafe { self.extend_from_within_unchecked(start, len) } } + /// Copies data from the provided range to the end of the buffer, without + /// first verifying that the unoccupied capacity is available. + /// /// SAFETY: /// For this to be safe two requirements need to hold: /// 1. start + len <= self.len() so we do not copy uninitialised memory @@ -326,6 +347,9 @@ impl RingBuffer { } #[allow(dead_code)] + /// This function is functionally the same as [RingBuffer::extend_from_within_unchecked], + /// but it does not contain any branching operations. + /// /// SAFETY: /// Needs start + len <= self.len() /// And more then len reserved space diff --git a/src/decoding/scratch.rs b/src/decoding/scratch.rs index 29f24f73..20374e37 100644 --- a/src/decoding/scratch.rs +++ b/src/decoding/scratch.rs @@ -1,5 +1,7 @@ +//! Structures that wrap around various decoders to make decoding easier. + use super::super::blocks::sequence_section::Sequence; -use super::decodebuffer::Decodebuffer; +use super::decodebuffer::DecodeBuffer; use crate::decoding::dictionary::Dictionary; use crate::fse::FSETable; use crate::huff0::HuffmanTable; @@ -9,10 +11,14 @@ use crate::blocks::sequence_section::{ MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE, }; +/// A block level decoding buffer. pub struct DecoderScratch { + /// The decoder used for Huffman blocks. pub huf: HuffmanScratch, + /// The decoder used for FSE blocks. pub fse: FSEScratch, - pub buffer: Decodebuffer, + + pub buffer: DecodeBuffer, pub offset_hist: [u32; 3], pub literals_buffer: Vec, @@ -34,7 +40,7 @@ impl DecoderScratch { match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE), ml_rle: None, }, - buffer: Decodebuffer::new(window_size), + buffer: DecodeBuffer::new(window_size), offset_hist: [1, 4, 8], block_content_buffer: Vec::new(), diff --git a/src/decoding/sequence_execution.rs b/src/decoding/sequence_execution.rs index cc03c6ec..1a212284 100644 --- a/src/decoding/sequence_execution.rs +++ b/src/decoding/sequence_execution.rs @@ -1,18 +1,50 @@ -use super::{decodebuffer::DecodebufferError, scratch::DecoderScratch}; +use super::{decodebuffer::DecodeBufferError, scratch::DecoderScratch}; -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum ExecuteSequencesError { - #[display(fmt = "{_0:?}")] - #[from] - DecodebufferError(DecodebufferError), - #[display(fmt = "Sequence wants to copy up to byte {wanted}. Bytes in literalsbuffer: {have}")] + DecodebufferError(DecodeBufferError), NotEnoughBytesForSequence { wanted: usize, have: usize }, - #[display(fmt = "Illegal offset: 0 found")] ZeroOffset, } +impl core::fmt::Display for ExecuteSequencesError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ExecuteSequencesError::DecodebufferError(e) => { + write!(f, "{:?}", e) + } + ExecuteSequencesError::NotEnoughBytesForSequence { wanted, have } => { + write!( + f, + "Sequence wants to copy up to byte {}. Bytes in literalsbuffer: {}", + wanted, have + ) + } + ExecuteSequencesError::ZeroOffset => { + write!(f, "Illegal offset: 0 found") + } + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ExecuteSequencesError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + ExecuteSequencesError::DecodebufferError(source) => Some(source), + _ => None, + } + } +} + +impl From for ExecuteSequencesError { + fn from(val: DecodeBufferError) -> Self { + Self::DecodebufferError(val) + } +} + +/// Take the provided decoder and execute the sequences stored within pub fn execute_sequences(scratch: &mut DecoderScratch) -> Result<(), ExecuteSequencesError> { let mut literals_copy_counter = 0; let old_buffer_size = scratch.buffer.len(); @@ -64,6 +96,9 @@ pub fn execute_sequences(scratch: &mut DecoderScratch) -> Result<(), ExecuteSequ Ok(()) } +/// Update the most recently used offsets to reflect the provided offset value, and return the +/// "actual" offset needed because offsets are not stored in a raw way, some transformations are needed +/// before you get a functional number. fn do_offset_history(offset_value: u32, lit_len: u32, scratch: &mut [u32; 3]) -> u32 { let actual_offset = if lit_len > 0 { match offset_value { diff --git a/src/decoding/sequence_section_decoder.rs b/src/decoding/sequence_section_decoder.rs index 107d734a..e7cc1944 100644 --- a/src/decoding/sequence_section_decoder.rs +++ b/src/decoding/sequence_section_decoder.rs @@ -9,41 +9,99 @@ use crate::blocks::sequence_section::{ use crate::fse::{FSEDecoder, FSEDecoderError, FSETableError}; use alloc::vec::Vec; -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum DecodeSequenceError { - #[display(fmt = "{_0:?}")] - #[from] GetBitsError(GetBitsError), - #[display(fmt = "{_0:?}")] - #[from] FSEDecoderError(FSEDecoderError), - #[display(fmt = "{_0:?}")] - #[from] FSETableError(FSETableError), - #[display( - fmt = "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption" - )] ExtraPadding { skipped_bits: i32 }, - #[display(fmt = "Do not support offsets bigger than 1<<32; got: {offset_code}")] UnsupportedOffset { offset_code: u8 }, - #[display(fmt = "Read an offset == 0. That is an illegal value for offsets")] ZeroOffset, - #[display(fmt = "Bytestream did not contain enough bytes to decode num_sequences")] NotEnoughBytesForNumSequences, - #[display(fmt = "Did not use full bitstream. Bits left: {bits_remaining} ({} bytes)", bits_remaining / 8)] ExtraBits { bits_remaining: isize }, - #[display(fmt = "compression modes are none but they must be set to something")] MissingCompressionMode, - #[display(fmt = "Need a byte to read for RLE ll table")] MissingByteForRleLlTable, - #[display(fmt = "Need a byte to read for RLE of table")] MissingByteForRleOfTable, - #[display(fmt = "Need a byte to read for RLE ml table")] MissingByteForRleMlTable, } +#[cfg(feature = "std")] +impl std::error::Error for DecodeSequenceError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DecodeSequenceError::GetBitsError(source) => Some(source), + DecodeSequenceError::FSEDecoderError(source) => Some(source), + DecodeSequenceError::FSETableError(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for DecodeSequenceError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + DecodeSequenceError::GetBitsError(e) => write!(f, "{:?}", e), + DecodeSequenceError::FSEDecoderError(e) => write!(f, "{:?}", e), + DecodeSequenceError::FSETableError(e) => write!(f, "{:?}", e), + DecodeSequenceError::ExtraPadding { skipped_bits } => { + write!(f, + "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption", + skipped_bits, + ) + } + DecodeSequenceError::UnsupportedOffset { offset_code } => { + write!( + f, + "Do not support offsets bigger than 1<<32; got: {}", + offset_code, + ) + } + DecodeSequenceError::ZeroOffset => write!( + f, + "Read an offset == 0. That is an illegal value for offsets" + ), + DecodeSequenceError::NotEnoughBytesForNumSequences => write!( + f, + "Bytestream did not contain enough bytes to decode num_sequences" + ), + DecodeSequenceError::ExtraBits { bits_remaining } => write!(f, "{}", bits_remaining), + DecodeSequenceError::MissingCompressionMode => write!( + f, + "compression modes are none but they must be set to something" + ), + DecodeSequenceError::MissingByteForRleLlTable => { + write!(f, "Need a byte to read for RLE ll table") + } + DecodeSequenceError::MissingByteForRleOfTable => { + write!(f, "Need a byte to read for RLE of table") + } + DecodeSequenceError::MissingByteForRleMlTable => { + write!(f, "Need a byte to read for RLE ml table") + } + } + } +} + +impl From for DecodeSequenceError { + fn from(val: GetBitsError) -> Self { + Self::GetBitsError(val) + } +} + +impl From for DecodeSequenceError { + fn from(val: FSETableError) -> Self { + Self::FSETableError(val) + } +} + +impl From for DecodeSequenceError { + fn from(val: FSEDecoderError) -> Self { + Self::FSEDecoderError(val) + } +} + +/// Decode the provided source as a series of sequences into the supplied `target`. pub fn decode_sequences( section: &SequencesHeader, source: &[u8], @@ -253,6 +311,10 @@ fn decode_sequences_without_rle( } } +/// Look up the provided state value from a literal length table predefined +/// by the Zstandard reference document. Returns a tuple of (value, number of bits). +/// +/// fn lookup_ll_code(code: u8) -> (u32, u8) { match code { 0..=15 => (u32::from(code), 0), @@ -280,6 +342,10 @@ fn lookup_ll_code(code: u8) -> (u32, u8) { } } +/// Look up the provided state value from a match length table predefined +/// by the Zstandard reference document. Returns a tuple of (value, number of bits). +/// +/// fn lookup_ml_code(code: u8) -> (u32, u8) { match code { 0..=31 => (u32::from(code) + 3, 0), @@ -308,8 +374,12 @@ fn lookup_ml_code(code: u8) -> (u32, u8) { } } +// This info is buried in the symbol compression mode table +/// "The maximum allowed accuracy log for literals length and match length tables is 9" pub const LL_MAX_LOG: u8 = 9; +/// "The maximum allowed accuracy log for literals length and match length tables is 9" pub const ML_MAX_LOG: u8 = 9; +/// "The maximum accuracy log for the offset table is 8." pub const OF_MAX_LOG: u8 = 8; fn maybe_update_fse_tables( @@ -430,19 +500,34 @@ fn maybe_update_fse_tables( Ok(bytes_read) } +// The default Literal Length decoding table uses an accuracy logarithm of 6 bits. const LL_DEFAULT_ACC_LOG: u8 = 6; +/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding +/// table is generated using a predefined distribution table. +/// +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals-length const LITERALS_LENGTH_DEFAULT_DISTRIBUTION: [i32; 36] = [ 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1, ]; +// The default Match Length decoding table uses an accuracy logarithm of 6 bits. const ML_DEFAULT_ACC_LOG: u8 = 6; +/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding +/// table is generated using a predefined distribution table. +/// +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#match-length const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i32; 53] = [ 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, ]; +// The default Match Length decoding table uses an accuracy logarithm of 5 bits. const OF_DEFAULT_ACC_LOG: u8 = 5; +/// If [ModeType::Predefined] is selected for a symbol type, its FSE decoding +/// table is generated using a predefined distribution table. +/// +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#match-length const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [ 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, ]; diff --git a/src/frame.rs b/src/frame.rs index eabfca5d..b7a23498 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -1,51 +1,135 @@ use crate::io::{Error, Read}; +use core::fmt; +#[cfg(feature = "std")] +use std::error::Error as StdError; + +/// This magic number is included at the start of a single Zstandard frame pub const MAGIC_NUM: u32 = 0xFD2F_B528; +/// The minimum window size is defined as 1 KB pub const MIN_WINDOW_SIZE: u64 = 1024; +/// The maximum window size is 3.75TB pub const MAX_WINDOW_SIZE: u64 = (1 << 41) + 7 * (1 << 38); +/// Zstandard compressed data is made of one or more [Frame]s. Each frame is independent and can be +/// decompressed independently of other frames. +/// +/// There are two frame formats defined by Zstandard: Zstandard frames and Skippable frames. +/// Zstandard frames contain compressed data, while skippable frames contain custom user metadata. +/// +/// This structure contains the header of the frame. +/// +/// pub struct Frame { pub header: FrameHeader, } +/// A frame header has a variable size, with a minimum of 2 bytes, and a maximum of 14 bytes. pub struct FrameHeader { pub descriptor: FrameDescriptor, + /// The `Window_Descriptor` field contains the minimum size of a memory buffer needed to + /// decompress the entire frame. + /// + /// This byte is not included in the frame header when the `Single_Segment_flag` is set. + /// + /// Bits 7-3 refer to the `Exponent`, where bits 2-0 refer to the `Mantissa`. + /// + /// To determine the size of a window, the following formula can be used: + /// ```text + /// windowLog = 10 + Exponent; + /// windowBase = 1 << windowLog; + /// windowAdd = (windowBase / 8) * Mantissa; + /// Window_Size = windowBase + windowAdd; + /// ``` + /// window_descriptor: u8, + /// The `Dictionary_ID` field contains the ID of the dictionary to be used to decode the frame. + /// When this value is not present, it's up to the decoder to know which dictionary to use. dict_id: Option, + /// The size of the original/uncompressed content. frame_content_size: u64, } +/// The first byte is called the `Frame Header Descriptor`, and it describes what other fields +/// are present. pub struct FrameDescriptor(u8); -#[derive(Debug, derive_more::Display)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum FrameDescriptorError { - #[display(fmt = "Invalid Frame_Content_Size_Flag; Is: {got}, Should be one of: 0, 1, 2, 3")] InvalidFrameContentSizeFlag { got: u8 }, } +impl fmt::Display for FrameDescriptorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidFrameContentSizeFlag { got } => write!( + f, + "Invalid Frame_Content_Size_Flag; Is: {}, Should be one of: 0, 1, 2, 3", + got + ), + } + } +} + +#[cfg(feature = "std")] +impl StdError for FrameDescriptorError {} + impl FrameDescriptor { + /// Read the `Frame_Content_Size_flag` from the frame header descriptor. + /// + /// This is a 2 bit flag, specifying if the `Frame_Content_Size` field is present + /// within the header. It notates the number of bytes used by `Frame_Content_size` + /// + /// When this value is is 0, `FCS_Field_Size` depends on Single_Segment_flag. + /// If the `Single_Segment_flag` field is set in the frame header descriptor, + /// the size of the `Frame_Content_Size` field of the header is 1 byte. + /// Otherwise, `FCS_Field_Size` is 0, and the `Frame_Content_Size` is not provided. + /// + /// | Flag Value (decimal) | Size of the `Frame_Content_Size` field in bytes | + /// | -- | -- | + /// | 0 | 0 or 1 (see above) | + /// | 1 | 2 | + /// | 2 | 4 | + /// | 3 | 8 | pub fn frame_content_size_flag(&self) -> u8 { self.0 >> 6 } + /// This bit is reserved for some future feature, a compliant decoder **must ensure** + /// that this value is set to zero. pub fn reserved_flag(&self) -> bool { ((self.0 >> 3) & 0x1) == 1 } + /// If this flag is set, data must be regenerated within a single continuous memory segment. + /// + /// In this case, the `Window_Descriptor` byte is skipped, but `Frame_Content_Size` is present. + /// The decoder must allocate a memory segment equal to or larger than `Frame_Content_Size`. pub fn single_segment_flag(&self) -> bool { ((self.0 >> 5) & 0x1) == 1 } + /// If this flag is set, a 32 bit `Content_Checksum` will be present at the end of the frame. pub fn content_checksum_flag(&self) -> bool { ((self.0 >> 2) & 0x1) == 1 } + /// This is a two bit flag telling if a dictionary ID is provided within the header. It also + /// specifies the size of this field + /// + /// | Value (Decimal) | `DID_Field_Size` (bytes) | + /// | -- | -- | + /// | 0 | 0 | + /// | 1 | 1 | + /// | 2 | 2 | + /// | 3 | 4 | pub fn dict_id_flag(&self) -> u8 { self.0 & 0x3 } - // Deriving info from the flags + /// Read the size of the `Frame_Content_size` field from the frame header descriptor, returning + /// the size in bytes. + /// If this value is zero, then the `Frame_Content_Size` field is not present within the header. pub fn frame_content_size_bytes(&self) -> Result { match self.frame_content_size_flag() { 0 => { @@ -62,6 +146,9 @@ impl FrameDescriptor { } } + /// Read the size of the `Dictionary_ID` field from the frame header descriptor, returning the size in bytes. + /// If this value is zero, then the dictionary id is not present within the header, + /// and "It's up to the decoder to know which dictionary to use." pub fn dictionary_id_bytes(&self) -> Result { match self.dict_id_flag() { 0 => Ok(0), @@ -73,34 +160,70 @@ impl FrameDescriptor { } } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum FrameHeaderError { - #[display( - fmt = "window_size bigger than allowed maximum. Is: {got}, Should be lower than: {MAX_WINDOW_SIZE}" - )] WindowTooBig { got: u64 }, - #[display( - fmt = "window_size smaller than allowed minimum. Is: {got}, Should be greater than: {MIN_WINDOW_SIZE}" - )] WindowTooSmall { got: u64 }, - #[display(fmt = "{_0:?}")] - #[from] FrameDescriptorError(FrameDescriptorError), - #[display(fmt = "Not enough bytes in dict_id. Is: {got}, Should be: {expected}")] DictIdTooSmall { got: usize, expected: usize }, - #[display( - fmt = "frame_content_size does not have the right length. Is: {got}, Should be: {expected}" - )] MismatchedFrameSize { got: usize, expected: u8 }, - #[display(fmt = "frame_content_size was zero")] FrameSizeIsZero, - #[display(fmt = "Invalid frame_content_size. Is: {got}, Should be one of 1, 2, 4, 8 bytes")] InvalidFrameSize { got: u8 }, } +impl fmt::Display for FrameHeaderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::WindowTooBig { got } => write!( + f, + "window_size bigger than allowed maximum. Is: {}, Should be lower than: {}", + got, MAX_WINDOW_SIZE + ), + Self::WindowTooSmall { got } => write!( + f, + "window_size smaller than allowed minimum. Is: {}, Should be greater than: {}", + got, MIN_WINDOW_SIZE + ), + Self::FrameDescriptorError(e) => write!(f, "{:?}", e), + Self::DictIdTooSmall { got, expected } => write!( + f, + "Not enough bytes in dict_id. Is: {}, Should be: {}", + got, expected + ), + Self::MismatchedFrameSize { got, expected } => write!( + f, + "frame_content_size does not have the right length. Is: {}, Should be: {}", + got, expected + ), + Self::FrameSizeIsZero => write!(f, "frame_content_size was zero"), + Self::InvalidFrameSize { got } => write!( + f, + "Invalid frame_content_size. Is: {}, Should be one of 1, 2, 4, 8 bytes", + got + ), + } + } +} + +#[cfg(feature = "std")] +impl StdError for FrameHeaderError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + FrameHeaderError::FrameDescriptorError(source) => Some(source), + _ => None, + } + } +} + +impl From for FrameHeaderError { + fn from(error: FrameDescriptorError) -> Self { + Self::FrameDescriptorError(error) + } +} + impl FrameHeader { + /// Read the size of the window from the header, returning the size in bytes. pub fn window_size(&self) -> Result { if self.descriptor.single_segment_flag() { Ok(self.frame_content_size()) @@ -126,40 +249,80 @@ impl FrameHeader { } } + /// The ID (if provided) of the dictionary required to decode this frame. pub fn dictionary_id(&self) -> Option { self.dict_id } + /// Obtain the uncompressed size (in bytes) of the frame contents. pub fn frame_content_size(&self) -> u64 { self.frame_content_size } } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum ReadFrameHeaderError { - #[display(fmt = "Error while reading magic number: {_0}")] MagicNumberReadError(Error), - #[display(fmt = "Read wrong magic number: 0x{_0:X}")] - BadMagicNumber(#[cfg_attr(feature = "std", error(ignore))] u32), - #[display(fmt = "Error while reading frame descriptor: {_0}")] + BadMagicNumber(u32), FrameDescriptorReadError(Error), - #[display(fmt = "{_0:?}")] - #[from] InvalidFrameDescriptor(FrameDescriptorError), - #[display(fmt = "Error while reading window descriptor: {_0}")] WindowDescriptorReadError(Error), - #[display(fmt = "Error while reading dictionary id: {_0}")] DictionaryIdReadError(Error), - #[display(fmt = "Error while reading frame content size: {_0}")] FrameContentSizeReadError(Error), - #[display( - fmt = "SkippableFrame encountered with MagicNumber 0x{magic_number:X} and length {length} bytes" - )] SkipFrame { magic_number: u32, length: u32 }, } +impl fmt::Display for ReadFrameHeaderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MagicNumberReadError(e) => write!(f, "Error while reading magic number: {}", e), + Self::BadMagicNumber(e) => write!(f, "Read wrong magic number: 0x{:X}", e), + Self::FrameDescriptorReadError(e) => { + write!(f, "Error while reading frame descriptor: {}", e) + } + Self::InvalidFrameDescriptor(e) => write!(f, "{:?}", e), + Self::WindowDescriptorReadError(e) => { + write!(f, "Error while reading window descriptor: {}", e) + } + Self::DictionaryIdReadError(e) => write!(f, "Error while reading dictionary id: {}", e), + Self::FrameContentSizeReadError(e) => { + write!(f, "Error while reading frame content size: {}", e) + } + Self::SkipFrame { + magic_number, + length, + } => write!( + f, + "SkippableFrame encountered with MagicNumber 0x{:X} and length {} bytes", + magic_number, length + ), + } + } +} + +#[cfg(feature = "std")] +impl StdError for ReadFrameHeaderError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + ReadFrameHeaderError::MagicNumberReadError(source) => Some(source), + ReadFrameHeaderError::FrameDescriptorReadError(source) => Some(source), + ReadFrameHeaderError::InvalidFrameDescriptor(source) => Some(source), + ReadFrameHeaderError::WindowDescriptorReadError(source) => Some(source), + ReadFrameHeaderError::DictionaryIdReadError(source) => Some(source), + ReadFrameHeaderError::FrameContentSizeReadError(source) => Some(source), + _ => None, + } + } +} + +impl From for ReadFrameHeaderError { + fn from(error: FrameDescriptorError) -> Self { + Self::InvalidFrameDescriptor(error) + } +} + +/// Read a single serialized frame from the reader and return a tuple containing the parsed frame and the number of bytes read. pub fn read_frame_header(mut r: impl Read) -> Result<(Frame, u8), ReadFrameHeaderError> { use ReadFrameHeaderError as err; let mut buf = [0u8; 4]; diff --git a/src/frame_decoder.rs b/src/frame_decoder.rs index 072ebc8a..610ced4e 100644 --- a/src/frame_decoder.rs +++ b/src/frame_decoder.rs @@ -1,3 +1,7 @@ +//! Zstandard compressed data is made of one or more [Frame]s. Each frame is independent and can be +//! decompressed independently of other frames. This module contains structures +//! and utilities that can be used to decode a frame. + use super::frame; use crate::decoding::dictionary::Dictionary; use crate::decoding::scratch::DecoderScratch; @@ -6,12 +10,14 @@ use crate::io::{Error, Read, Write}; use alloc::collections::BTreeMap; use alloc::vec::Vec; use core::convert::TryInto; +#[cfg(feature = "std")] +use std::error::Error as StdError; /// This implements a decoder for zstd frames. This decoder is able to decode frames only partially and gives control /// over how many bytes/blocks will be decoded at a time (so you don't have to decode a 10GB file into memory all at once). /// It reads bytes as needed from a provided source and can be read from to collect partial results. /// -/// If you want to just read the whole frame with an io::Read without having to deal with manually calling decode_blocks +/// If you want to just read the whole frame with an `io::Read` without having to deal with manually calling [FrameDecoder::decode_blocks] /// you can use the provided StreamingDecoder with wraps this FrameDecoder /// /// Workflow is as follows: @@ -79,46 +85,115 @@ pub enum BlockDecodingStrategy { UptoBytes(usize), } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum FrameDecoderError { - #[display(fmt = "{_0:?}")] - #[from] ReadFrameHeaderError(frame::ReadFrameHeaderError), - #[display(fmt = "{_0:?}")] - #[from] FrameHeaderError(frame::FrameHeaderError), - #[display( - fmt = "Specified window_size is too big; Requested: {requested}, Max: {MAX_WINDOW_SIZE}" - )] WindowSizeTooBig { requested: u64 }, - #[display(fmt = "{_0:?}")] - #[from] DictionaryDecodeError(dictionary::DictionaryDecodeError), - #[display(fmt = "Failed to parse/decode block body: {_0}")] - #[from] FailedToReadBlockHeader(decoding::block_decoder::BlockHeaderReadError), - #[display(fmt = "Failed to parse block header: {_0}")] FailedToReadBlockBody(decoding::block_decoder::DecodeBlockContentError), - #[display(fmt = "Failed to read checksum: {_0}")] FailedToReadChecksum(Error), - #[display(fmt = "Decoder must initialized or reset before using it")] NotYetInitialized, - #[display(fmt = "Decoder encountered error while initializing: {_0}")] FailedToInitialize(frame::FrameHeaderError), - #[display(fmt = "Decoder encountered error while draining the decodebuffer: {_0}")] FailedToDrainDecodebuffer(Error), - #[display( - fmt = "Target must have at least as many bytes as the contentsize of the frame reports" - )] TargetTooSmall, - #[display( - fmt = "Frame header specified dictionary id 0x{dict_id:X} that wasnt provided by add_dict() or reset_with_dict()" - )] DictNotProvided { dict_id: u32 }, } +#[cfg(feature = "std")] +impl StdError for FrameDecoderError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + FrameDecoderError::ReadFrameHeaderError(source) => Some(source), + FrameDecoderError::FrameHeaderError(source) => Some(source), + FrameDecoderError::DictionaryDecodeError(source) => Some(source), + FrameDecoderError::FailedToReadBlockHeader(source) => Some(source), + FrameDecoderError::FailedToReadBlockBody(source) => Some(source), + FrameDecoderError::FailedToReadChecksum(source) => Some(source), + FrameDecoderError::FailedToInitialize(source) => Some(source), + FrameDecoderError::FailedToDrainDecodebuffer(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for FrameDecoderError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> ::core::fmt::Result { + match self { + FrameDecoderError::ReadFrameHeaderError(e) => { + write!(f, "{:?}", e) + } + FrameDecoderError::FrameHeaderError(e) => { + write!(f, "{:?}", e) + } + FrameDecoderError::WindowSizeTooBig { requested } => { + write!( + f, + "Specified window_size is too big; Requested: {}, Max: {}", + requested, MAX_WINDOW_SIZE, + ) + } + FrameDecoderError::DictionaryDecodeError(e) => { + write!(f, "{:?}", e) + } + FrameDecoderError::FailedToReadBlockHeader(e) => { + write!(f, "Failed to parse/decode block body: {}", e) + } + FrameDecoderError::FailedToReadBlockBody(e) => { + write!(f, "Failed to parse block header: {}", e) + } + FrameDecoderError::FailedToReadChecksum(e) => { + write!(f, "Failed to read checksum: {}", e) + } + FrameDecoderError::NotYetInitialized => { + write!(f, "Decoder must initialized or reset before using it",) + } + FrameDecoderError::FailedToInitialize(e) => { + write!(f, "Decoder encountered error while initializing: {}", e) + } + FrameDecoderError::FailedToDrainDecodebuffer(e) => { + write!( + f, + "Decoder encountered error while draining the decodebuffer: {}", + e, + ) + } + FrameDecoderError::TargetTooSmall => { + write!(f, "Target must have at least as many bytes as the contentsize of the frame reports") + } + FrameDecoderError::DictNotProvided { dict_id } => { + write!(f, "Frame header specified dictionary id 0x{:X} that wasnt provided by add_dict() or reset_with_dict()", dict_id) + } + } + } +} + +impl From for FrameDecoderError { + fn from(val: dictionary::DictionaryDecodeError) -> Self { + Self::DictionaryDecodeError(val) + } +} + +impl From for FrameDecoderError { + fn from(val: decoding::block_decoder::BlockHeaderReadError) -> Self { + Self::FailedToReadBlockHeader(val) + } +} + +impl From for FrameDecoderError { + fn from(val: frame::FrameHeaderError) -> Self { + Self::FrameHeaderError(val) + } +} + +impl From for FrameDecoderError { + fn from(val: frame::ReadFrameHeaderError) -> Self { + Self::ReadFrameHeaderError(val) + } +} + const MAX_WINDOW_SIZE: u64 = 1024 * 1024 * 100; impl FrameDecoderState { diff --git a/src/fse/fse_decoder.rs b/src/fse/fse_decoder.rs index b58aed25..c456fcab 100644 --- a/src/fse/fse_decoder.rs +++ b/src/fse/fse_decoder.rs @@ -2,61 +2,161 @@ use crate::decoding::bit_reader::BitReader; use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use alloc::vec::Vec; +/// FSE decoding involves a decoding table that describes the probabilities of +/// all literals from 0 to the highest present one +/// +/// pub struct FSETable { + /// The maximum symbol in the table (inclusive). Limits the probabilities length to max_symbol + 1. max_symbol: u8, + /// The actual table containing the decoded symbol and the compression data + /// connected to that symbol. pub decode: Vec, //used to decode symbols, and calculate the next state - + /// The size of the table is stored in logarithm base 2 format, + /// with the **size of the table** being equal to `(1 << accuracy_log)`. + /// This value is used so that the decoder knows how many bits to read from the bitstream. pub accuracy_log: u8, + /// In this context, probability refers to the likelihood that a symbol occurs in the given data. + /// Given this info, the encoder can assign shorter codes to symbols that appear more often, + /// and longer codes that appear less often, then the decoder can use the probability + /// to determine what code was assigned to what symbol. + /// + /// The probability of a single symbol is a value representing the proportion of times the symbol + /// would fall within the data. + /// + /// If a symbol probability is set to `-1`, it means that the probability of a symbol + /// occurring in the data is less than one. pub symbol_probabilities: Vec, //used while building the decode Vector + /// The number of times each symbol occurs (The first entry being 0x0, the second being 0x1) and so on + /// up until the highest possible symbol (255). symbol_counter: Vec, } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum FSETableError { - #[display(fmt = "Acclog must be at least 1")] AccLogIsZero, - #[display(fmt = "Found FSE acc_log: {got} bigger than allowed maximum in this case: {max}")] - AccLogTooBig { got: u8, max: u8 }, - #[display(fmt = "{_0:?}")] - #[from] + AccLogTooBig { + got: u8, + max: u8, + }, GetBitsError(GetBitsError), - #[display( - fmt = "The counter ({got}) exceeded the expected sum: {expected_sum}. This means an error or corrupted data \n {symbol_probabilities:?}" - )] ProbabilityCounterMismatch { got: u32, expected_sum: u32, symbol_probabilities: Vec, }, - #[display(fmt = "There are too many symbols in this distribution: {got}. Max: 256")] - TooManySymbols { got: usize }, + TooManySymbols { + got: usize, + }, +} + +#[cfg(feature = "std")] +impl std::error::Error for FSETableError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + FSETableError::GetBitsError(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for FSETableError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + FSETableError::AccLogIsZero => write!(f, "Acclog must be at least 1"), + FSETableError::AccLogTooBig { got, max } => { + write!( + f, + "Found FSE acc_log: {0} bigger than allowed maximum in this case: {1}", + got, max + ) + } + FSETableError::GetBitsError(e) => write!(f, "{:?}", e), + FSETableError::ProbabilityCounterMismatch { + got, + expected_sum, + symbol_probabilities, + } => { + write!(f, + "The counter ({}) exceeded the expected sum: {}. This means an error or corrupted data \n {:?}", + got, + expected_sum, + symbol_probabilities, + ) + } + FSETableError::TooManySymbols { got } => { + write!( + f, + "There are too many symbols in this distribution: {}. Max: 256", + got, + ) + } + } + } +} + +impl From for FSETableError { + fn from(val: GetBitsError) -> Self { + Self::GetBitsError(val) + } } pub struct FSEDecoder<'table> { + /// An FSE state value represents an index in the FSE table. pub state: Entry, + /// A reference to the table used for decoding. table: &'table FSETable, } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum FSEDecoderError { - #[display(fmt = "{_0:?}")] - #[from] GetBitsError(GetBitsError), - #[display(fmt = "Tried to use an uninitialized table!")] TableIsUninitialized, } +#[cfg(feature = "std")] +impl std::error::Error for FSEDecoderError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + FSEDecoderError::GetBitsError(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for FSEDecoderError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + FSEDecoderError::GetBitsError(e) => write!(f, "{:?}", e), + FSEDecoderError::TableIsUninitialized => { + write!(f, "Tried to use an uninitialized table!") + } + } + } +} + +impl From for FSEDecoderError { + fn from(val: GetBitsError) -> Self { + Self::GetBitsError(val) + } +} + +/// A single entry in an FSE table. #[derive(Copy, Clone)] pub struct Entry { + /// This value is used as an offset value, and it is added + /// to a value read from the stream to determine the next state value. pub base_line: u32, + /// How many bits should be read from the stream when decoding this entry. pub num_bits: u8, + /// The byte that should be put in the decode output when encountering this state. pub symbol: u8, } +/// This value is added to the first 4 bits of the stream to determine the +/// `Accuracy_Log` const ACC_LOG_OFFSET: u8 = 5; fn highest_bit_set(x: u32) -> u32 { @@ -65,6 +165,7 @@ fn highest_bit_set(x: u32) -> u32 { } impl<'t> FSEDecoder<'t> { + /// Initialize a new Finite State Entropy decoder. pub fn new(table: &'t FSETable) -> FSEDecoder<'_> { FSEDecoder { state: table.decode.first().copied().unwrap_or(Entry { @@ -76,10 +177,13 @@ impl<'t> FSEDecoder<'t> { } } + /// Returns the byte associated with the symbol the internal cursor is pointing at. pub fn decode_symbol(&self) -> u8 { self.state.symbol } + /// Initialize internal state and prepare for decoding. After this, `decode_symbol` can be called + /// to read the first symbol and `update_state` can be called to prepare to read the next symbol. pub fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), FSEDecoderError> { if self.table.accuracy_log == 0 { return Err(FSEDecoderError::TableIsUninitialized); @@ -89,6 +193,7 @@ impl<'t> FSEDecoder<'t> { Ok(()) } + /// Advance the internal state to decode the next symbol in the bitstream. pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; let add = bits.get_bits(num_bits); @@ -101,6 +206,7 @@ impl<'t> FSEDecoder<'t> { } impl FSETable { + /// Initialize a new empty Finite State Entropy decoding table. pub fn new(max_symbol: u8) -> FSETable { FSETable { max_symbol, @@ -111,6 +217,7 @@ impl FSETable { } } + /// Reset `self` and update `self`'s state to mirror the provided table. pub fn reinit_from(&mut self, other: &Self) { self.reset(); self.symbol_counter.extend_from_slice(&other.symbol_counter); @@ -120,6 +227,7 @@ impl FSETable { self.accuracy_log = other.accuracy_log; } + /// Empty the table and clear all internal state. pub fn reset(&mut self) { self.symbol_counter.clear(); self.symbol_probabilities.clear(); @@ -127,7 +235,7 @@ impl FSETable { self.accuracy_log = 0; } - //returns how many BYTEs (not bits) were read while building the decoder + /// returns how many BYTEs (not bits) were read while building the decoder pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result { self.accuracy_log = 0; @@ -137,6 +245,7 @@ impl FSETable { Ok(bytes_read) } + /// Given the provided accuracy log, build a decoding table from that log. pub fn build_from_probabilities( &mut self, acc_log: u8, @@ -150,12 +259,15 @@ impl FSETable { self.build_decoding_table() } + /// Build the actual decoding table after probabilities have been read into the table. + /// After this function is called, the decoding process can begin. fn build_decoding_table(&mut self) -> Result<(), FSETableError> { if self.symbol_probabilities.len() > self.max_symbol as usize + 1 { return Err(FSETableError::TooManySymbols { got: self.symbol_probabilities.len(), }); } + self.decode.clear(); let table_size = 1 << self.accuracy_log; @@ -230,6 +342,8 @@ impl FSETable { Ok(()) } + /// Read the accuracy log and the probability table from the source and return the number of bytes + /// read. If the size of the table is larger than the provided `max_log`, return an error. fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result { self.symbol_probabilities.clear(); //just clear, we will fill a probability for each entry anyways. No need to force new allocs here @@ -245,11 +359,11 @@ impl FSETable { return Err(FSETableError::AccLogIsZero); } - let probablility_sum = 1 << self.accuracy_log; + let probability_sum = 1 << self.accuracy_log; let mut probability_counter = 0; - while probability_counter < probablility_sum { - let max_remaining_value = probablility_sum - probability_counter + 1; + while probability_counter < probability_sum { + let max_remaining_value = probability_sum - probability_counter + 1; let bits_to_read = highest_bit_set(max_remaining_value); let unchecked_value = br.get_bits(bits_to_read as usize)? as u32; @@ -293,10 +407,10 @@ impl FSETable { } } - if probability_counter != probablility_sum { + if probability_counter != probability_sum { return Err(FSETableError::ProbabilityCounterMismatch { got: probability_counter, - expected_sum: probablility_sum, + expected_sum: probability_sum, symbol_probabilities: self.symbol_probabilities.clone(), }); } @@ -317,6 +431,8 @@ impl FSETable { } //utility functions for building the decoding table from probabilities +/// Calculate the position of the next entry of the table given the current +/// position and size of the table. fn next_position(mut p: usize, table_size: usize) -> usize { p += (table_size >> 1) + (table_size >> 3) + 3; p &= table_size - 1; diff --git a/src/fse/mod.rs b/src/fse/mod.rs index ba4beb51..e25489fa 100644 --- a/src/fse/mod.rs +++ b/src/fse/mod.rs @@ -1,2 +1,16 @@ +//! FSE, short for Finite State Entropy, is an encoding technique +//! that assigns shorter codes to symbols that appear more frequently in data, +//! and longer codes to less frequent symbols. +//! +//! FSE works by mutating a state and using that state to index into a table. +//! +//! Zstandard uses two different kinds of entropy encoding: FSE, and Huffman coding. +//! Huffman is used to compress literals, +//! while FSE is used for all other symbols (literal length code, match length code, offset code). +//! +//! https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#fse +//! +//! + mod fse_decoder; pub use fse_decoder::*; diff --git a/src/huff0/huff0_decoder.rs b/src/huff0/huff0_decoder.rs index b4b514d0..129cfc84 100644 --- a/src/huff0/huff0_decoder.rs +++ b/src/huff0/huff0_decoder.rs @@ -1,103 +1,241 @@ +//! Utilities for decoding Huff0 encoded huffman data. + use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError}; use alloc::vec::Vec; +#[cfg(feature = "std")] +use std::error::Error as StdError; pub struct HuffmanTable { decode: Vec, - + /// The weight of a symbol is the number of occurences in a table. + /// This value is used in constructing a binary tree referred to as + /// a huffman tree. weights: Vec, + /// The maximum size in bits a prefix code in the encoded data can be. + /// This value is used so that the decoder knows how many bits + /// to read from the bitstream before checking the table. This + /// value must be 11 or lower. pub max_num_bits: u8, bits: Vec, bit_ranks: Vec, rank_indexes: Vec, - + /// In some cases, the list of weights is compressed using FSE compression. fse_table: FSETable, } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum HuffmanTableError { - #[display(fmt = "{_0:?}")] - #[from] GetBitsError(GetBitsError), - #[display(fmt = "{_0:?}")] - #[from] FSEDecoderError(FSEDecoderError), - #[display(fmt = "{_0:?}")] - #[from] FSETableError(FSETableError), - #[display(fmt = "Source needs to have at least one byte")] SourceIsEmpty, - #[display( - fmt = "Header says there should be {expected_bytes} bytes for the weights but there are only {got_bytes} bytes in the stream" - )] NotEnoughBytesForWeights { got_bytes: usize, expected_bytes: u8, }, - #[display( - fmt = "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption" - )] - ExtraPadding { skipped_bits: i32 }, - #[display( - fmt = "More than 255 weights decoded (got {got} weights). Stream is probably corrupted" - )] - TooManyWeights { got: usize }, - #[display(fmt = "Can't build huffman table without any weights")] + ExtraPadding { + skipped_bits: i32, + }, + TooManyWeights { + got: usize, + }, MissingWeights, - #[display(fmt = "Leftover must be power of two but is: {got}")] - LeftoverIsNotAPowerOf2 { got: u32 }, - #[display( - fmt = "Not enough bytes in stream to decompress weights. Is: {have}, Should be: {need}" - )] - NotEnoughBytesToDecompressWeights { have: usize, need: usize }, - #[display( - fmt = "FSE table used more bytes: {used} than were meant to be used for the whole stream of huffman weights ({available_bytes})" - )] - FSETableUsedTooManyBytes { used: usize, available_bytes: u8 }, - #[display(fmt = "Source needs to have at least {need} bytes, got: {got}")] - NotEnoughBytesInSource { got: usize, need: usize }, - #[display(fmt = "Cant have weight: {got} bigger than max_num_bits: {MAX_MAX_NUM_BITS}")] - WeightBiggerThanMaxNumBits { got: u8 }, - #[display( - fmt = "max_bits derived from weights is: {got} should be lower than: {MAX_MAX_NUM_BITS}" - )] - MaxBitsTooHigh { got: u8 }, + LeftoverIsNotAPowerOf2 { + got: u32, + }, + NotEnoughBytesToDecompressWeights { + have: usize, + need: usize, + }, + FSETableUsedTooManyBytes { + used: usize, + available_bytes: u8, + }, + NotEnoughBytesInSource { + got: usize, + need: usize, + }, + WeightBiggerThanMaxNumBits { + got: u8, + }, + MaxBitsTooHigh { + got: u8, + }, } +#[cfg(feature = "std")] +impl StdError for HuffmanTableError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + HuffmanTableError::GetBitsError(source) => Some(source), + HuffmanTableError::FSEDecoderError(source) => Some(source), + HuffmanTableError::FSETableError(source) => Some(source), + _ => None, + } + } +} + +impl core::fmt::Display for HuffmanTableError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> ::core::fmt::Result { + match self { + HuffmanTableError::GetBitsError(e) => write!(f, "{:?}", e), + HuffmanTableError::FSEDecoderError(e) => write!(f, "{:?}", e), + HuffmanTableError::FSETableError(e) => write!(f, "{:?}", e), + HuffmanTableError::SourceIsEmpty => write!(f, "Source needs to have at least one byte"), + HuffmanTableError::NotEnoughBytesForWeights { + got_bytes, + expected_bytes, + } => { + write!(f, "Header says there should be {} bytes for the weights but there are only {} bytes in the stream", + expected_bytes, + got_bytes) + } + HuffmanTableError::ExtraPadding { skipped_bits } => { + write!(f, + "Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption", + skipped_bits, + ) + } + HuffmanTableError::TooManyWeights { got } => { + write!( + f, + "More than 255 weights decoded (got {} weights). Stream is probably corrupted", + got, + ) + } + HuffmanTableError::MissingWeights => { + write!(f, "Can\'t build huffman table without any weights") + } + HuffmanTableError::LeftoverIsNotAPowerOf2 { got } => { + write!(f, "Leftover must be power of two but is: {}", got) + } + HuffmanTableError::NotEnoughBytesToDecompressWeights { have, need } => { + write!( + f, + "Not enough bytes in stream to decompress weights. Is: {}, Should be: {}", + have, need, + ) + } + HuffmanTableError::FSETableUsedTooManyBytes { + used, + available_bytes, + } => { + write!(f, + "FSE table used more bytes: {} than were meant to be used for the whole stream of huffman weights ({})", + used, + available_bytes, + ) + } + HuffmanTableError::NotEnoughBytesInSource { got, need } => { + write!( + f, + "Source needs to have at least {} bytes, got: {}", + need, got, + ) + } + HuffmanTableError::WeightBiggerThanMaxNumBits { got } => { + write!( + f, + "Cant have weight: {} bigger than max_num_bits: {}", + got, MAX_MAX_NUM_BITS, + ) + } + HuffmanTableError::MaxBitsTooHigh { got } => { + write!( + f, + "max_bits derived from weights is: {} should be lower than: {}", + got, MAX_MAX_NUM_BITS, + ) + } + } + } +} + +impl From for HuffmanTableError { + fn from(val: GetBitsError) -> Self { + Self::GetBitsError(val) + } +} + +impl From for HuffmanTableError { + fn from(val: FSEDecoderError) -> Self { + Self::FSEDecoderError(val) + } +} + +impl From for HuffmanTableError { + fn from(val: FSETableError) -> Self { + Self::FSETableError(val) + } +} + +/// An interface around a huffman table used to decode data. pub struct HuffmanDecoder<'table> { table: &'table HuffmanTable, + /// State is used to index into the table. pub state: u64, } -#[derive(Debug, derive_more::Display, derive_more::From)] -#[cfg_attr(feature = "std", derive(derive_more::Error))] +#[derive(Debug)] #[non_exhaustive] pub enum HuffmanDecoderError { - #[display(fmt = "{_0:?}")] - #[from] GetBitsError(GetBitsError), } +impl core::fmt::Display for HuffmanDecoderError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + HuffmanDecoderError::GetBitsError(e) => write!(f, "{:?}", e), + } + } +} + +#[cfg(feature = "std")] +impl StdError for HuffmanDecoderError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + HuffmanDecoderError::GetBitsError(source) => Some(source), + } + } +} + +impl From for HuffmanDecoderError { + fn from(val: GetBitsError) -> Self { + Self::GetBitsError(val) + } +} + +/// A single entry in the table contains the decoded symbol/literal and the +/// size of the prefix code. #[derive(Copy, Clone)] pub struct Entry { + /// The byte that the prefix code replaces during encoding. symbol: u8, + /// The number of bits the prefix code occupies. num_bits: u8, } +/// The Zstandard specification limits the maximum length of a code to 11 bits. const MAX_MAX_NUM_BITS: u8 = 11; +/// Assert that the provided value is greater than zero, and returns the +/// 32 - the number of leading zeros fn highest_bit_set(x: u32) -> u32 { assert!(x > 0); u32::BITS - x.leading_zeros() } impl<'t> HuffmanDecoder<'t> { + /// Create a new decoder with the provided table pub fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> { HuffmanDecoder { table, state: 0 } } + /// Re-initialize the decoder, using the new table if one is provided. + /// This might used for treeless blocks, because they re-use the table from old + /// data. pub fn reset(mut self, new_table: Option<&'t HuffmanTable>) { self.state = 0; if let Some(next_table) = new_table { @@ -105,10 +243,15 @@ impl<'t> HuffmanDecoder<'t> { } } + /// Decode the symbol the internal state (cursor) is pointed at and return the + /// decoded literal. pub fn decode_symbol(&mut self) -> u8 { self.table.decode[self.state as usize].symbol } + /// Initialize internal state and prepare to decode data. Then, `decode_symbol` can be called + /// to read the byte the internal cursor is pointing at, and `next_state` can be called to advance + /// the cursor until the max number of bits has been read. pub fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 { let num_bits = self.table.max_num_bits; let new_bits = br.get_bits(num_bits); @@ -116,11 +259,18 @@ impl<'t> HuffmanDecoder<'t> { num_bits } + /// Advance the internal cursor to the next symbol. After this, you can call `decode_symbol` + /// to read from the new position. pub fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 { + // self.state stores a small section, or a window of the bit stream. The table can be indexed via this state, + // telling you how many bits identify the current symbol. let num_bits = self.table.decode[self.state as usize].num_bits; + // New bits are read from the stream let new_bits = br.get_bits(num_bits); + // Shift and mask out the bits that identify the current symbol self.state <<= num_bits; self.state &= self.table.decode.len() as u64 - 1; + // The new bits are appended at the end of the current state. self.state |= new_bits; num_bits } @@ -133,6 +283,7 @@ impl Default for HuffmanTable { } impl HuffmanTable { + /// Create a new, empty table. pub fn new() -> HuffmanTable { HuffmanTable { decode: Vec::new(), @@ -146,6 +297,8 @@ impl HuffmanTable { } } + /// Completely empty the table then repopulate as a replica + /// of `other`. pub fn reinit_from(&mut self, other: &Self) { self.reset(); self.decode.extend_from_slice(&other.decode); @@ -156,6 +309,7 @@ impl HuffmanTable { self.fse_table.reinit_from(&other.fse_table); } + /// Completely empty the table of all data. pub fn reset(&mut self) { self.decode.clear(); self.weights.clear(); @@ -166,6 +320,9 @@ impl HuffmanTable { self.fse_table.reset(); } + /// Read from `source` and parse it into a huffman table. + /// + /// Returns the number of bytes read. pub fn build_decoder(&mut self, source: &[u8]) -> Result { self.decode.clear(); @@ -174,6 +331,13 @@ impl HuffmanTable { Ok(bytes_used) } + /// Read weights from the provided source. + /// + /// The huffman table is represented in the encoded data as a list of weights + /// at the most basic level. After the header, weights are read, then the table + /// can be built using that list of weights. + /// + /// Returns the number of bytes read. fn read_weights(&mut self, source: &[u8]) -> Result { use HuffmanTableError as err; @@ -184,6 +348,9 @@ impl HuffmanTable { let mut bits_read = 8; match header { + // If the header byte is less than 128, the series of weights + // is compressed using two interleaved FSE streams that share + // a distribution table. 0..=127 => { let fse_stream = &source[1..]; if header as usize > fse_stream.len() { @@ -208,6 +375,9 @@ impl HuffmanTable { "Building fse table for huffman weights used: {}", bytes_used_by_fse_header ); + // Huffman headers are compressed using two interleaved + // FSE bitstreams, where the first state (decoder) handles + // even symbols, and the second handles odd symbols. let mut dec1 = FSEDecoder::new(&self.fse_table); let mut dec2 = FSEDecoder::new(&self.fse_table); @@ -245,6 +415,7 @@ impl HuffmanTable { self.weights.clear(); + // The two decoders take turns decoding a single symbol and updating their state. loop { let w = dec1.decode_symbol(); self.weights.push(w); @@ -273,6 +444,12 @@ impl HuffmanTable { } } } + // If the header byte is greater than or equal to 128, + // weights are directly represented, where each weight is + // encoded directly as a 4 bit field. The weights will + // always be encoded with full bytes, meaning if there's + // an odd number of weights, the last weight will still + // occupy a full byte. _ => { // weights are directly encoded let weights_raw = &source[1..]; @@ -311,6 +488,10 @@ impl HuffmanTable { Ok(bytes_read as u32) } + /// Once the weights have been read from the data, you can decode the weights + /// into a table, and use that table to decode the actual compressed data. + /// + /// This function populates the rest of the table from the series of weights. fn build_table_from_weights(&mut self) -> Result<(), HuffmanTableError> { use HuffmanTableError as err; diff --git a/src/huff0/mod.rs b/src/huff0/mod.rs index 445c7fab..3d847d65 100644 --- a/src/huff0/mod.rs +++ b/src/huff0/mod.rs @@ -1,2 +1,6 @@ +/// Huffman coding is a method of encoding where symbols are assigned a code, +/// and more commonly used symbols get shorter codes, and less commonly +/// used symbols get longer codes. Codes are prefix free, meaning no two codes +/// will start with the same sequence of bits. mod huff0_decoder; pub use huff0_decoder::*; diff --git a/src/io.rs b/src/io.rs index 6970cd13..7a90969a 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,2 +1,3 @@ +//! Re-exports of std values for when the std is available. #[cfg(feature = "std")] pub use std::io::{Error, ErrorKind, Read, Write}; diff --git a/src/io_nostd.rs b/src/io_nostd.rs index 0fc76e90..880ff499 100644 --- a/src/io_nostd.rs +++ b/src/io_nostd.rs @@ -1,3 +1,5 @@ +//! Manual implementations of representations for `#![no_std]` + use alloc::boxed::Box; #[non_exhaustive] diff --git a/src/streaming_decoder.rs b/src/streaming_decoder.rs index 3021e3f2..fde7a9d9 100644 --- a/src/streaming_decoder.rs +++ b/src/streaming_decoder.rs @@ -3,19 +3,40 @@ use core::borrow::BorrowMut; use crate::frame_decoder::{BlockDecodingStrategy, FrameDecoder, FrameDecoderError}; use crate::io::{Error, ErrorKind, Read}; -/// High level decoder that implements a io::Read that can be used with -/// io::Read::read_to_end / io::Read::read_exact or passing this to another library / module as a source for the decoded content +/// High level Zstandard frame decoder that can be used to decompress a given Zstandard frame. /// -/// The lower level FrameDecoder by comparison allows for finer grained control but need sto have it's decode_blocks method called continously -/// to decode the zstd-frame. +/// This decoder implements `io::Read`, so you can interact with it by calling +/// `io::Read::read_to_end` / `io::Read::read_exact` or passing this to another library / module as a source for the decoded content +/// +/// If you need more control over how decompression takes place, you can use +/// the lower level [FrameDecoder], which allows for greater control over how +/// decompression takes place but the implementor must call +/// [FrameDecoder::decode_blocks] repeatedly to decode the entire frame. /// /// ## Caveat -/// [StreamingDecoder] expects the underlying stream to only contain a single frame. +/// [StreamingDecoder] expects the underlying stream to only contain a single frame, +/// yet the specification states that a single archive may contain multiple frames. +/// /// To decode all the frames in a finite stream, the calling code needs to recreate -/// the instance of the decoder -/// and handle +/// the instance of the decoder and handle /// [crate::frame::ReadFrameHeaderError::SkipFrame] /// errors by skipping forward the `length` amount of bytes, see +/// +/// ```no_run +/// // `read_to_end` is not implemented by the no_std implementation. +/// #[cfg(feature = "std")] +/// { +/// use std::fs::File; +/// use std::io::Read; +/// use ruzstd::{StreamingDecoder}; +/// +/// // Read a Zstandard archive from the filesystem then decompress it into a vec. +/// let mut f: File = todo!("Read a .zstd archive from somewhere"); +/// let mut decoder = StreamingDecoder::new(f).unwrap(); +/// let mut result = Vec::new(); +/// Read::read_to_end(&mut decoder, &mut result).unwrap(); +/// } +/// ``` pub struct StreamingDecoder> { pub decoder: DEC, source: READ, @@ -39,8 +60,39 @@ impl StreamingDecoder { decoder.init(&mut source)?; Ok(StreamingDecoder { decoder, source }) } +} + +impl> StreamingDecoder { + /// Gets a reference to the underlying reader. + pub fn get_ref(&self) -> &READ { + &self.source + } + + /// Gets a mutable reference to the underlying reader. + /// + /// It is inadvisable to directly read from the underlying reader. + pub fn get_mut(&mut self) -> &mut READ { + &mut self.source + } + + /// Destructures this object into the inner reader. + pub fn into_inner(self) -> READ + where + READ: Sized, + { + self.source + } + + /// Destructures this object into both the inner reader and [FrameDecoder]. + pub fn into_parts(self) -> (READ, DEC) + where + READ: Sized, + { + (self.source, self.decoder) + } - pub fn inner(self) -> FrameDecoder { + /// Destructures this object into the inner [FrameDecoder]. + pub fn into_frame_decoder(self) -> DEC { self.decoder } } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3cbcd239..95a0a6d7 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -331,9 +331,11 @@ fn test_streaming() { // Test resetting to a new file while keeping the old decoder let mut content = fs::File::open("./decodecorpus_files/z000068.zst").unwrap(); - let mut stream = - crate::streaming_decoder::StreamingDecoder::new_with_decoder(&mut content, stream.inner()) - .unwrap(); + let mut stream = crate::streaming_decoder::StreamingDecoder::new_with_decoder( + &mut content, + stream.into_frame_decoder(), + ) + .unwrap(); let mut result = Vec::new(); Read::read_to_end(&mut stream, &mut result).unwrap(); @@ -415,9 +417,11 @@ fn test_streaming_no_std() { let content = include_bytes!("../../decodecorpus_files/z000068.zst"); let mut content = content.as_slice(); - let mut stream = - crate::streaming_decoder::StreamingDecoder::new_with_decoder(&mut content, stream.inner()) - .unwrap(); + let mut stream = crate::streaming_decoder::StreamingDecoder::new_with_decoder( + &mut content, + stream.into_frame_decoder(), + ) + .unwrap(); let original = include_bytes!("../../decodecorpus_files/z000068"); let mut result = vec![0; original.len()];