Skip to content

Commit

Permalink
Merge pull request #283 from hawkinsp:triton
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668594620
  • Loading branch information
The jax_triton Authors committed Aug 28, 2024
2 parents bce82cb + 54ed529 commit feb3fc3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 47 deletions.
33 changes: 3 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,40 +76,13 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl
$ pip install jax-triton
```

Make sure you have a CUDA-compatible `jaxlib` installed.
For example you could run:
```bash
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

### Installation at HEAD

JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:
```bash
$ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
```
This should install compatible versions of JAX and Triton.
You can either use a stable release of `triton` or a nightly release.

JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
Make sure you have a CUDA-compatible `jax` installed. For example you could run:
```bash
$ pip install jaxlib[cuda11_pip]
$ # or
$ pip install jaxlib[cuda12_pip]
$ pip install "jax[cuda12]"
```

If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.
To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via:
```bash
$ pip install 'jaxlib @ <link to nightly>'
```
or to install CUDA via pip automatically, you can do:
```bash
$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
$ # or
$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'
```


## Development

To develop `jax-triton`, you can clone the repo with:
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version_info__ = (0, 1, 4)
__version_info__ = (0, 2, 0)
__version__ = ".".join(str(v) for v in __version_info__)
21 changes: 5 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,22 @@ name = "jax-triton"
dynamic = ["version"]
description = "JAX + OpenAI Triton integration"
readme = "README.md"
requires-python = ">=3.9,<3.11"
requires-python = ">=3.10"
dependencies = [
"absl-py>=1.4.0",
"jax @ git+https://github.com/google/jax@a0c1265bbae2c3ec644d6181f23264b4794e9eac",
"triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230714011643/triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
"jax>=0.4.31",
"triton>=3.0",
"setuptools", # triton seems to need this when installing itself.
]

[project.optional-dependencies]
cuda12 = [
"jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230727+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda12_pip = [
"jaxlib[cuda12_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230727+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda11 = [
"jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230727+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
]
cuda11_pip = [
"jaxlib[cuda11_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230727+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"
]
tests = [
"pytest"
]


[build-system]
requires = ["setuptools", "setuptools-scm", "cmake"]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"

[tools.setuptools]
Expand Down

0 comments on commit feb3fc3

Please sign in to comment.