Components | Build | Test | |
---|---|---|---|
|
|||
upstream | |||
rosetta | |||
upstream | |||
rosetta | |||
This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: T5x, PAXML, Transformer Engine, Pallas and others to come soon.
We currently enable training and evaluation for the following models:
Model Name | Pretraining | Fine-tuning | Evaluation |
---|---|---|---|
GPT-3(paxml) | ✔️ | ✔️ | |
LLaMA2(paxml) | ✔️ | ||
t5(t5x) | ✔️ | ✔️ | ✔️ |
ViT | ✔️ | ✔️ | ✔️ |
Imagen | ✔️ | ✔️ |
We will update this table as new models become available, so stay tuned.
The JAX image is embedded with the following flags and environment variables for performance tuning:
XLA Flags | Value | Explanation |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler |
true |
allows XLA to move communication collectives to increase overlap with compute kernels |
--xla_gpu_enable_async_all_gather |
true |
allows XLA to run NCCL AllGather kernels on a separate CUDA stream to allow overlap with compute kernels |
--xla_gpu_enable_async_reduce_scatter |
true |
allows XLA to run NCCL ReduceScatter kernels on a separate CUDA stream to allow overlap with compute kernels |
--xla_gpu_enable_triton_gemm |
false |
use cuBLAS instead of Trition GeMM kernels |
Environment Variable | Value | Explanation |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS |
1 |
use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
NCCL_NVLS_ENABLE |
0 |
Disables NVLink SHARP (1). Future releases will re-enable this feature. |
CUDA_MODULE_LOADING |
EAGER |
Disables lazy-loading (1) which uses slightly more GPU memory. |
See this page for more information about how to profile JAX programs on GPU.
`bus error` when running JAX in a docker container
Solution:
docker run -it --shm-size=1g ...
Explanation:
The bus error
might occur due to the size limitation of /dev/shm
. You can address this by increasing the shared memory size using
the --shm-size
option when launching your container.
enroot/pyxis reports error code 404 when importing multi-arch images
Problem description:
slurmstepd: error: pyxis: [INFO] Authentication succeeded
slurmstepd: error: pyxis: [INFO] Fetching image manifest list
slurmstepd: error: pyxis: [INFO] Fetching image manifest
slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/<TAG> returned error code: 404 Not Found
Solution: Upgrade enroot or apply a single-file patch as mentioned in the enroot v3.4.0 release note.
Explanation: Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists but has switched to using the Open Container Initiative (OCI) format since 20.10. Enroot added support for OCI format in version 3.4.0.
- AWS
- GCP
- Azure
- OCI