Skip to content

Latest commit

 

History

History
180 lines (163 loc) · 9.75 KB

README.md

File metadata and controls

180 lines (163 loc) · 9.75 KB

JAX Toolbox

Image Build Tests

container-badge-base

build-badge-base n/a
Frameworks
container-badge-jax
build-badge-jax test-badge-jax-V100
test-badge-jax-A100
container-badge-t5x build-badge-t5x test-badge-t5x
container-badge-pax build-badge-pax test-badge-pax
container-badge-te build-badge-te unit-test-badge-te
integration-test-badge-te
Rosetta
container-badge-rosetta-t5x build-badge-rosetta-t5x test-badge-rosetta-t5x
container-badge-rosetta-pax build-badge-rosetta-pax test-badge-rosetta-pax

Note

This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: T5x, PAXML, Transformer Engine, and others to come soon.

Supported Models

We currently enable training and evaluation for the following models:

Model Name Pretraining Fine-tuning Evaluation
GPT-3(paxml) ✔️ ✔️ ✔️
t5(t5x) ✔️ ✔️ ✔️
ViT ✔️ ✔️ ✔️

We will update this table as new models become available, so stay tuned.

Environment Variables

The JAX image is embedded with the following flags and environment variables for performance tuning:

XLA Flags Value Explanation
--xla_gpu_enable_latency_hiding_scheduler true allows XLA to move communication collectives to increase overlap with compute kernels
--xla_gpu_enable_async_all_gather true allows XLA to run NCCL AllGather kernels on a separate CUDA stream to allow overlap with compute kernels
--xla_gpu_enable_async_reduce_scatter true allows XLA to run NCCL ReduceScatter kernels on a separate CUDA stream to allow overlap with compute kernels
--xla_gpu_enable_triton_gemm false use cuBLAS instead of Trition GeMM kernels
Environment Variable Value Explanation
CUDA_DEVICE_MAX_CONNECTIONS 1 use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches
NCCL_IB_SL 1 defines the InfiniBand Service Level (1)