Jax and Flax implementations of the AdaHessian optimizer, a second order optimizer for neural networks.
You can install this librarie with:
pip install git+https://github.com/nestordemeure/AdaHessianJax.git
The implementation provides both a fast way to evaluate the diagonal of the hessian of a program and an optimizer API that stays close to jax.experimental.optimizers in the adahessianJax.jax
namespace:
# gets the jax.experimental.optimizers version of the optimizer
from adahessianJax import grad_and_hessian
from adahessianJax.jaxOptimizer import adahessian
# builds an optimizer triplet, no need to pass a learning rate
opt_init, opt_update, get_params = adahessian()
# initialize the optimizer with the initial value of the parameters to optimize
opt_state = opt_init(init_params)
rng = jax.random.PRNGKey(0)
# uses the optimizer, note that we pass the gradient AND a hessian
params = get_params(opt_state)
rng, rng_step = jax.random.split(rng)
gradient, hessian = grad_and_hessian(loss, (params, batch), rng_step)
opt_state = opt_update(i, gradient, hessian, opt_state)
The example folder contains JAX's MNIST classification example updated to be run with Adam or AdaHessian in order to compare both implementations.
The implementation provides both a fast way to evaluate the diagonal of the hessian of a program and an optimizer API that stays close to Flax optimizers in the adahessianJax.flax
namespace:
# gets the flax version of the optimizer
from adahessianJax import grad_and_hessian
from adahessianJax.flaxOptimizer import Adahessian
# defines the optimizer, no need to pass a learning rate
optimizer_def = Adahessian()
# initialize the optimizer with the initial value of the parameters to optimize
optimizer = optimizer_def.create(init_params)
rng = jax.random.PRNGKey(0)
# uses the optimizer, note that we pass the gradient AND a hessian
params = optimizer.target
rng, rng_step = jax.random.split(rng)
gradient, hessian = grad_and_hessian(loss, (params, batch), rng_step)
optimizer = optimizer.apply_gradient(gradient, hessian)
Argument | Description |
---|---|
step_size (float, optional) |
learning rate (default: 1e-3) |
b1 (float, optional) |
the exponential decay rate for the first moment estimates (default: 0.9) |
b2 (float, optional) |
the exponential decay rate for the squared hessian estimates (default: 0.999) |
eps (float, optional) |
term added to the denominator to improve numerical stability (default: 1e-8) |
weight_decay (float, optional) |
weight decay (L2 penalty) (default: 0.0) |
hessian_power (float, optional) |
hessian power (default: 1.0) |
Returns a (init_fun, update_fun, get_params)
triple of functions modeling the optimizer, similarly to the jax.experimental.optimizers API.
The only difference is that update_fun
takes both a gradient and a hessian parameter.
Argument | Description |
---|---|
learning_rate (float, optional) |
learning rate (default: 1e-3) |
beta1 (float, optional) |
the exponential decay rate for the first moment estimates (default: 0.9) |
beta2 (float, optional) |
the exponential decay rate for the squared hessian estimates (default: 0.999) |
eps (float, optional) |
term added to the denominator to improve numerical stability (default: 1e-8) |
weight_decay (float, optional) |
weight decay (L2 penalty) (default: 0.0) |
hessian_power (float, optional) |
hessian power (default: 1.0) |
Returns an optimizer definition, similarly to the Flax optimizers API.
The only difference is that apply_gradient
takes both a gradient and a hessian parameter.
Argument | Description |
---|---|
fun (Callable) |
function to be differentiated |
fun_input (Tuple) |
value at which the gradient and hessian of fun should be evaluated |
rng (ndarray) |
a PRNGKey used as the random key |
argnum (int, optional) |
specifies which positional argument to differentiate with respect to (default: 0) |
Returns a pair (gradient, hessian)
where the first element is the gradient of fun
evaluated in fun_input
and the second element is the diagonal of its hessian.
This function expects fun_input
to be a tuple and will fail otherwise.
One can pass (fun_input,)
if fun
has a single input that is not already a tuple.
Argument | Description |
---|---|
fun (Callable) |
function to be differentiated |
fun_input (Tuple) |
value at which the gradient and hessian of fun should be evaluated |
rng (ndarray) |
a PRNGKey used as the random key |
argnum (int, optional) |
specifies which positional argument to differentiate with respect to (default: 0) |
Returns a triplet (value, gradient, hessian)
where the first element is the value of fun
evaluated in fun_input
, the second element is its gradient and the third element is the diagonal of its hessian.
This function expects fun_input
to be a tuple and will fail otherwise.
One can pass (fun_input,)
if fun
has a single input that is not already a tuple.