Skip to content

Commit

Permalink
Update the CUDA installation instructions.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Mar 29, 2023
1 parent 8c4fed6 commit 775f404
Showing 1 changed file with 59 additions and 43 deletions.
102 changes: 59 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,66 +417,82 @@ On Linux, it is often necessary to first update `pip` to a version that supports
**These `pip` installations do not work with Windows, and may fail silently; see
[above](#installation).**

### pip installation: GPU (CUDA)
### pip installation: GPU (CUDA, installed via pip, easier)

If you want to install JAX with both CPU and NVidia GPU support, you must first
install [CUDA](https://developer.nvidia.com/cuda-downloads) and
[CuDNN](https://developer.nvidia.com/CUDNN),
if they have not already been installed. Unlike some other popular deep
learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip`
package.

JAX provides pre-built CUDA-compatible wheels for **Linux only**,
with CUDA 11.4 or newer, and CuDNN 8.2 or newer. Note these existing wheels are currently for `x86_64` architectures only. Other combinations of
operating system, CUDA, and CuDNN are possible, but require [building from
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
installing CUDA and CUDNN using the pip wheels, since it is much easier!

* CUDA 11.4 or newer is *required*.
* Your CUDA installation must be new enough to support your GPU. If you have
an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
you must use CUDA 11.8 or newer.
* The supported cuDNN versions for the prebuilt wheels are:
* cuDNN 8.6 or newer. We recommend using the cuDNN 8.6 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.2 or newer.
* You *must* use an NVIDIA driver version that is at least as new as your
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
For example, if you have CUDA 11.4 update 4 installed, you must use NVIDIA
driver 470.82.01 or newer if on Linux. This is a strict requirement that
exists because JAX relies on JIT-compiling code; older drivers may lead to
failures.
* If you need to use an newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.

Next, run
You must first install the NVIDIA driver. We
recommend installing the newest driver available from NVIDIA, but the driver
must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.

```bash
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.

# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

**These `pip` installations do not work with Windows, and may fail silently; see
[above](#installation).**
### pip installation: GPU (CUDA, installed locally, harder)

If you prefer to use a preinstalled copy of CUDA, you must first
install [CUDA](https://developer.nvidia.com/cuda-downloads) and
[CuDNN](https://developer.nvidia.com/CUDNN).

JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
combinations of operating system and architecture are possible, but require
[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).

You should use an NVIDIA driver version that is at least as new as your
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
If you need to use an newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.

The jaxlib version must correspond to the version of the existing CUDA
installation you want to use. You can specify a particular CUDA and CuDNN
version for jaxlib explicitly:
JAX currently ships three CUDA wheel variants:
* CUDA 12.0 and CuDNN 8.8.
* CUDA 11.8 and CuDNN 8.6.
* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued
with jax 0.4.8.

You may use a JAX wheel provided the major version of your CUDA and CuDNN
installation matches, and the minor version is at least as new as the version
JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
CUDA 12.1 and CuDNN 8.9.

Your CUDA installation must also be new enough to support your GPU. If you have
an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
you must use CUDA 11.8 or newer.


To install, run

```bash
pip install --upgrade pip

# Installs the wheel compatible with Cuda >= 11.8 and cudnn >= 8.6
pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated).
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

**These `pip` installations do not work with Windows, and may fail silently; see
[above](#installation).**

You can find your CUDA version with the command:

```bash
Expand Down

0 comments on commit 775f404

Please sign in to comment.