Replies: 2 comments 1 reply
-
Can you provide a small reproducible example of the jitted training loop? How long does it take to compile? |
Beta Was this translation helpful? Give feedback.
-
Hi, thanks for your response.
To reproduce: Here is the repo: https://github.com/uta-smile/lab_challenge_23fall Install by
And install all dependence with
After that, we can see two situations.
Will generate the following errors:
It will not generate any error. |
Beta Was this translation helpful? Give feedback.
-
I'm not quite sure what I need to present here.
The JAX version is 0.4.16, the Flax version is 0.7.4, H100 and CUDA12 are used, and CUDA_VISIBLE_DEVICES is also used to limit seeing only 1 GPU.
Here is the model, it's a very simple UNet.
Here is the log
Beta Was this translation helpful? Give feedback.
All reactions