Skip to content

Commit

Permalink
Merge branch 'main' into stackless
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Aug 28, 2024
2 parents 6b9b191 + ef33cf5 commit f4200ba
Show file tree
Hide file tree
Showing 439 changed files with 29,874 additions and 8,469 deletions.
52 changes: 15 additions & 37 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ build --config=short_logs

build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.

# Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled,
# these values are overridden.
build --@xla//xla/python:enable_gpu=false

###########################################################################

build:posix --copt=-fvisibility=hidden
Expand All @@ -65,34 +61,21 @@ build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda
build:cuda --@xla//xla/python:enable_gpu=true
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
build:cuda --define=xla_python_enable_gpu=true

# Build with nvcc for CUDA and clang for host
build:nvcc_clang --config=cuda
# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang
build:nvcc_clang --action_env=TF_CUDA_CLANG="1"
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# Requires MSVC and LLVM to be installed
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
build:win_clang --compiler=clang-cl

# Later Bazel flag values override earlier values.
# TODO(jieying): remove enable_gpu and xla_python_enable_gpu from build:cuda
# after the pluin is released.
build:cuda_plugin --@xla//xla/python:enable_gpu=false
build:cuda_plugin --define=xla_python_enable_gpu=false

build:rocm_plugin --@xla//xla/python:enable_gpu=false
build:rocm_plugin --define=xla_python_enable_gpu=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
Expand All @@ -109,7 +92,6 @@ build:cuda --linkopt=-Wl,--disable-new-dtags

build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
build:cuda_clang --action_env=TF_CUDA_CLANG="1"
# Disable clang extention that rejects type definitions within offsetof.
# This was added in clang-16 by https://reviews.llvm.org/D133574.
# Can be removed once upb is updated, since a type definition is used within
Expand All @@ -119,10 +101,14 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:cuda_clang --copt=-Qunused-arguments

# Build with nvcc for CUDA and clang for host
build:nvcc_clang --config=cuda
build:nvcc_clang --config=cuda_clang
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --@xla//xla/python:enable_gpu=true
build:rocm --define=xla_python_enable_gpu=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"

Expand Down Expand Up @@ -215,7 +201,6 @@ build:rbe_linux --host_linkopt=-lm
# https://github.com/bazelbuild/bazel/issues/13623
build:rbe_cpu_linux_base --config=rbe_linux
build:rbe_cpu_linux_base --config=cuda_clang
build:rbe_cpu_linux_base --action_env=TF_NVCC_CLANG="1"
build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64"
Expand All @@ -235,22 +220,15 @@ build:rbe_linux_cuda_base --config=cuda
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1

build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1"
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDA_VERSION=12
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=9
build:rbe_linux_cuda12.3_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12"
build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl"
# RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility
build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat
build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
documentation_render:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: '33f2d2ea2f3dd2b3ceeeb4519d55e08096184149' # Latest commit as of 2024-05-28
ref: 'db95e67b29235249e5776ca2b6bb4e77117e0690' # Latest commit as of 2024-08-08
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -38,8 +38,8 @@ jobs:
python -m pip install -r array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: jax.experimental.array_api
ARRAY_API_TESTS_MODULE: jax.numpy
JAX_ENABLE_X64: 'true'
run: |
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt
2 changes: 1 addition & 1 deletion .github/workflows/upstream-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
&& steps.status.outcome == 'failure'
&& github.event_name == 'schedule'
&& github.repository == 'google/jax'
uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet: actions/upload-artifact@v4
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with:
name: output-${{ matrix.python-version }}-log.jsonl
path: output-${{ matrix.python-version }}-log.jsonl
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
--bazel_options=--config=win_clang `
--verbose
- uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet: actions/upload-artifact@v4
- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
path: ${{ github.workspace }}\dist\*.whl
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
--bazel_options=--color=yes `
--bazel_options=--config=win_clang
- uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet: actions/upload-artifact@v4
- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with:
name: wheels
path: ${{ github.workspace }}\jax\dist\*.whl
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0
hooks:
- id: check-ast
- id: check-merge-conflict
Expand All @@ -26,21 +26,21 @@ repos:
files: \.py$

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1
hooks:
- id: ruff

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.10.0'
rev: 'd4911cfb7f1010759fde68da196036feeb25b99d' # frozen: v1.11.2
hooks:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.30, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.31, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
args: [--config=pyproject.toml]

- repo: https://github.com/mwouts/jupytext
rev: v1.16.1
rev: 8ed836db64ad5d304f2315e6bfd9049c9142e190 # frozen: v1.16.4
hooks:
- id: jupytext
files: docs/
Expand Down
59 changes: 57 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,71 @@ see {ref}`pallas-changelog`.

<!--
Remember to align the itemized text with the first line of an item within a list.
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.31
## jax 0.4.32

* Changes
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Added new {func}`jax.process_indices` function to replace the
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
modified to no longer support `complex dtypes`.

* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.

* Deprecations
* Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are
no longer allowed, after being deprecated since JAX v0.4.27.
* Deprecated the following APIs:
* `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly.
* `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`.
* `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`.
* The `jax.experimental.array_api` module is deprecated, and importing it is no
longer required to use the Array API. `jax.numpy` supports the array API
directly; see {ref}`python-array-api` for more information.
* The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
the future.

## jaxlib 0.4.32

* Breaking changes
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL
distributions, and then use CUDA libraries and tools as dependencies in
various Bazel targets. This enables more reproducible builds for JAX and its
supported CUDA versions.

* Changes
* SparseCore profiling is added.
* JAX now supports profiling [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore) on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's [TraceViewer](https://www.tensorflow.org/guide/profiler#trace_viewer).

## jax 0.4.31 (July 29, 2024)

* Deletion
* xmap has been deleted. Please use {func}`shard_map` as the replacement.

* Changes
* The minimum CuDNN version is v9.1. This was true in previous releases also,
but we now declare this version constraint formally.
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.
* The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
supported version until December 2024.
* The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum
supported version until January 2025.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
Expand All @@ -41,8 +94,10 @@ Remember to align the itemized text with the first line of an item within a list
or `enable_xla=False` is now deprecated and this support will be removed in
a future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
* The previously-deprecated function `jax.random.shuffle` has been removed;
instead use `jax.random.permutation` with `independent=True`.

## jaxlib 0.4.31
## jaxlib 0.4.31 (July 29, 2024)

* Bug fixes
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
Expand Down
47 changes: 47 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,50 @@ xla_workspace0()

load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)

cuda_json_init_repository()

load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)

cuda_redist_init_repositories(
cuda_redistributions = CUDA_REDISTRIBUTIONS,
)

cudnn_redist_init_repository(
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)

load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)

cuda_configure(name = "local_config_cuda")

load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)

nccl_redist_init_repository()

load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)

nccl_configure(name = "local_config_nccl")
Loading

0 comments on commit f4200ba

Please sign in to comment.