Skip to content

Partial versus static_argnums for passing functions into jitted function #14694

Answered by jakevdp
KeAWang asked this question in Q&A
Discussion options

You must be logged in to vote

Did you mean for the second call to be this?

print(f2(jax.tree_util.Partial(g), x))

Regardless, the general pattern that people use is either to use static_argnums/static_argnames or use a closure. Side-by-side, they'd look like this:

def f(g, x):
    return g(x)

ans1 = jax.jit(f, static_argnames=['g'])(g, x)
ans2 = jax.jit(partial(f, g))(x)

Both are idiomatic, and don't really have any tradeoffs; you can use whichever fits the context of your code. If you want f to be a stand-alone JIT-compiled function (i.e. not leave JITting up to the caller) then static_argnames is probably best.

Using tree_util.Partial also works, the only tradeoff might be that it is more indirect and may be confus…

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@KeAWang
Comment options

@jakevdp
Comment options

@KeAWang
Comment options

@jakevdp
Comment options

@KeAWang
Comment options

Answer selected by KeAWang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants