How to most efficiently calculate jvp when you want to calculate two tangents at the same primal? #14127
Answered
by
soraros
adam-hartshorne
asked this question in
Q&A
-
If I have a primal x, a variable z (that we don't want to calculate tangents over), and two tangents t1 and t2, where I want to calculate the primal_out and the tangent_out for both t1 and t2 at x. I am aware of jax.linearize approach i.e
is this really the most efficient way, or with such a small number of tangents is there a better way that doesn't require creating f_jvp and two calls to it? |
Beta Was this translation helpful? Give feedback.
Answered by
soraros
Jan 24, 2023
Replies: 1 comment 3 replies
-
You could use a from functools import partial
import jax.numpy as jnp
from jax import jvp, linearize, vmap
ps = (jnp.arange(3.), 1.)
t1 = jnp.array([0, 0, 1.])
t2 = jnp.array([0, 1., 0])
ts = (jnp.stack([t1, t2]), jnp.stack([0., 0.]))
def f(x, z):
return x + jnp.sin(z)
def f_jvps_manual(ps, ts):
_, f_jvp = linearize(f, *ps)
return tuple(f_jvp(*x) for x in zip(*ts))
def f_jvps_vmap(ps, ts):
_, ts_out = vmap(partial(jvp, f, ps))(ts)
return tuple(ts_out)
print(f_jvps_manual(ps, ts))
print(f_jvps_vmap(ps, ts)) |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
adam-hartshorne
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You could use a
vmap
p-edf_jvp
. Noticeably, this trick is also used to implementjacfwd
andjacrev
.