Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose option to switch between sparse matrix representations #22

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,24 @@ def solve(
| ConjugateGradientConfig = "conjugate_gradient",
trust_region: TrustRegionConfig | None = TrustRegionConfig(),
termination: TerminationConfig = TerminationConfig(),
sparse_mode: Literal["blockrow", "coo", "csr"] = "blockrow",
verbose: bool = True,
) -> VarValues:
"""Solve the nonlinear least squares problem using either Gauss-Newton
or Levenberg-Marquardt."""
or Levenberg-Marquardt.

Args:
initial_vals: Initial values for the variables. If None, default values will be used.
linear_solver: The linear solver to use.
trust_region: Configuration for Levenberg-Marquardt trust region.
termination: Configuration for termination criteria.
sparse_mode: The representation to use for sparse matrix
multiplication. Can be "blockrow", "coo", or "csr".
verbose: Whether to print verbose output during optimization.

Returns:
Optimized variable values.
"""
if initial_vals is None:
initial_vals = VarValues.make(
var_type(ids) for var_type, ids in self.sorted_ids_from_var_type.items()
Expand All @@ -79,7 +93,12 @@ def solve(
linear_solver = "conjugate_gradient"

solver = NonlinearSolver(
linear_solver, trust_region, termination, conjugate_gradient_config, verbose
linear_solver,
trust_region,
termination,
conjugate_gradient_config,
sparse_mode,
verbose,
)
return solver.solve(graph=self, initial_vals=initial_vals)

Expand Down
32 changes: 26 additions & 6 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
make_point_jacobi_precoditioner,
)

from ._sparse_matrices import BlockRowSparseMatrix, SparseCsrMatrix
from ._sparse_matrices import BlockRowSparseMatrix, SparseCooMatrix, SparseCsrMatrix
from ._variables import VarTypeOrdering, VarValues
from .utils import jax_log

Expand Down Expand Up @@ -191,6 +191,7 @@ class NonlinearSolver:
trust_region: TrustRegionConfig | None
termination: TerminationConfig
conjugate_gradient_config: ConjugateGradientConfig | None
sparse_mode: jdc.Static[Literal["blockrow", "coo", "csr"]]
verbose: jdc.Static[bool]

@jdc.jit
Expand Down Expand Up @@ -254,11 +255,30 @@ def step(
)

# linear_transpose() will return a tuple, with one element per primal.
A_multiply = A_blocksparse.multiply
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]
if self.sparse_mode == "blockrow":
A_multiply = A_blocksparse.multiply
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]
elif self.sparse_mode == "coo":
A_coo = SparseCooMatrix(
values=jac_values, coords=graph.jac_coords_coo
).as_jax_bcoo()
AT_coo = A_coo.transpose()
A_multiply = lambda vec: A_coo @ vec
AT_multiply = lambda vec: AT_coo @ vec
elif self.sparse_mode == "csr":
A_csr = SparseCsrMatrix(
values=jac_values, coords=graph.jac_coords_csr
).as_jax_bcsr()
A_multiply = lambda vec: A_csr @ vec
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]
else:
assert_never(self.sparse_mode)

# Compute right-hand side of normal equation.
ATb = -AT_multiply(state.residual_vector)
Expand Down
8 changes: 8 additions & 0 deletions src/jaxls/_sparse_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ class SparseCsrMatrix:
coords: SparseCsrCoordinates
"""Indices describing non-zero entries."""

def as_jax_bcsr(self) -> jax.experimental.sparse.BCSR:
return jax.experimental.sparse.BCSR(
args=(self.values, self.coords.indices, self.coords.indptr),
shape=self.coords.shape,
indices_sorted=True,
unique_indices=True,
)


@jdc.pytree_dataclass
class SparseCooCoordinates:
Expand Down
Loading