Replies: 2 comments 4 replies
-
Yes,you can linearize or manually implement it. In this line of code: _, JtV = jax.jvp(lambda th: E_jacob(x, th, rec, params), (theta_vec,), (theta_vec,))
_, e_vjp = jax.vjp(lambda th: E_jacob(x, th, rec, params), theta_vec)
Gv = e_vjp(HyJtV)[0] change to f_val, f_jvp = jax.linearize(lambda th: E_jacob(x, th, rec, params), theta_vec)
JtV = f_jvp(gradient)
def f_vjp(vec):
return f_jvp(vec) # Transposing JVP --> simulate VJP
Gv = f_vjp(HyJtV) With this the Jacobian will be compute by |
Beta Was this translation helpful? Give feedback.
4 replies
-
Thank you
judith valenzuela
El lun, 18 de nov de 2024 a la(s) 7:53 p. m., Howard ***@***.***> escribió:
Yes,you can linearize or manually implement it.
In this line of code:
_, JtV = jax.jvp(lambda th: E_jacob(x, th, rec, params), (theta_vec,), (theta_vec,))
_, e_vjp = jax.vjp(lambda th: E_jacob(x, th, rec, params), theta_vec)
Gv = e_vjp(HyJtV)[0]
change to
f_val, f_jvp = jax.linearize(lambda th: E_jacob(x, th, rec, params), theta_vec)
JtV = f_jvp(gradient)
def f_vjp(vec):
return f_jvp(vec) # Transposing JVP --> simulate VJP
Gv = f_vjp(HyJtV)
With this the Jacobian will be compute by jax.linearize is reused, saving computation and speeding up the code.
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you are subscribed to this thread.Message ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Dear All,
I would like to ask for advice, on how to accelerate the computation of GGN-vector products. I wrote what I believe should be a reasonably efficient implementation of a GGN-vector product using one jvp and one vjp. My understanding is that my code linearizes the function twice, once for jvp, and then again for vjp (although the whole code is JIT compiled, so maybe XLA is able to reuse all the jacobians under the hood).
However the autodial cookbook suggests that the GGN-vector product can be accelerated by linearizing the function just once (reusing function jacobians from jvp computation to calculate vjp). As far as I understand, I can use
jax.linearize
instead ofjax.jvp
, so I get anf_jvp
function instead of a singlejvp
. So, I guess my question is, whether there is a simple way, how to transposef_jvp
output ofjax.linearize
into af_vjp
function equivalent to the output ofjax.vjp
.Any advice would be super welcome!
Best,
Jan
Beta Was this translation helpful? Give feedback.
All reactions