Skip to content

Commit

Permalink
Update jax-triton metadata to depend on the Triton 3.0.0 release.
Browse files Browse the repository at this point in the history
This seems to work just fine! And it will make the pip install considerably easier.
  • Loading branch information
hawkinsp committed Aug 28, 2024
1 parent bce82cb commit 54ed529
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 47 deletions.
33 changes: 3 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,40 +76,13 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl
$ pip install jax-triton
```

Make sure you have a CUDA-compatible `jaxlib` installed.
For example you could run:
```bash
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

### Installation at HEAD

JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:
```bash
$ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
```
This should install compatible versions of JAX and Triton.
You can either use a stable release of `triton` or a nightly release.

JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
Make sure you have a CUDA-compatible `jax` installed. For example you could run:
```bash
$ pip install jaxlib[cuda11_pip]
$ # or
$ pip install jaxlib[cuda12_pip]
$ pip install "jax[cuda12]"
```

If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.
To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via:
```bash
$ pip install 'jaxlib @ <link to nightly>'
```
or to install CUDA via pip automatically, you can do:
```bash
$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
$ # or
$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'
```


## Development

To develop `jax-triton`, you can clone the repo with:
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version_info__ = (0, 1, 4)
__version_info__ = (0, 2, 0)
__version__ = ".".join(str(v) for v in __version_info__)
21 changes: 5 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,22 @@ name = "jax-triton"
dynamic = ["version"]
description = "JAX + OpenAI Triton integration"
readme = "README.md"
requires-python = ">=3.9,<3.11"
requires-python = ">=3.10"
dependencies = [
"absl-py>=1.4.0",
"jax @ git+https://github.com/google/jax@a0c1265bbae2c3ec644d6181f23264b4794e9eac",
"triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230714011643/triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
"jax>=0.4.31",
"triton>=3.0",
"setuptools", # triton seems to need this when installing itself.
]

[project.optional-dependencies]
cuda12 = [
"jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230727+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda12_pip = [
"jaxlib[cuda12_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230727+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda11 = [
"jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230727+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda11_pip = [
"jaxlib[cuda11_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230727+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
]
tests = [
"pytest"
]


[build-system]
requires = ["setuptools", "setuptools-scm", "cmake"]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"

[tools.setuptools]
Expand Down

0 comments on commit 54ed529

Please sign in to comment.