Skip to content

Commit

Permalink
Merge branch 'rocm-main' into jax-mlgh-9948-add-gpu-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
charleshofer authored Nov 22, 2024
2 parents 54a75cb + 3be7c1e commit caf86f6
Show file tree
Hide file tree
Showing 159 changed files with 6,008 additions and 2,399 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/bazel_gpu_rbe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: CI - Bazel GPU tests (RBE)

on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
run_tests:
if: github.event.repository.fork == false
strategy:
matrix:
runner: ["linux-x86-n2-16"]

runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'

env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"

steps:
- uses: actions/checkout@v3
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel GPU Tests with RBE
run: ./ci/run_bazel_test_gpu_rbe.sh
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ jobs:
documentation_render:
name: Documentation - render documentation
runs-on: ubuntu-latest
timeout-minutes: 10
timeout-minutes: 20
strategy:
matrix:
python-version: ['3.10']
Expand Down
49 changes: 26 additions & 23 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
name: CI - Cloud TPU (nightly)
on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
- cron: "0 */2 * * *" # Run every 2 hours
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
Expand All @@ -26,15 +26,18 @@ jobs:
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
{type: "v3-8", cores: "4"},
{type: "v4-8", cores: "4"},
{type: "v5e-8", cores: "8"}
# {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
# {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240722
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
PYTHON: python${{ matrix.python-version }}
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 120
defaults:
run:
Expand All @@ -46,52 +49,52 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
pip install -U -r build/collect-profile-requirements.txt
$PYTHON -m pip install -U -r build/test-requirements.txt
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu
$PYTHON -m pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
pip install .[tpu] \
$PYTHON -m pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --pre libtpu \
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
$PYTHON -m pip install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
$PYTHON -m pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
python3 -c 'import jax; print("libtpu version:",
$PYTHON -c 'import sys; print("python version:", sys.version)'
$PYTHON -c 'import jax; print("jax version:", jax.__version__)'
$PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
$PYTHON -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
$PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻
ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
Expand Down
28 changes: 28 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.36

* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels, `post_process_call`,
`new_base_main`, `custom_bind`, and so on. The change should only affect
users that use JAX internals.

If you do use JAX internals then you may need to
update your code (see
https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
`config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
you need help updating your code then please file a bug.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
Expand Down Expand Up @@ -43,6 +58,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.

* New Features
Expand All @@ -52,12 +70,22 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
* Added {func}`jax.numpy.put_along_axis`.
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
supported on GPU. See {jax-issue}`#24663` for more details.

* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}`#24843` for more details.

* Deprecations
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
use `jax.Array` instead.
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
instead.

## jax 0.4.35 (Oct 22, 2024)

* Breaking Changes
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like.

Using `jit` puts constraints on the kind of Python control flow
the function can use; see
the [Gotchas
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html)
for more.

### Auto-vectorization with `vmap`
Expand Down Expand Up @@ -349,7 +348,7 @@ Some standouts:
1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
1. [Random numbers are
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
1. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
they're in the `jax.lax` package.
Expand All @@ -369,7 +368,7 @@ Some standouts:
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
np.float32)).dtype` is `float64` rather than `float32`.
1. Some transformations, like `jit`, [constrain how you can use Python control
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
flow](https://jax.readthedocs.io/en/latest/control-flow.html).
You'll always get loud errors if something goes wrong. You might have to use
[`jit`'s `static_argnums`
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
Expand All @@ -390,6 +389,7 @@ Some standouts:
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | yes | no | experimental | n/a | no | no |
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
| Intel GPU | experimental | n/a | n/a | n/a | no | no |


### Instructions
Expand All @@ -401,6 +401,7 @@ Some standouts:
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |

See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
for information on alternative installation strategies. These include compiling
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/shape_poly_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import jax
from jax import core
from jax._src.numpy import lax_numpy
from jax import export

jax.config.parse_flags_with_absl()
Expand Down Expand Up @@ -76,7 +75,7 @@ def inequalities_slice(state):
while state:
for _ in range(30):
a.scope._clear_caches()
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
_ = 0 <= slice_size <= b
_ = start >= 0
_ = start + slice_size <= b
Expand Down
51 changes: 51 additions & 0 deletions ci/run_bazel_test_gpu_rbe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one
# GPU apiece on RBE.
#
# -e: abort script if one command fails
# -u: error if undefined variable used
# -x: log all commands
# -o history: record shell history
# -o allexport: export all functions and variables to be available to subscripts
set -exu -o history -o allexport

# Source default JAXCI environment variables.
source ci/envs/default.env

# Clone XLA at HEAD if path to local XLA is not provided
if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
export JAXCI_CLONE_MAIN_XLA=1
fi

# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"

# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece).
echo "Running RBE GPU tests..."

bazel test --config=rbe_linux_x86_64_cuda \
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--test_output=errors \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
--test_tag_filters=-multiaccelerator \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64=0 \
--color=yes \
//tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
6 changes: 3 additions & 3 deletions docs/Custom_Operation_for_GPUs.md
Original file line number Diff line number Diff line change
Expand Up @@ -623,16 +623,16 @@ be used with the custom_partitioning registration and for the
gradient. (And if you implement the interface to support vmat, it will
also be on the outer primitive).
JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic.
JAX custom_partitioning implementations are callbacks from XLA to Python during XLA sharding logic.
XLA sharding goes in two phases: a sharding propagation phase and a partition phase.
The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph.
The propagation phase is when XLA plan the sharding to be created. It is the partition phase that creates the sharded graph.
For XLA to be able to shard our custom operations, it needs us to define 2 extra functions:
infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively.
The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding.
The partition() function will do a few things:
- tell which input sharding will be expected. XLA will reshad if needed.
- tell which input sharding will be expected. XLA will reshard if needed.
- tell the final version of the output sharding.
- give a function that will create the new instruction from the sharded inputs.
Expand Down
Loading

0 comments on commit caf86f6

Please sign in to comment.