-
Suppose we want to be able to pass a closure function (possibly jitted) into another jitted function. I think generally, we do one of two things shown below by import jax
from functools import partial
@partial(jax.jit, static_argnums=0)
def f1(g, x):
return g(x)
@jax.jit
def f2(g, x):
return g(x)
y = 0
# @jax.jit #maybe we want to jit this too
def g(x):
return x + y
x = 0
print(f1(g, x))
print(f2(jax.tree_util.Partial, x)) Which one is preferable? And what are the tradeoffs for each? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Did you mean for the second call to be this? print(f2(jax.tree_util.Partial(g), x)) Regardless, the general pattern that people use is either to use def f(g, x):
return g(x)
ans1 = jax.jit(f, static_argnames=['g'])(g, x)
ans2 = jax.jit(partial(f, g))(x) Both are idiomatic, and don't really have any tradeoffs; you can use whichever fits the context of your code. If you want Using |
Beta Was this translation helpful? Give feedback.
Did you mean for the second call to be this?
Regardless, the general pattern that people use is either to use
static_argnums/static_argnames
or use a closure. Side-by-side, they'd look like this:Both are idiomatic, and don't really have any tradeoffs; you can use whichever fits the context of your code. If you want
f
to be a stand-alone JIT-compiled function (i.e. not leave JITting up to the caller) thenstatic_argnames
is probably best.Using
tree_util.Partial
also works, the only tradeoff might be that it is more indirect and may be confus…