Skip to content

Commit

Permalink
Merge pull request #58 from bamler-lab:models-cleanup
Browse files Browse the repository at this point in the history
Clean up `stream::models` (breaking changes)
  • Loading branch information
robamler authored Aug 29, 2024
2 parents cd5613d + e3260b9 commit 54e8581
Show file tree
Hide file tree
Showing 35 changed files with 7,533 additions and 4,155 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,23 @@ jobs:
RUSTFLAGS: "-D warnings"
run: cargo clippy --all-targets --all-features

# The environment variable `RUSTFLAGS` doesn't actually seem to have any effect on rustdoc,
# but we set it here to the same value as in all other cargo runs as changing it would
# cause unnecessary recompilation of some dependencies.
- name: Check for broken doc links
env:
RUSTFLAGS: "-D warnings"
run: cargo rustdoc -- -D rustdoc::broken-intra-doc-links

- name: Test in development mode
env:
RUSTFLAGS: "-D warnings"
run: cargo test
run: cargo test --all-targets

- name: Test in release mode
env:
RUSTFLAGS: "-D warnings"
run: cargo test --release
run: cargo test --release --all-targets

miri-test:
name: no_std and miri
Expand Down
12 changes: 10 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,23 @@ jobs:
RUSTFLAGS: "-D warnings"
run: cargo clippy --all-targets --all-features

# The environment variable `RUSTFLAGS` doesn't actually seem to have any effect on rustdoc,
# but we set it here to the same value as in all other cargo runs as changing it would
# cause unnecessary recompilation of some dependencies.
- name: Check for broken doc links
env:
RUSTFLAGS: "-D warnings"
run: cargo rustdoc -- -D rustdoc::broken-intra-doc-links

- name: Test in development mode
env:
RUSTFLAGS: "-D warnings"
run: cargo test
run: cargo test --all-targets

- name: Test in release mode
env:
RUSTFLAGS: "-D warnings"
run: cargo test --release
run: cargo test --release --all-targets

miri-test:
name: no_std and miri
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ name = "constriction"
readme = "README-rust.md"
repository = "https://github.com/bamler-lab/constriction/"
version = "0.3.5"
rust-version = "1.75" # for feature `return_position_impl_trait_in_traits`

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -24,7 +25,7 @@ std = []

# Use feature `pybindings` to compile the python extension module that provides
# access to this library from python. This feature is turned off by default
# because it causes problems with `cargo test`. To turn it on, run:
# because it causes problems with `cargo test` on Mac OS. To turn it on, run:
# cargo build --release --features pybindings
pybindings = ["ndarray", "numpy", "pyo3"]

Expand Down
6 changes: 4 additions & 2 deletions README-python.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ message = np.array([6, 10, -4, 2, 5, 2, 1, 0, 2], dtype=np.int32)
means = np.array([2.3, 6.1, -8.5, 4.1, 1.3], dtype=np.float64)
stds = np.array([6.2, 5.3, 3.8, 3.2, 4.7], dtype=np.float64)
entropy_model1 = constriction.stream.model.QuantizedGaussian(-50, 50)
entropy_model2 = constriction.stream.model.Categorical(np.array(
[0.2, 0.5, 0.3], dtype=np.float64)) # Probabilities of the symbols 0,1,2.
entropy_model2 = constriction.stream.model.Categorical(
np.array([0.2, 0.5, 0.3], dtype=np.float32), # Probabilities of the symbols 0,1,2.
perfect=False
)

# Simply encode both parts in sequence with their respective models:
encoder = constriction.stream.queue.RangeEncoder()
Expand Down
6 changes: 3 additions & 3 deletions benches/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::any::type_name;

use constriction::{
stream::{
model::{LookupDecoderModel, NonContiguousCategoricalEncoderModel},
model::{NonContiguousCategoricalEncoderModel, NonContiguousLookupDecoderModel},
queue::RangeEncoder,
stack::AnsCoder,
Code, Decode, Encode,
Expand Down Expand Up @@ -106,7 +106,7 @@ where
.unwrap();

let decoder_model =
LookupDecoderModel::<u16,Probability,_,_,PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
NonContiguousLookupDecoderModel::<u16, Probability, _, _, PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
symbols,probabilities,false
)
.unwrap();
Expand Down Expand Up @@ -203,7 +203,7 @@ where
.unwrap();

let decoder_model =
LookupDecoderModel::<u16,Probability,_,_,PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
NonContiguousLookupDecoderModel::<u16, Probability, _, _, PRECISION>::from_symbols_and_nonzero_fixed_point_probabilities(
symbols,probabilities,false
)
.unwrap();
Expand Down
6 changes: 3 additions & 3 deletions ensure_no_std/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ fn alloc_error_handler(layout: core::alloc::Layout) -> ! {
pub extern "C" fn _start() -> ! {
use constriction::stream::{Decode, Encode};

let model = constriction::stream::model::UniformModel::<u32, 24>::new(10);
let model = constriction::stream::model::DefaultUniformModel::new(10);

let mut encoder = constriction::stream::stack::DefaultAnsCoder::new();
encoder.encode_symbol(3u32, model).unwrap();
encoder.encode_symbol(5u32, model).unwrap();
encoder.encode_symbol(3usize, model).unwrap();
encoder.encode_symbol(5usize, model).unwrap();
let compressed = core::hint::black_box(encoder.into_compressed().unwrap());

let mut decoder =
Expand Down
2 changes: 1 addition & 1 deletion src/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ impl<B: Seek> Seek for Reverse<B> {
/// );
/// // Encoding *a few* more symbols works ...
/// cursor_coder.encode_iid_symbols_reverse(65..75, &model).unwrap();
/// // ... but at some point we'll run out of buffer space.
/// // ... but at some point we'll run out of buffer space:
/// assert_eq!(
/// cursor_coder.encode_iid_symbols_reverse(50..65, &model),
/// Err(CoderError::Backend(constriction::backends::BoundedWriteError::OutOfSpace))
Expand Down
82 changes: 57 additions & 25 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ impl<FrontendError, BackendError> From<BackendError> for CoderError<FrontendErro

impl<FrontendError> CoderError<FrontendError, Infallible> {
fn into_frontend_error(self) -> FrontendError {
#[allow(unreachable_patterns)]
match self {
CoderError::Frontend(frontend_error) => frontend_error,
CoderError::Backend(infallible) => match infallible {},
Expand Down Expand Up @@ -459,7 +460,7 @@ pub trait Pos: PosSeek {
/// let mut ans = DefaultAnsCoder::new();
/// let probabilities = vec![0.03, 0.07, 0.1, 0.1, 0.2, 0.2, 0.1, 0.15, 0.05];
/// let entropy_model = DefaultContiguousCategoricalEntropyModel
/// ::from_floating_point_probabilities(&probabilities).unwrap();
/// ::from_floating_point_probabilities_fast(&probabilities, None).unwrap();
///
/// // Encode some symbols in two chunks and take a snapshot after each chunk.
/// let symbols1 = vec![8, 2, 0, 7];
Expand Down Expand Up @@ -640,15 +641,6 @@ pub unsafe trait BitArray:
}
}

#[inline(always)]
fn wrapping_pow2<T: BitArray>(exponent: usize) -> T {
if exponent >= T::BITS {
T::zero()
} else {
T::one() << exponent
}
}

/// A trait for bit strings like [`BitArray`] but with guaranteed nonzero values
///
/// # Safety
Expand All @@ -668,21 +660,6 @@ pub unsafe trait NonZeroBitArray: Copy + Display + Debug + Eq + Hash + 'static {
fn get(self) -> Self::Base;
}

/// Iterates from most significant to least significant bits in chunks but skips any
/// initial zero chunks.
fn bit_array_to_chunks_truncated<Data, Chunk>(
data: Data,
) -> impl ExactSizeIterator<Item = Chunk> + DoubleEndedIterator
where
Data: BitArray + AsPrimitive<Chunk>,
Chunk: BitArray,
{
(0..(Data::BITS - data.leading_zeros() as usize))
.step_by(Chunk::BITS)
.rev()
.map(move |shift| (data >> shift).as_())
}

macro_rules! unsafe_impl_bit_array {
($(($base:ty, $non_zero:ty)),+ $(,)?) => {
$(
Expand Down Expand Up @@ -737,13 +714,38 @@ unsafe_impl_bit_array!(
#[cfg(feature = "std")]
unsafe_impl_bit_array!((u128, core::num::NonZeroU128),);

/// Iterates from most significant to least significant bits in chunks but skips any
/// initial zero chunks.
fn bit_array_to_chunks_truncated<Data, Chunk>(
data: Data,
) -> impl ExactSizeIterator<Item = Chunk> + DoubleEndedIterator
where
Data: BitArray + AsPrimitive<Chunk>,
Chunk: BitArray,
{
(0..(Data::BITS - data.leading_zeros() as usize))
.step_by(Chunk::BITS)
.rev()
.map(move |shift| (data >> shift).as_())
}

#[inline(always)]
fn wrapping_pow2<T: BitArray>(exponent: usize) -> T {
if exponent >= T::BITS {
T::zero()
} else {
T::one() << exponent
}
}

pub trait UnwrapInfallible<T> {
fn unwrap_infallible(self) -> T;
}

impl<T> UnwrapInfallible<T> for Result<T, Infallible> {
#[inline(always)]
fn unwrap_infallible(self) -> T {
#[allow(unreachable_patterns)]
match self {
Ok(x) => x,
Err(infallible) => match infallible {},
Expand All @@ -753,12 +755,42 @@ impl<T> UnwrapInfallible<T> for Result<T, Infallible> {

impl<T> UnwrapInfallible<T> for Result<T, CoderError<Infallible, Infallible>> {
fn unwrap_infallible(self) -> T {
#[allow(unreachable_patterns)]
match self {
Ok(x) => x,
#[allow(unreachable_patterns)]
Err(infallible) => match infallible {
CoderError::Backend(infallible) => match infallible {},
CoderError::Frontend(infallible) => match infallible {},
},
}
}
}

/// Helper macro to express assertions that are tested at compile time
/// despite using properties of generic parameters of an outer function.
///
/// See discussion at <https://morestina.net/blog/1940>.
macro_rules! generic_static_asserts {
(($($l:lifetime,)* $($($t:ident$(: $bound:path)?),+)? $(; $(const $c:ident:$ct:ty),+)?); $($label:ident: $test:expr);+$(;)?) => {
#[allow(path_statements, clippy::no_effect)]
{
{
struct Check<$($l,)* $($($t,)+)? $($(const $c:$ct,)+)?>($($($t,)+)?);
impl<$($l,)* $($($t$(:$bound)?,)+)? $($(const $c:$ct,)+)?> Check<$($l,)* $($($t,)+)? $($($c,)+)?> {
$(
const $label: () = assert!($test);
)+
}
generic_static_asserts!{@nested Check::<$($l,)* $($($t,)+)? $($($c,)+)?>, $($label: $test;)+}
}
}
};
(@nested $t:ty, $($label:ident: $test:expr;)+) => {
$(
<$t>::$label;
)+
}
}

pub(crate) use generic_static_asserts;
6 changes: 4 additions & 2 deletions src/pybindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ use pyo3::{prelude::*, wrap_pymodule};
/// means = np.array([2.3, 6.1, -8.5, 4.1, 1.3], dtype=np.float64)
/// stds = np.array([6.2, 5.3, 3.8, 3.2, 4.7], dtype=np.float64)
/// entropy_model1 = constriction.stream.model.QuantizedGaussian(-50, 50)
/// entropy_model2 = constriction.stream.model.Categorical(np.array(
/// [0.2, 0.5, 0.3], dtype=np.float64)) # Probabilities of the symbols 0,1,2.
/// entropy_model2 = constriction.stream.model.Categorical(
/// np.array([0.2, 0.5, 0.3], dtype=np.float32), # Probabilities of the symbols 0,1,2.
/// perfect=False
/// )
///
/// # Simply encode both parts in sequence with their respective models:
/// encoder = constriction.stream.queue.RangeEncoder()
Expand Down
4 changes: 2 additions & 2 deletions src/pybindings/stream/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl ChainCoder {
) -> PyResult<()> {
if let Ok(symbol) = symbols.extract::<i32>() {
if !params.is_empty() {
return Err(pyo3::exceptions::PyAttributeError::new_err(
return Err(pyo3::exceptions::PyValueError::new_err(
"To encode a single symbol, use a concrete model, i.e., pass the\n\
model parameters directly to the constructor of the model and not to the\n\
`encode` method of the entropy coder. Delaying the specification of model\n\
Expand Down Expand Up @@ -166,7 +166,7 @@ impl ChainCoder {
})?;
} else {
if symbols.len() != model.0.len(&params[0])? {
return Err(pyo3::exceptions::PyAttributeError::new_err(
return Err(pyo3::exceptions::PyValueError::new_err(
"`symbols` argument has wrong length.",
));
}
Expand Down
Loading

0 comments on commit 54e8581

Please sign in to comment.