Skip to content

brentyi/jaxls

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jaxls

pyright

status: working! see limitations here

jaxls is a library for nonlinear least squares in JAX.

We provide a factor graph interface for specifying and solving least squares problems. We accelerate optimization by analyzing the structure of graphs: repeated factor and variable types are vectorized, and the sparsity of adjacency in the graph is translated into sparse matrix operations.

Features:

  • Automatic sparse Jacobians.
  • Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations included.
  • Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
  • Linear solvers: both direct (sparse Cholesky via CHOLMOD, on CPU) and iterative (Jacobi-preconditioned Conjugate Gradient).

Use cases are primarily in least squares problems that are inherently (1) sparse and (2) inefficient to solve with gradient-based methods. In robotics, these are ubiquitous across classical approaches to perception, planning, and control.

For the first iteration of this library, written for IROS 2021, see jaxfg. jaxls is a rewrite that aims to be faster and easier to use. For additional references, see inspirations like GTSAM, Ceres Solver, minisam, SwiftFusion, g2o.

Installation

jaxls supports python>=3.12.

For Cholesky factorization via CHOLMOD, scikit-sparse requires SuiteSparse:

# Via conda.
conda install conda-forge::suitesparse
# Via apt.
sudo apt update
sudo apt install -y libsuitesparse-dev
# Via brew.
brew install suite-sparse

Then, from your environment of choice:

git clone https://github.com/brentyi/jaxls.git
cd jaxls
pip install -e .

Pose graph example

import jaxls
import jaxlie

Defining variables. Each variable is given an integer ID. They don't need to be contiguous.

pose_vars = [jaxls.SE2Var(0), jaxls.SE2Var(1)]

Defining factors. Factors are defined using a callable cost function and a set of arguments.

# Factors take two arguments:
# - A callable with signature `(jaxls.VarValues, *Args) -> jax.Array`.
# - A tuple of arguments: the type should be `tuple[*Args]`.
#
# All arguments should be PyTree structures. Variable types within the PyTree
# will be automatically detected.
factors = [
    # Cost on pose 0.
    jaxls.Factor.make(
        lambda vals, var, init: (vals[var] @ init.inverse()).log(),
        (pose_vars[0], jaxlie.SE2.from_xy_theta(0.0, 0.0, 0.0)),
    ),
    # Cost on pose 1.
    jaxls.Factor.make(
        lambda vals, var, init: (vals[var] @ init.inverse()).log(),
        (pose_vars[1], jaxlie.SE2.from_xy_theta(2.0, 0.0, 0.0)),
    ),
    # Cost between poses.
    jaxls.Factor.make(
        lambda vals, var0, var1, delta: (
            (vals[var0].inverse() @ vals[var1]) @ delta.inverse()
        ).log(),
        (pose_vars[0], pose_vars[1], jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)),
    ),
]

Factors with similar structure, like the first two in this example, will be vectorized under-the-hood.

Solving optimization problems. We can set up the optimization problem, solve it, and print the solutions:

graph = jaxls.FactorGraph.make(factors, pose_vars)
solution = graph.solve()
print("All solutions", solution)
print("Pose 0", solution[pose_vars[0]])
print("Pose 1", solution[pose_vars[1]])

Limitations

There are many practical features that we don't currently support:

  • GPU accelerated Cholesky factorization. (for CHOLMOD we wrap scikit-sparse, which runs on CPU only)
  • Covariance estimation / marginalization.
  • Incremental solves.
  • Analytical Jacobians.