Skip to content

How to most efficiently calculate jvp when you want to calculate two tangents at the same primal? #14127

Answered by soraros
adam-hartshorne asked this question in Q&A
Discussion options

You must be logged in to vote

You could use a vmapp-ed f_jvp. Noticeably, this trick is also used to implement jacfwd and jacrev.

from functools import partial

import jax.numpy as jnp
from jax import jvp, linearize, vmap

ps = (jnp.arange(3.), 1.)

t1 = jnp.array([0, 0, 1.])
t2 = jnp.array([0, 1., 0])
ts = (jnp.stack([t1, t2]), jnp.stack([0., 0.]))

def f(x, z):
  return x + jnp.sin(z)

def f_jvps_manual(ps, ts):
  _, f_jvp = linearize(f, *ps)
  return tuple(f_jvp(*x) for x in zip(*ts))

def f_jvps_vmap(ps, ts):
  _, ts_out = vmap(partial(jvp, f, ps))(ts)
  return tuple(ts_out)

print(f_jvps_manual(ps, ts))
print(f_jvps_vmap(ps, ts))

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@mattjj
Comment options

@adam-hartshorne
Comment options

@soraros
Comment options

Answer selected by adam-hartshorne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants