Image | Build | Tests |
---|---|---|
n/a | ||
Frameworks | ||
|
|
|
|
||
Rosetta | ||
This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: T5x, PAXML, Transformer Engine, and others to come soon.
We currently enable training and evaluation for the following models:
Model Name | Pretraining | Fine-tuning | Evaluation |
---|---|---|---|
GPT-3(paxml) | ✔️ | ✔️ | ✔️ |
t5(t5x) | ✔️ | ✔️ | ✔️ |
ViT | ✔️ | ✔️ | ✔️ |
We will update this table as new models become available, so stay tuned.
The JAX image is embedded with the following flags and environment variables for performance tuning:
XLA Flags | Value | Explanation |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler |
true |
allows XLA to move communication collectives to increase overlap with compute kernels |
--xla_gpu_enable_async_all_gather |
true |
allows XLA to run NCCL AllGather kernels on a separate CUDA stream to allow overlap with compute kernels |
--xla_gpu_enable_async_reduce_scatter |
true |
allows XLA to run NCCL ReduceScatter kernels on a separate CUDA stream to allow overlap with compute kernels |
--xla_gpu_enable_triton_gemm |
false |
use cuBLAS instead of Trition GeMM kernels |
Environment Variable | Value | Explanation |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS |
1 |
use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
NCCL_IB_SL |
1 |
defines the InfiniBand Service Level (1) |