From 54ed52912dc68fac4545b88274fc31834d465f24 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 28 Aug 2024 11:25:41 -0700 Subject: [PATCH] Update jax-triton metadata to depend on the Triton 3.0.0 release. This seems to work just fine! And it will make the pip install considerably easier. --- README.md | 33 +++------------------------------ jax_triton/version.py | 2 +- pyproject.toml | 21 +++++---------------- 3 files changed, 9 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 1e53c688..179a4de3 100644 --- a/README.md +++ b/README.md @@ -76,40 +76,13 @@ ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Fl $ pip install jax-triton ``` -Make sure you have a CUDA-compatible `jaxlib` installed. -For example you could run: -```bash -$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -``` - -### Installation at HEAD - -JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run: -```bash -$ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git' -``` -This should install compatible versions of JAX and Triton. +You can either use a stable release of `triton` or a nightly release. -JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release: +Make sure you have a CUDA-compatible `jax` installed. For example you could run: ```bash -$ pip install jaxlib[cuda11_pip] -$ # or -$ pip install jaxlib[cuda12_pip] +$ pip install "jax[cuda12]" ``` -If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly. -To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via: -```bash -$ pip install 'jaxlib @ ' -``` -or to install CUDA via pip automatically, you can do: -```bash -$ pip install 'jaxlib[cuda11_pip] @ ' -$ # or -$ pip install 'jaxlib[cuda12_pip] @ ' -``` - - ## Development To develop `jax-triton`, you can clone the repo with: diff --git a/jax_triton/version.py b/jax_triton/version.py index 49a09620..38a4bf2d 100644 --- a/jax_triton/version.py +++ b/jax_triton/version.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version_info__ = (0, 1, 4) +__version_info__ = (0, 2, 0) __version__ = ".".join(str(v) for v in __version_info__) diff --git a/pyproject.toml b/pyproject.toml index bb7fc130..0b5f6ca2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,33 +3,22 @@ name = "jax-triton" dynamic = ["version"] description = "JAX + OpenAI Triton integration" readme = "README.md" -requires-python = ">=3.9,<3.11" +requires-python = ">=3.10" dependencies = [ "absl-py>=1.4.0", - "jax @ git+https://github.com/google/jax@a0c1265bbae2c3ec644d6181f23264b4794e9eac", - "triton-nightly @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/2.1.dev20230714011643/triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + "jax>=0.4.31", + "triton>=3.0", + "setuptools", # triton seems to need this when installing itself. ] [project.optional-dependencies] -cuda12 = [ - "jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230727+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl" -] -cuda12_pip = [ - "jaxlib[cuda12_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda12/jaxlib-0.4.14.dev20230727+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl" -] -cuda11 = [ - "jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230727+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl" -] -cuda11_pip = [ - "jaxlib[cuda11_pip] @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230727+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl" -] tests = [ "pytest" ] [build-system] -requires = ["setuptools", "setuptools-scm", "cmake"] +requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" [tools.setuptools]