diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7ed60026..0268acf4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,11 @@ on: pull_request: branches: - main - +permissions: + contents: read # to fetch code +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true jobs: lint: runs-on: ubuntu-latest @@ -17,3 +21,18 @@ jobs: with: python-version: '3.10' - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1 + test: + runs-on: linux-x86-g2-48-l4-4gpu + container: + # TODO: change image based on what is needed for these tests + image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11 + steps: + - name: Install Released JAX + run: | + pip install "jax[cuda12]" + - name: Test JAX Triton + run: | + echo "Running JAX Triton GPU Tests" + nvidia-smi + pytest tests/ +