Replies: 2 comments
-
same problem i stuck in. |
Beta Was this translation helpful? Give feedback.
-
JAX preallocates the If you need more memory than the default of 75%, then just increase the value to something like 0.99. Then it will preallocate more memory. It can be useful to turn of preallocation while debugging so that you can see how full the memory is. You can turn off preallocation by running For reference see: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html |
Beta Was this translation helpful? Give feedback.
-
Hello, I am wondering the mechanism of how the GPU memory is used in jax.
As the
XLA_PYTHON_CLIENT_MEM_FRACTION
is set to 0.9 as default, the about 33G for each GPU VRAM is preallocated when the script starts in 2xA100-40G.When I load LLaMA 30b checkpoint (, which is 61GB for float16 precision), OOM occurs although it seems there is enough capacity to load the checkpoint.
So, I change the default
XLA_PYTHON_CLIENT_MEM_FRACTION
to 0.99, then I can successfully load the checkpoint.There are two hypothesis that can explain this phonomena.
XLA_PYTHON_CLIENT_MEM_FRACTION
.I think the latter is true, but I ask to this community for clarity.
Thank you
Beta Was this translation helpful? Give feedback.
All reactions