diff --git a/pyproject.toml b/pyproject.toml index 5bc0f3c2..45162fd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ license = { file = "LICENSE" } name = "flowjax" readme = "README.md" requires-python = ">=3.10" -version = "13.1.0" +version = "13.1.1" [project.urls] repository = "https://github.com/danielward27/flowjax" diff --git a/tests/test_bijections/test_bijection_utils.py b/tests/test_bijections/test_bijection_utils.py index 53ba8b86..5081c410 100644 --- a/tests/test_bijections/test_bijection_utils.py +++ b/tests/test_bijections/test_bijection_utils.py @@ -1,6 +1,6 @@ -import jax import jax.numpy as jnp import pytest +from equinox import EquinoxRuntimeError from flowjax.bijections import Affine, Partial, Permute @@ -30,5 +30,5 @@ def test_partial(idx, expected): def test_Permute_argcheck(): - with pytest.raises(jax.lib.xla_extension.XlaRuntimeError): + with pytest.raises(EquinoxRuntimeError): Permute(jnp.array([0, 0])) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 0785b0d9..8416030e 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -18,7 +18,7 @@ def test_BijectionReparam(): - with pytest.raises(jax.lib.xla_extension.XlaRuntimeError, match="Exp"): + with pytest.raises(eqx.EquinoxRuntimeError, match="Exp"): BijectionReparam(-jnp.ones(3), Exp()) param = jnp.array([jnp.inf, 1, 2])