From 759e9e5b2d7c2952fed5aef90d40884458495c39 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 12 Aug 2024 16:57:41 +0100 Subject: [PATCH] Pinning JAX to <0.4.31 until sharding parameter bug is solved. --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e9f9cd6..9016657 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,11 +23,11 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "chex >= 0.1.6", - "jax >= 0.3.16", - "jaxlib >= 0.3.15", + "chex>=0.1.6", + "jax>=0.3.16,<0.4.31", + "jaxlib>=0.3.15", "ml_dtypes", - "numpy >= 1.22.4" + "numpy>=1.22.4" ] dynamic = ["version"]