From e2aa76fea7064edf2e0c510f26078d54cda022aa Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 2 Oct 2024 14:28:09 +0200 Subject: [PATCH] Forcing a modern version of JAX. --- pyproject.toml | 2 +- src/jace/__init__.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 393ce01..3149170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ ] dependencies = [ "dace>=0.16", - "jax[cpu]>=0.4.24", + "jax[cpu]>=0.4.33", "numpy>=1.26.0", ] description = "JAX jit using DaCe (Data Centric Parallel Programming)" diff --git a/src/jace/__init__.py b/src/jace/__init__.py index ca7e0f5..a40ecd6 100644 --- a/src/jace/__init__.py +++ b/src/jace/__init__.py @@ -9,12 +9,18 @@ from __future__ import annotations +import jax + import jace.translator.primitive_translators as _ # noqa: F401 [unused-import] # Needed to populate the internal translator registry. from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ from .api import grad, jacfwd, jacrev, jit +if jax.version._version_as_tuple(jax.__version__) < (0, 4, 33): + raise ImportError(f"Require at least JAX version '0.4.33', but found '{jax.__version__}'.") + + __all__ = [ "__author__", "__copyright__",