Skip to content

Commit

Permalink
Forcing a modern version of JAX.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Oct 2, 2024
1 parent 951f857 commit e2aa76f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
]
dependencies = [
"dace>=0.16",
"jax[cpu]>=0.4.24",
"jax[cpu]>=0.4.33",
"numpy>=1.26.0",
]
description = "JAX jit using DaCe (Data Centric Parallel Programming)"
Expand Down
6 changes: 6 additions & 0 deletions src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@

from __future__ import annotations

import jax

import jace.translator.primitive_translators as _ # noqa: F401 [unused-import] # Needed to populate the internal translator registry.

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit


if jax.version._version_as_tuple(jax.__version__) < (0, 4, 33):
raise ImportError(f"Require at least JAX version '0.4.33', but found '{jax.__version__}'.")


__all__ = [
"__author__",
"__copyright__",
Expand Down

0 comments on commit e2aa76f

Please sign in to comment.