diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7ed60026..b8c7ee7f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,11 @@ on: pull_request: branches: - main - +permissions: + contents: read # to fetch code +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true jobs: lint: runs-on: ubuntu-latest @@ -17,3 +21,27 @@ jobs: with: python-version: '3.10' - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1 + test: + runs-on: linux-x86-g2-48-l4-4gpu + container: + image: index.docker.io/library/ubuntu@sha256:0e5e4a57c2499249aafc3b40fcd541e9a456aab7296681a3994d631587203f97 # ratchet:ubuntu:22.04 + steps: + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + with: + python-version: '3.10' + - name: Setup Released JAX and Torch + run: | + pip install torch + pip install -U "jax[cuda12]" + pip install pytest + - name: Test JAX Triton + run: | + echo "Running JAX Triton GPU Tests" + nvidia-smi + pip install . + # Need newer ml-dtypes because we install newer numpy + pip install --upgrade ml-dtypes + pytest -v --tb=short tests/ +