From c495d8ab0c2684a7b821b9b346c0b91695e59eda Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 17 May 2024 14:58:31 +0200 Subject: [PATCH] Add test for freeze_dims_and_data in JAX backend --- tests/sampling/test_jax.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index dd438546c8d..fcd079dad04 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -35,6 +35,7 @@ from pymc import ImputationWarning from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix +from pymc.model.transform.optimization import freeze_dims_and_data from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, @@ -514,6 +515,24 @@ def test_convergence_warnings(caplog, nuts_sampler): def test_dirichlet_multinomial(): + """Test we can draw from a DM in the JAX backend if the shape is constant.""" dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01) dm_draws = pm.draw(dm, mode="JAX") np.testing.assert_equal(dm_draws, np.eye(3) * 5) + + +def test_dirichlet_multinomial_dims(): + """Test we can draw from a DM with a shape defined by dims in the JAX backend, + after freezing those dims. + """ + with pm.Model(coords={"trial": range(3), "item": range(3)}) as m: + dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item")) + + # JAX does not allow us to JIT a function with dynamic shape + with pytest.raises(TypeError): + pm.draw(dm, mode="JAX") + + # Should be fine after freezing the dims that specify the shape + frozen_dm = freeze_dims_and_data(m)["dm"] + dm_draws = pm.draw(frozen_dm, mode="JAX") + np.testing.assert_equal(dm_draws, np.eye(3) * 5)