What is the best way to get both the Jacobian and Hessian of a function? #14140
Answered
by
jakevdp
newalexander
asked this question in
Q&A
-
I have a function We can calculate this in jax: from jax import jacfwd, jacrev
def f(x): ...
x = ...
jacobian = jacfwd(f)(x)
hessian = jacfwd(jacrev(f))(x) However, this reuses some derivative calculations. Is there a way to share these calculations across the Jacobian and Hessian calls? For example, we can do this if we want to calculate both a function value and a Jacobian. from jax import numpy as jnp, vmap, jvp
def f(x): ...
x = ...
def get_jacfwd_custom(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x, ), (s,))
f_val, j_val = vmap(_jvp, in_axes=0)(jnp.eye(len(x)))
return (f_val[0, :], # all rows of `f_val` are the same
j_val.T)
return jacfun
f_val, j_val = get_jacfwd_custom(f)(x) Thank you! |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jan 25, 2023
Replies: 1 comment
-
Sometimes the simplest approach is the best: @jit
def jac_and_hess(x):
return jax.jacfwd(f)(x), jax.hessian(f)(x) The compiler in most cases is able to de-duplicate operations within a JIT-compiled function. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
newalexander
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Sometimes the simplest approach is the best:
The compiler in most cases is able to de-duplicate operations within a JIT-compiled function.