Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New TPU jobs to use updated runners. #22759

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions .github/workflows/cloud-tpu-ci-nightly-new.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Cloud TPU CI
#
# This job currently runs once per day. We use self-hosted TPU runners, so we'd
# have to add more runners to run on every commit.
#
# This job's build matrix runs over several TPU architectures using both the
# latest released jaxlib on PyPi ("pypi_latest") and the latest nightly
# jaxlib.("nightly"). It also installs a matching libtpu, either the one pinned
# to the release for "pypi_latest", or the latest nightly.for "nightly". It
# always locally installs jax from github head (already checked out by the
# Github Actions environment).

name: CI - Cloud TPU Nightly (new)
on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
workflow_dispatch: # allows triggering the workflow run manually
# # TODO: remove pull request trigger
pull_request:
branches:
- main
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
# TODO - remove concurrency for normal usage. Its here for presubmit testing
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
cloud-tpu-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
# {type: "v3-8", cores: "4"},
# {type: "v4-8", cores: "4"},
{type: "v5e-8", cores: "8"}
]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240228
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["arc-linux-x86-ct5lp-224-8tpu"]
container:
# TODO repin
# We run on a bare python image to best replicate enduser usage
image: python:3.10-bookworm
timeout-minutes: 120
defaults:
run:
shell: bash -ex {0}
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html

elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --pre libtpu-nightly \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests

elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests

else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi

python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
# TODO: reenable
# - name: Send chat on failure
# # Don't notify when testing the workflow from a branch.
# if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
# run: |
# curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
# --header 'Content-Type: application/json' \
# --data-raw "{
# 'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
# }"
78 changes: 78 additions & 0 deletions .github/workflows/cloud-tpu-presubmit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Cloud TPU CI
name: Cloud TPU Presubmit
# Run on pull_request that is labeled as "optional_ci_tpu" or workflow dispatch
on:
pull_request:
branches:
- main
types: [labeled, synchronize]
workflow_dispatch:
# Cancel any previous iterations if a new commit is pushed
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
cloud-tpu-test:
# TODO: confirm final naming for optional label
# if: contains(github.event.pull_request.labels.*.name, 'optional_ci_tpu')
name: "TPU v5e x 8 Presubmit"
env:
ENABLE_PJRT_COMPATIBILITY: 1
# TODO: Needs final runs-on value
runs-on: arc-linux-x86-ct5lp-224-8tpu
container:
# TODO: Needs newer, light weight image
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11
timeout-minutes: 120
defaults:
run:
shell: bash -ex {0}
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
# TODO: build jax should be done on a step prior or we should just bazel test
- name: Wait For Connection
uses: google-ml-infra/jax-fork/actions/ci_connection@0b98fcaa920fcae8374c61b84febcfbe5c3b472f
with:
halt-dispatch-input: "1"
- name: Build JAX
run: |
pip uninstall -y jaxlib
python3 build/build.py --use_clang
pip install -e .
ls -la dist/*.whl
pip install dist/*.whl
# Note the version it installs! Should be today's date
pip install -U --no-index --pre libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
NUM_TESTS: 8
JAX_NUM_GENERATED_CASES: 25
run: |
# Run single-accelerator tests in parallel
mkdir results
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=$NUM_TESTS --tb=short \
--junitxml=results/singlejunit.xml --maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --junitxml=results/multijunit.xml \
--maxfail=20 -m "multiaccelerator" tests
# - name: 'Upload Artifact'
# if: success() || failure()
# uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet:actions/upload-artifact@v4
# with:
# name: junit
# path: |
# results/singlejunit.xml
# results/multijunit.xml
# retention-days: 1
Loading