Somax is a library of Second-Order Methods for stochastic optimization written in JAX. Somax is based on the JAXopt StochasticSolver API, and can be used as a drop-in replacement for JAXopt as well as Optax solvers.
Currently supported methods:
- Diagonal Scaling:
- Hessian-free Optimization:
- Quasi-Newton:
- Gauss-Newton:
- Natural Gradient:
Future releases:
- Add support for separate "gradient batches" and "curvature batches" for all solvers;
- Add support for Optax rate schedules.
*The catfish in the logo is a nod to "сом", the Belarusian word for "catfish", also pronounced as "som".
pip install python-somax
Requires JAXopt 0.8.2+.
from somax import EGN
# initialize the solver
solver = EGN(
predict_fun=model.apply,
loss_type='mse',
learning_rate=0.1,
regularizer=1.0,
)
# initialize the solver state
opt_state = solver.init_state(params)
# run the optimization loop
for i in range(10):
params, opt_state = solver.update(params, opt_state, batch_x, batch_y)
See more in the examples folder.
@misc{korbit2024somax,
author = {Nick Korbit},
title = {{SOMAX}: a library of second-order methods for stochastic optimization written in {JAX}},
year = {2024},
url = {https://github.com/cor3bit/somax},
}
Optimization with JAX
Optax: first-order gradient (SGD, Adam, ...) optimisers.
JAXopt: deterministic second-order methods (e.g., Gauss-Newton, Levenberg
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
Awesome Projects
Awesome JAX: a longer list of various JAX projects.
Awesome SOMs: a list
of resources for second-order optimization methods in machine learning.
Some of the implementation ideas are based on the following repositories:
-
Line Search in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/armijo_sgd.py#L48
-
L-BFGS Inverse Hessian-Gradient product in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/lbfgs.py#L44
-
AdaHessian (official implementation): https://github.com/amirgholami/adahessian
-
AdaHessian (Nestor Demeure's implementation): https://github.com/nestordemeure/AdaHessianJax
-
Sophia (official implementation): https://github.com/Liuhong99/Sophia
-
Sophia (levanter implementation): https://github.com/stanford-crfm/levanter/blob/main/src/levanter/optim/sophia.py