Replies: 1 comment 12 replies
-
Any updates on this? I'm running into a similar issue. |
Beta Was this translation helpful? Give feedback.
12 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I've been looking at traces of model-parallel training I've implemented in jax using pjit and noticed a curious thing: every call to pjitted function hits a function called
cache_miss
and does quite a lot of computations in python.I'm wondering whether this is expected or I have set up something wrong and jax is re-doing some work it's supposed to do only once. It has practical significance for me because overhead from these python pjit activities can sometimes be large enough to make GPU pipeline wait.
The training step function I pjit looks something like this:
Here's how I compile it:
On every training step I just call it like this:
Beta Was this translation helpful? Give feedback.
All reactions