Is jit(f).lower() supposed to be callable inside jit? #14267
-
I am not sure whether this behavior is a bug, whether you are not supposed to call When you call This error does not appear when one either:
As far as I understand each tracer is annotated with a jit-level in order to detect escaping tracers. I suspect that in this case the tracer-level does not get incremented when Code to reproduce: import jax
jax.config.update('jax_check_tracer_leaks', True)
def f1(a, b):
a = a+b
@jax.jit
def f2(a, b):
# works
jax.jit(f1)(a,b)
# works
jax.jit(f1).lower(a,b)
f1_partial = jax.tree_util.Partial(f1, b=b)
# works
jax.jit(f1_partial)(a)
# crashes with error about leaked tracer.
jax.jit(f1_partial).lower(a)
f2(0,1) In case you wonder why I want to do this: I currently use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Thanks for the question! Producing such a bad error message is certainly a bug, but I believe the answer is no, @froystig can you correct me if I'm wrong, and also think about how we might improve the error message? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Producing such a bad error message is certainly a bug, but I believe the answer is no,
.lower
should only be used at "top-level" (i.e. not underneathjax.jit
or any other JAX transformation or staging mechanism).@froystig can you correct me if I'm wrong, and also think about how we might improve the error message?