Jitting a function that returns a function #7053
-
I'm writing a factory function that returns the function @partial(jit, static_argnums=0)
def factory(fn, primals):
_, jvp_fn = linearize(fn, primals)
_, vjp_fn = vjp(fn, primals)
def mat_vec(v):
return vjp_fn(jvp_fn(v))
return mat_vec I get the error
which I take to mean that functions cannot be returned from a jitted function. My use case for this requires this matrix-vector product with many vectors (not known in advance) for the same Jacobian (so I want to return a function), but I also use several different |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 7 replies
-
Thanks for the question! A jitted function cannot return a function, but an un-jitted function can return a jitted function: def factory(fn, primals):
_, jvp_fn = linearize(fn, primals)
_, vjp_fn = vjp(fn, primals)
@jit
def mat_vec(v):
return vjp_fn(jvp_fn(v))
return mat_vec would that be sufficient for your use case? |
Beta Was this translation helpful? Give feedback.
-
With #7101 you might be able to do something like: import jax
from functools import partial
@partial(jax.jit, static_argnums=0)
def _fwd(fn, primals):
_, jvp_fn = jax.linearize(fn, primals)
_, vjp_fn = jax.vjp(fn, primals)
return jvp_fn, vjp_fn
@jax.jit
def _mat_vec(jvp_fn, vjp_fn, v):
return vjp_fn(jvp_fn(v))
def factory(fn, primals):
jvp_fn, vjp_fn = _fwd(fn, primals)
return partial(_mat_vec, jvp_fn, vjp_fn) |
Beta Was this translation helpful? Give feedback.
With #7101 you might be able to do something like: