-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flash attention on NVIDIA #678
Comments
The first two are totally normal and fine.
|
For 3:
Do you have a repro for 3? Do you have the issue with today containers? |
#795 should fix (thanks @jennifgcrl !) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi again :) - I am getting the following JAX notifications while running a training job and I was wondering if you can provide some clarity.
INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
I do run the training job on an instance with 8 A100 GPUs. Does this notification indicate that each of the GPUs is training a different model independently instead of distributedly training a single model or it just means that I have only one node/instance so there is no use for jax.distributed? I did not find anything anything on Levanter documentation about the need to provide a "distributed config" but maybe I missed something??? On wandb-log I can see all 8 GPUs being utilized under the experiment.
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
This is not clear to me if it causes any issues in the training process
/levanter/src/levanter/models/attention.py:266: UserWarning: transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention.. Falling back to the reference implementation.
I used
ghcr.io/nvidia/jax:levanter-2024-07-24
as my base container and create a new version where I installed transformers_engine like thatRUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
as explained in NVIDIA's documentation. It doesn't seem to do the job though since I am still seeing the warning. Do you have any suggestions on this??Thanks a lot in advance!!
The text was updated successfully, but these errors were encountered: