Skip to content

kachark/rust-optimal-transport

Repository files navigation

Rust Optimal Transport

Crates.io GitHub

This library provides solvers for performing regularized and unregularized Optimal Transport in Rust.

Inspired by Python Optimal Transport, this library provides the following solvers:

  • Network simplex algorithm for linear program / Earth Movers Distance
  • Entropic regularization OT solvers including Sinkhorn Knopp and Greedy Sinkhorn
  • Unbalanced Sinkhorn Knopp

Installation

The library has been tested on macOS. It requires a C++ compiler for building the EMD solver and relies on the following Rust libraries:

  • cxx 1.0
  • thiserror 1.0
  • ndarray 0.15

Cargo

Edit your Cargo.toml with the following to use rust-optimal-transport in your project.

[dependencies]
rust-optimal-transport = "0.2"

Features

If you would like to enable LAPACK backend (currently supporting OpenBLAS):

[dependencies]
rust-optimal-transport = { version = "0.2", features = ["blas"] }

This will link against an installed instance of OpenBLAS on your system. For more details see the ndarray-linalg crate.

Examples

Short examples

  • Import the library
use rust_optimal_transport as ot;
use ot::prelude::*;
  • Compute OT matrix as the Earth Mover's Distance
use ndarray::prelude::*;
use ndarray_stats::QuantileExt; // max()

// Generate data by sampling a 2D gaussian distribution
let n = 100;

let mu_source = array![0., 0.];
let cov_source = array![[1., 0.], [0., 1.]];

let mu_target = array![4., 4.];
let cov_target = array![[1., -0.8], [-0.8, 1.]];

let source = ot::utils::sample_2D_gauss(n, &mu_source, &cov_source).unwrap();
let target = ot::utils::sample_2D_gauss(n, &mu_target, &cov_target).unwrap();

// Uniform weights on the source and target distributions
let mut source_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
let mut target_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));

// Compute the cost between the distributions
let mut cost = dist(&source, &target, SqEuclidean);

// Normalize cost matrix for numerical stability
let max_cost = cost.max().unwrap();
cost = &cost / *max_cost;

// Compute the optimal transport matrix
let ot_matrix = EarthMovers::new(
    &mut source_weights,
    &mut target_weights,
    &mut cost
).solve()?;

Testing

cargo test

If using M1 mac and linking against Homebrew's OpenBLAS, add the following to build.rs:

println!("cargo:rustc-link-search=/opt/homebrew/opt/openblas/lib");

Acknowledgements

This library is inspired by Python Optimal Transport. The original authors and contributors of that project are listed at POT.