Skip to content

Jitting a function that returns a function #7053

Answered by inailuig
attila-i-szabo asked this question in Q&A
Discussion options

You must be logged in to vote

With #7101 you might be able to do something like:

import jax
from functools import partial


@partial(jax.jit, static_argnums=0)
def _fwd(fn, primals):
    _, jvp_fn = jax.linearize(fn, primals)
    _, vjp_fn = jax.vjp(fn, primals)
    return jvp_fn, vjp_fn

@jax.jit
def _mat_vec(jvp_fn, vjp_fn, v):
    return vjp_fn(jvp_fn(v))

def factory(fn, primals):
    jvp_fn, vjp_fn = _fwd(fn, primals)
    return partial(_mat_vec, jvp_fn, vjp_fn)

Replies: 2 comments 7 replies

Comment options

You must be logged in to vote
1 reply
@attila-i-szabo
Comment options

Comment options

You must be logged in to vote
6 replies
@mattjj
Comment options

@mattjj
Comment options

@PhilipVinc
Comment options

@mattjj
Comment options

@inailuig
Comment options

Answer selected by attila-i-szabo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants