diff --git a/sbijax/__init__.py b/sbijax/__init__.py index ff076eb..867e4cf 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,7 +2,7 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.0" +__version__ = "0.1.1" from sbijax.abc.rejection_abc import RejectionABC diff --git a/sbijax/_sne_base.py b/sbijax/_sne_base.py index 400a1a0..e5f10da 100644 --- a/sbijax/_sne_base.py +++ b/sbijax/_sne_base.py @@ -59,8 +59,55 @@ def simulate_data_and_possibly_append( """ observable = jnp.atleast_2d(observable) - sample_key, rng_key = jr.split(rng_key) + new_data, diagnostics = self.simulate_data( + rng_key, + params=params, + observable=observable, + n_simulations=n_simulations, + **kwargs, + ) if data is None: + d_new = new_data + else: + d_new = self.stack_data(data, new_data) + return d_new, diagnostics + + def simulate_data( + self, + rng_key, + *, + params=None, + observable=None, + n_simulations=1000, + **kwargs, + ): + """ + Simulate data from the posterior or prior and append it to an + existing data set (if provided) + + Parameters + ---------- + rng_key: jax.PRNGKey + a random key + params: Optional[pytree] + a dictionary of neural network parameters. If None, will draw from + prior. If parameters given, will draw from amortized posterior + using 'observable; + observable: Optional[jnp.ndarray] + an observation. Needs to be gfiven if posterior draws are desired + n_simulations: int + number of newly simulated data + kwargs: keyword arguments + dictionary of ey value pairs passed to `sample_posterior` + + Returns + ------- + NamedTuple: + returns a NamedTuple of two axis, y and theta + """ + + sample_key, rng_key = jr.split(rng_key) + if params is None or len(params) == 0: diagnostics = None self.n_total_simulations += n_simulations new_thetas = self.prior_sampler_fn( @@ -68,12 +115,17 @@ def simulate_data_and_possibly_append( sample_shape=(n_simulations,), ) else: + if observable is None: + raise ValueError( + "need to have access to 'observable' " + "when sampling from posterior" + ) if "n_samples" not in kwargs: kwargs["n_samples"] = n_simulations new_thetas, diagnostics = self.sample_posterior( rng_key=sample_key, params=params, - observable=observable, + observable=jnp.atleast_2d(observable), **kwargs, ) perm_key, rng_key = jr.split(rng_key) @@ -82,18 +134,33 @@ def simulate_data_and_possibly_append( simulate_key, rng_key = jr.split(rng_key) new_obs = self.simulator_fn(seed=simulate_key, theta=new_thetas) + chex.assert_shape(new_thetas, [n_simulations, None]) + chex.assert_shape(new_obs, [n_simulations, None]) + new_data = named_dataset(new_obs, new_thetas) - chex.assert_shape(new_thetas, [n_simulations, None]) - chex.assert_shape(new_data, [n_simulations, None]) + return new_data, diagnostics - if data is None: - d_new = new_data - else: - d_new = named_dataset( - *[jnp.vstack([a, b]) for a, b in zip(data, new_data)] - ) - return d_new, diagnostics + @staticmethod + def stack_data(data, also_data): + """ + Stack two data sets. + + Parameters + ---------- + data: NamedTuple + one data set + also_data: : NamedTuple + + Returns + ------- + NamedTuple: + returns the stack of the two data sets + """ + + return named_dataset( + *[jnp.vstack([a, b]) for a, b in zip(data, also_data)] + ) def as_iterators( self, rng_key, data, batch_size, percentage_data_as_validation_set diff --git a/sbijax/snl_test.py b/sbijax/snl_test.py index 8e8783a..24cdfcb 100644 --- a/sbijax/snl_test.py +++ b/sbijax/snl_test.py @@ -1,8 +1,10 @@ # pylint: skip-file - +import chex import distrax import haiku as hk +import pytest from jax import numpy as jnp +from jax import random as jr from surjectors import Chain, MaskedCoupling, TransformedDistribution from surjectors.conditioners import mlp_conditioner from surjectors.util import make_alternating_binary_mask @@ -88,3 +90,34 @@ def test_snl(): n_samples=200, n_warmup=100, ) + + +def test_stack_data(): + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + snl = SNL(fns, make_model(2)) + n = 100 + data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n) + also_data, _ = snl.simulate_data(jr.PRNGKey(2), n_simulations=n) + stacked_data = snl.stack_data(data, also_data) + + chex.assert_trees_all_equal(data[0], stacked_data[0][:n]) + chex.assert_trees_all_equal(data[1], stacked_data[1][:n]) + chex.assert_trees_all_equal(also_data[0], stacked_data[0][n:]) + chex.assert_trees_all_equal(also_data[1], stacked_data[1][n:]) + + +def test_simulate_data_from_posterior_fail(): + rng_seq = hk.PRNGSequence(0) + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + snl = SNL(fns, make_model(2)) + n = 100 + + data, _ = snl.simulate_data(jr.PRNGKey(1), n_simulations=n) + params, _ = snl.fit(next(rng_seq), data=data, n_iter=10) + with pytest.raises(ValueError): + snl.simulate_data(jr.PRNGKey(2), n_simulations=n, params=params)