diff --git a/README.md b/README.md index 62bb8934317d..6f94a7fa19b3 100644 --- a/README.md +++ b/README.md @@ -417,66 +417,82 @@ On Linux, it is often necessary to first update `pip` to a version that supports **These `pip` installations do not work with Windows, and may fail silently; see [above](#installation).** -### pip installation: GPU (CUDA) +### pip installation: GPU (CUDA, installed via pip, easier) -If you want to install JAX with both CPU and NVidia GPU support, you must first -install [CUDA](https://developer.nvidia.com/cuda-downloads) and -[CuDNN](https://developer.nvidia.com/CUDNN), -if they have not already been installed. Unlike some other popular deep -learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip` -package. - -JAX provides pre-built CUDA-compatible wheels for **Linux only**, -with CUDA 11.4 or newer, and CuDNN 8.2 or newer. Note these existing wheels are currently for `x86_64` architectures only. Other combinations of -operating system, CUDA, and CuDNN are possible, but require [building from -source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). +There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN +installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend +installing CUDA and CUDNN using the pip wheels, since it is much easier! -* CUDA 11.4 or newer is *required*. - * Your CUDA installation must be new enough to support your GPU. If you have - an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU, - you must use CUDA 11.8 or newer. -* The supported cuDNN versions for the prebuilt wheels are: - * cuDNN 8.6 or newer. We recommend using the cuDNN 8.6 wheel if your cuDNN - installation is new enough, since it supports additional functionality. - * cuDNN 8.2 or newer. -* You *must* use an NVIDIA driver version that is at least as new as your - [CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). - For example, if you have CUDA 11.4 update 4 installed, you must use NVIDIA - driver 470.82.01 or newer if on Linux. This is a strict requirement that - exists because JAX relies on JIT-compiling code; older drivers may lead to - failures. - * If you need to use an newer CUDA toolkit with an older driver, for example - on a cluster where you cannot update the NVIDIA driver easily, you may be - able to use the - [CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) - that NVIDIA provides for this purpose. - -Next, run +You must first install the NVIDIA driver. We +recommend installing the newest driver available from NVIDIA, but the driver +must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux. ```bash pip install --upgrade pip -# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer. + +# CUDA 12 installation +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# CUDA 11 installation # Note: wheels only available on linux. -pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` -**These `pip` installations do not work with Windows, and may fail silently; see -[above](#installation).** +### pip installation: GPU (CUDA, installed locally, harder) + +If you prefer to use a preinstalled copy of CUDA, you must first +install [CUDA](https://developer.nvidia.com/cuda-downloads) and +[CuDNN](https://developer.nvidia.com/CUDNN). + +JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other +combinations of operating system and architecture are possible, but require +[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). + +You should use an NVIDIA driver version that is at least as new as your +[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions). +If you need to use an newer CUDA toolkit with an older driver, for example +on a cluster where you cannot update the NVIDIA driver easily, you may be +able to use the +[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/) +that NVIDIA provides for this purpose. -The jaxlib version must correspond to the version of the existing CUDA -installation you want to use. You can specify a particular CUDA and CuDNN -version for jaxlib explicitly: +JAX currently ships three CUDA wheel variants: +* CUDA 12.0 and CuDNN 8.8. +* CUDA 11.8 and CuDNN 8.6. +* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued + with jax 0.4.8. + +You may use a JAX wheel provided the major version of your CUDA and CuDNN +installation matches, and the minor version is at least as new as the version +JAX expects. For example, you would be able to use the CUDA 12.0 wheel with +CUDA 12.1 and CuDNN 8.9. + +Your CUDA installation must also be new enough to support your GPU. If you have +an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU, +you must use CUDA 11.8 or newer. + + +To install, run ```bash pip install --upgrade pip -# Installs the wheel compatible with Cuda >= 11.8 and cudnn >= 8.6 -pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2 +# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer. +# Note: wheels only available on linux. +pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated). pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` +**These `pip` installations do not work with Windows, and may fail silently; see +[above](#installation).** + You can find your CUDA version with the command: ```bash