Skip to content

Commit

Permalink
Introduce GPU testing to JAX Triton
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelHudgins committed Sep 23, 2024
1 parent fa66b2d commit 53623e9
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ on:
pull_request:
branches:
- main

permissions:
contents: read # to fetch code
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
lint:
runs-on: ubuntu-latest
Expand All @@ -17,3 +21,18 @@ jobs:
with:
python-version: '3.10'
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
test:
runs-on: linux-x86-g2-48-l4-4gpu
container:
# TODO: change image based on what is needed for these tests
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11
steps:
- name: Install Released JAX
run: |
pip install "jax[cuda12]"
- name: Test JAX Triton
run: |
echo "Running JAX Triton GPU Tests"
nvidia-smi
pytest tests/

0 comments on commit 53623e9

Please sign in to comment.