Replies: 1 comment 7 replies
-
Try setting
Does that help? (It might be a bug in the traceback filtering, since we should only filter JAX-internal frames.) |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello! 😃
I just changed the input to a complex neural network training loop. I'm getting a strange error on the second call of a (complicated) jitted function:
Unfortunately the traceback doesn't say where this happens in the code at all. Usually jax does a pretty good job giving the full traceback (even with traced arrays), but here I get nothing but the very last frame of the traceback, which doesn't help me at all. In particular, I wouldn't know where to put a jax.debug.print or breakpoint since I don't know where this happens :(
When I print
self
in the last frame of the Traceback, I getThere is no error when the function is not jitted.
Any help or pointers would be great
Thanks!
I'm using jax 0.4.2, and cudnn 8.2.
Beta Was this translation helpful? Give feedback.
All reactions