Skip to content

Add logical axis global context support for NNX #525

Add logical axis global context support for NNX

Add logical axis global context support for NNX #525

Workflow file for this run

# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Flax - Test
on:
push:
branches:
- main
- 'test_*'
pull_request:
branches:
- main
jobs:
cancel-previous:
name: Cancel Previous Runs
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.10.1
if: ${{github.ref != 'refs/head/main'}}
with:
access_token: ${{ github.token }}
pre-commit:
name: Test pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- uses: pre-commit/action@v2.0.3
commit-count:
name: Check commit count
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
# We allow at most 5 commits in a branch to ensure our CI doesn't break.
- name: Check commit count in PR
if: always()
shell: bash
run: |
set -x
# $GITHUB_REF is in format `refs/heads/<branch_name>`. We fetch it under
# the name `commit-count` so we can refer to it below.
# Do an unshallow fetch so we retrieve all commits (this is necessary
# because ations/checkout@v2 fetches a shallow copy).
git fetch origin --unshallow $GITHUB_REF:commit-count
git fetch origin main
diff=$(git rev-list --count origin/main...commit-count)
# $GITHUB_REF adds an additional commit to the commit tree, so $diff is
# one too high when executing this as a Github Action.
if (( $diff > 6)); then
echo "ERROR! More than 5 commits in PR -- please squash your commits."
url=https://flax.readthedocs.io/en/latest/contributing.html#too-many-commits-in-a-pull-request
echo "See $url for help on how to resolve this."
exit 1
fi
test-import:
name: Test import standalone
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: astral-sh/setup-uv@v2
with:
uv-version: "0.3.0"
- name: Install standalone dependencies only
run: |
uv sync --extra all
- name: Test importing Flax
run: |
uv run python -c "import flax"
tests:
name: Run Tests
needs: [cancel-previous, pre-commit, commit-count, test-import]
runs-on: ubuntu-20.04-16core
strategy:
matrix:
python-version: ['3.10', '3.11']
test-type: [doctest, pytest, pytype, mypy]
jax-version: [newest]
exclude:
- test-type: pytype
python-version: '3.10'
- test-type: mypy
python-version: '3.11'
include:
- python-version: '3.10'
test-type: pytest
jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
id: setup_python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Setup uv
uses: astral-sh/setup-uv@v2
with:
version: "0.3.0"
- name: Setup Rust (flaxlib)
uses: actions-rust-lang/setup-rust-toolchain@v1
- name: Install dependencies
run: |
uv sync --extra all --extra testing --extra docs
uv pip install ./flaxlib
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
uv pip install -U jax jaxlib
else
uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Test with ${{ matrix.test-type }}
run: |
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then
uv run tests/run_all_tests.sh --only-doctest
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
uv run tests/run_all_tests.sh --only-pytest
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
uv run tests/run_all_tests.sh --only-pytype
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
uv run tests/run_all_tests.sh --only-mypy
else
echo "Unknown test type: ${{ matrix.test-type }}"
exit 1
fi
- name: Upload coverage to Codecov
if: matrix.test-type == 'pytest'
uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
file: ./coverage.xml
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/Build"
}'