This library provides solvers for performing regularized and unregularized Optimal Transport in Rust.
Inspired by Python Optimal Transport, this library provides the following solvers:
- Network simplex algorithm for linear program / Earth Movers Distance
- Entropic regularization OT solvers including Sinkhorn Knopp and Greedy Sinkhorn
- Unbalanced Sinkhorn Knopp
The library has been tested on macOS. It requires a C++ compiler for building the EMD solver and relies on the following Rust libraries:
- cxx 1.0
- thiserror 1.0
- ndarray 0.15
Edit your Cargo.toml with the following to use rust-optimal-transport in your project.
[dependencies]
rust-optimal-transport = "0.2"
If you would like to enable LAPACK backend (currently supporting OpenBLAS):
[dependencies]
rust-optimal-transport = { version = "0.2", features = ["blas"] }
This will link against an installed instance of OpenBLAS on your system. For more details see the ndarray-linalg crate.
- Import the library
use rust_optimal_transport as ot;
use ot::prelude::*;
- Compute OT matrix as the Earth Mover's Distance
use ndarray::prelude::*;
use ndarray_stats::QuantileExt; // max()
// Generate data by sampling a 2D gaussian distribution
let n = 100;
let mu_source = array![0., 0.];
let cov_source = array![[1., 0.], [0., 1.]];
let mu_target = array![4., 4.];
let cov_target = array![[1., -0.8], [-0.8, 1.]];
let source = ot::utils::sample_2D_gauss(n, &mu_source, &cov_source).unwrap();
let target = ot::utils::sample_2D_gauss(n, &mu_target, &cov_target).unwrap();
// Uniform weights on the source and target distributions
let mut source_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
let mut target_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
// Compute the cost between the distributions
let mut cost = dist(&source, &target, SqEuclidean);
// Normalize cost matrix for numerical stability
let max_cost = cost.max().unwrap();
cost = &cost / *max_cost;
// Compute the optimal transport matrix
let ot_matrix = EarthMovers::new(
&mut source_weights,
&mut target_weights,
&mut cost
).solve()?;
cargo test
If using M1 mac and linking against Homebrew's OpenBLAS, add the following to build.rs
:
println!("cargo:rustc-link-search=/opt/homebrew/opt/openblas/lib");
This library is inspired by Python Optimal Transport. The original authors and contributors of that project are listed at POT.