New TPU jobs to use updated runners. #4
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Cloud TPU CI | |
name: Cloud TPU Presubmit | |
# Run on pull_request that is labeled as "optional_ci_tpu" or workflow dispatch | |
on: | |
pull_request: | |
branches: | |
- main | |
types: [labeled, synchronize] | |
workflow_dispatch: | |
# Cancel any previous iterations if a new commit is pushed | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.ref }} | |
cancel-in-progress: true | |
permissions: | |
contents: read | |
jobs: | |
cloud-tpu-test: | |
# TODO: confirm final naming for optional label | |
# if: contains(github.event.pull_request.labels.*.name, 'optional_ci_tpu') | |
name: "TPU v5e x 8 Presubmit" | |
env: | |
ENABLE_PJRT_COMPATIBILITY: 1 | |
# TODO: Needs final runs-on value | |
runs-on: arc-linux-x86-ct5lp-224-8tpu | |
container: | |
# TODO: Needs newer, light weight image | |
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11 | |
timeout-minutes: 120 | |
defaults: | |
run: | |
shell: bash -ex {0} | |
steps: | |
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 | |
- name: Install JAX test requirements | |
run: | | |
pip install -U -r build/test-requirements.txt | |
# TODO: build jax should be done on a step prior or we should just bazel test | |
- name: Wait For Connection | |
uses: google-ml-infra/jax-fork/actions/ci_connection@61e7d8d6c273b102e4a6271c1e84bd0a4febc8cb | |
with: | |
halt-dispatch-input: "1" | |
- name: Build JAX | |
run: | | |
pip uninstall -y jaxlib | |
python3 build/build.py --use_clang | |
pip install -e . | |
ls -la dist/*.whl | |
pip install dist/*.whl | |
# Note the version it installs! Should be today's date | |
pip install -U --no-index --pre libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
python3 -c 'import sys; print("python version:", sys.version)' | |
python3 -c 'import jax; print("jax version:", jax.__version__)' | |
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' | |
python3 -c 'import jax; print("libtpu version:", | |
jax.lib.xla_bridge.get_backend().platform_version)' | |
- name: Run tests | |
env: | |
JAX_PLATFORMS: tpu,cpu | |
PY_COLORS: 1 | |
NUM_TESTS: 8 | |
JAX_NUM_GENERATED_CASES: 25 | |
run: | | |
# Run single-accelerator tests in parallel | |
mkdir results | |
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=$NUM_TESTS --tb=short \ | |
--junitxml=results/singlejunit.xml --maxfail=20 -m "not multiaccelerator" tests examples | |
# Run multi-accelerator across all chips | |
python3 -m pytest --tb=short --junitxml=results/multijunit.xml \ | |
--maxfail=20 -m "multiaccelerator" tests | |
# - name: 'Upload Artifact' | |
# if: success() || failure() | |
# uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # ratchet:actions/upload-artifact@v4 | |
# with: | |
# name: junit | |
# path: | | |
# results/singlejunit.xml | |
# results/multijunit.xml | |
# retention-days: 1 |