Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify pybindings #63

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 9 additions & 142 deletions src/pybindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub mod symbol;
use std::prelude::v1::*;

use alloc::borrow::Cow;
use numpy::{ndarray, PyArrayMethods, PyReadonlyArray, PyUntypedArrayMethods};
use numpy::{ndarray, PyArrayMethods, PyReadonlyArray, PyReadonlyArray1, PyUntypedArrayMethods};
use pyo3::{prelude::*, wrap_pymodule};

use crate::NanError;
Expand Down Expand Up @@ -177,150 +177,12 @@ use crate::NanError;
/// [entropy models](stream/model.html).
#[pymodule]
#[pyo3(name = "constriction")]
fn init_module(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_wrapped(wrap_pymodule!(init_stream))?;
module.add_wrapped(wrap_pymodule!(init_symbol))?;
fn init_module(module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_wrapped(wrap_pymodule!(stream::init_module))?;
module.add_wrapped(wrap_pymodule!(symbol::init_module))?;
Ok(())
}

/// Stream codes, i.e., entropy codes that amortize compressed bits over several symbols.
///
/// We provide two main stream codes:
///
/// - **Range Coding** [1, 2] (in submodule [`queue`](stream/queue.html)), which is a computationally
/// more efficient variant of Arithmetic Coding [3], and which has "queue" semantics ("first in
/// first out", i.e., symbols get decoded in the same order in which they were encoded); and
/// - **Asymmetric Numeral Systems (ANS)** [4] (in submodule [`stack`](stream/stack.html)), which has
/// "stack" semantics ("last in first out", i.e., symbols get decoded in *reverse* order compared
/// to the the order in which they got encoded).
///
/// In addition, the submodule [`model`](stream/model.html) provides common entropy models and
/// wrappers for defining your own entropy models (or for using models from the popular `scipy`
/// package in `constriction`). These entropy models can be used with both of the above stream
/// codes.
///
/// We further provide an experimental new "Chain Coder" (in submodule [`chain`](stream/chain.html)),
/// which is intended for special new compression methods.
///
/// ## Examples
///
/// See top of the documentations of both submodules [`queue`](stream/queue.html) and
/// [`stack`](stream/stack.html).
///
/// ## References
///
/// [1] Pasco, Richard Clark. Source coding algorithms for fast data compression. Diss.
/// Stanford University, 1976.
///
/// [2] Martin, G. Nigel N. "Range encoding: an algorithm for removing redundancy from a
/// digitised message." Proc. Institution of Electronic and Radio Engineers International
/// Conference on Video and Data Recording. 1979.
///
/// [3] Rissanen, Jorma, and Glen G. Langdon. "Arithmetic coding." IBM Journal of research
/// and development 23.2 (1979): 149-162.
///
/// [4] Duda, Jarek, et al. "The use of asymmetric numeral systems as an accurate
/// replacement for Huffman coding." 2015 Picture Coding Symposium (PCS). IEEE, 2015.
#[pymodule]
#[pyo3(name = "stream")]
fn init_stream(py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
stream::init_module(py, module)
}

/// Symbol codes. Mainly provided for teaching purpose. You'll probably want to use a
/// [stream code](stream.html) instead.
///
/// Symbol codes encode messages (= sequences of symbols) by compressing each symbol
/// individually and then concatenating the compressed representations of the symbols
/// (called code words). This has the advantage of being conceptually simple, but the
/// disadvantage of incurring an overhead of up to almost 1 bit *per symbol*. The overhead
/// is most significant in the regime of low entropy per symbol (often the regime of novel
/// machine-learning based compression methods), since each (nontrivial) symbol in a symbol
/// code contributes at least 1 bit to the total bitrate of the message, even if its
/// information content is only marginally above zero.
///
/// This module defines the types `QueueEncoder`, `QueueDecoder`, and `StackCoder`, which
/// are essentially just efficient containers for of growable bit strings. The interesting
/// part happens in the so-called code book, which defines how symbols map to code words. We
/// provide the Huffman Coding algorithm in the submodule [`huffman`](symbol/huffman.html),
/// which calculates an optimal code book for a given probability distribution.
///
/// ## Examples
///
/// The following two examples both encode and then decode a sequence of symbols using the
/// same entropy model for each symbol. We could as well use a different entropy model for
/// each symbol, but it would make the examples more lengthy.
///
/// - Example 1: Encoding and decoding with "queue" semantics ("first in first out", i.e.,
/// we'll decode symbols in the same order in which we encode them):
///
/// ```python
/// import constriction
/// import numpy as np
///
/// # Define an entropy model over the (implied) alphabet {0, 1, 2, 3}:
/// probabils = np.array([0.3, 0.2, 0.4, 0.1], dtype=np.float32)
///
/// # Encode some example message, using the same model for each symbol here:
/// message = [1, 3, 2, 3, 0, 1, 3, 0, 2, 1, 1, 3, 3, 1, 2, 0, 1, 3, 1]
/// encoder = constriction.symbol.QueueEncoder()
/// encoder_codebook = constriction.symbol.huffman.EncoderHuffmanTree(probabils)
/// for symbol in message:
/// encoder.encode_symbol(symbol, encoder_codebook)
///
/// # Obtain the compressed representation and the bitrate:
/// compressed, bitrate = encoder.get_compressed()
/// print(compressed, bitrate) # (prints: [3756389791, 61358], 48)
/// print(f"(in binary: {[bin(word) for word in compressed]}")
///
/// # Decode the message
/// decoder = constriction.symbol.QueueDecoder(compressed)
/// decoded = []
/// decoder_codebook = constriction.symbol.huffman.DecoderHuffmanTree(probabils)
/// for symbol in range(19):
/// decoded.append(decoder.decode_symbol(decoder_codebook))
///
/// assert decoded == message # (verifies correctness)
/// ```
///
/// - Example 2: Encoding and decoding with "stack" semantics ("last in first out", i.e.,
/// we'll encode symbols from back to front and then decode them from front to back):
///
/// ```python
/// import constriction
/// import numpy as np
///
/// # Define an entropy model over the (implied) alphabet {0, 1, 2, 3}:
/// probabils = np.array([0.3, 0.2, 0.4, 0.1], dtype=np.float32)
///
/// # Encode some example message, using the same model for each symbol here:
/// message = [1, 3, 2, 3, 0, 1, 3, 0, 2, 1, 1, 3, 3, 1, 2, 0, 1, 3, 1]
/// coder = constriction.symbol.StackCoder()
/// encoder_codebook = constriction.symbol.huffman.EncoderHuffmanTree(probabils)
/// for symbol in reversed(message): # Note: reversed
/// coder.encode_symbol(symbol, encoder_codebook)
///
/// # Obtain the compressed representation and the bitrate:
/// compressed, bitrate = coder.get_compressed()
/// print(compressed, bitrate) # (prints: [[2818274807, 129455] 48)
/// print(f"(in binary: {[bin(word) for word in compressed]}")
///
/// # Decode the message (we could explicitly construct a decoder:
/// # `decoder = constriction.symbol.StackCoder(compressed)`
/// # but we can also also reuse our existing `coder` for decoding):
/// decoded = []
/// decoder_codebook = constriction.symbol.huffman.DecoderHuffmanTree(probabils)
/// for symbol in range(19):
/// decoded.append(coder.decode_symbol(decoder_codebook))
///
/// assert decoded == message # (verifies correctness)
/// ```
#[pymodule]
#[pyo3(name = "symbol")]
fn init_symbol(py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
symbol::init_module(py, module)
}

#[derive(Debug, Clone)]
pub enum PyReadonlyFloatArray<'py, D: ndarray::Dimension> {
F32(PyReadonlyArray<'py, f32, D>),
Expand Down Expand Up @@ -371,6 +233,11 @@ impl<'py, D: ndarray::Dimension> PyReadonlyFloatArray<'py, D> {
}
}

fn array1_to_vec<T: numpy::Element + Clone>(x: PyReadonlyArray1<'_, T>) -> Vec<T> {
x.to_vec()
.unwrap_or_else(|_| x.as_array().iter().cloned().collect())
}

impl From<NanError> for PyErr {
fn from(_err: NanError) -> Self {
pyo3::exceptions::PyFloatingPointError::new_err(
Expand Down
Loading
Loading