diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e6ae812..bb061bd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -41,7 +41,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -61,7 +61,7 @@ jobs: - precommit strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 3c918b9..7dd45f1 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28f3d5b..3b0bd04 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: args: [--py38-plus] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 24.2.0 hooks: - id: black args: ["--config=pyproject.toml"] diff --git a/examples/bivariate_gaussian_smcabc.py b/examples/bivariate_gaussian_smcabc.py index 99d2e37..d90b575 100644 --- a/examples/bivariate_gaussian_smcabc.py +++ b/examples/bivariate_gaussian_smcabc.py @@ -1,5 +1,5 @@ """ -Example using ABC +Example using sequential Monte Carlo ABC on a bivariate Gaussian """ import distrax diff --git a/examples/bivariate_gaussian_snasss.py b/examples/bivariate_gaussian_snasss.py new file mode 100644 index 0000000..821bd49 --- /dev/null +++ b/examples/bivariate_gaussian_snasss.py @@ -0,0 +1,123 @@ +""" +Example using sequential neural approximate (slice) summary statistics on a +bivariate Gaussian with repeated dimensions +""" + +import distrax +import haiku as hk +import jax.nn +import matplotlib.pyplot as plt +import optax +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr +from surjectors import ( + Chain, + MaskedAutoregressive, + Permutation, + TransformedDistribution, +) +from surjectors.nn import MADE +from surjectors.util import unstack + +from sbijax import SNASSS +from sbijax.nn.snasss_net import SNASSSNet + +W = jr.normal(jr.PRNGKey(0), (2, 10)) + + +def prior_model_fns(): + p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + y = theta @ W + y = y + distrax.Normal(jnp.zeros_like(y), 0.1).sample(seed=seed) + return y + + +def make_model(dim): + def _bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _flow(method, **kwargs): + layers = [] + order = jnp.arange(dim) + for i in range(5): + layer = MaskedAutoregressive( + bijector_fn=_bijector_fn, + conditioner=MADE( + dim, + [50, 50, dim * 2], + 2, + w_init=hk.initializers.TruncatedNormal(0.001), + b_init=jnp.zeros, + activation=jax.nn.tanh, + ), + ) + order = order[::-1] + layers.append(layer) + layers.append(Permutation(order, 1)) + chain = Chain(layers) + + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(dim), jnp.ones(dim)), + 1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + td = hk.transform(_flow) + td = hk.without_apply_rng(td) + return td + + +def make_critic(dim): + @hk.without_apply_rng + @hk.transform + def _net(method, **kwargs): + net = SNASSSNet([64, 64, dim], [64, 64, 1], [64, 64, 1]) + return net(method, **kwargs) + + return _net + + +def run(): + y_observed = jnp.array([[2.0, -2.0]]) @ W + + prior_simulator_fn, prior_logdensity_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn + + estim = SNASSS(fns, make_model(2), make_critic(2)) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(2): + data, _ = estim.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, _ = estim.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + batch_size=100, + ) + + rng_key = jr.PRNGKey(23) + snp_samples, _ = estim.sample_posterior(rng_key, params, y_observed) + fig, axes = plt.subplots(2) + for i, ax in enumerate(axes): + sns.histplot(snp_samples[:, i], color="darkblue", ax=ax) + ax.set_xlim([-3.0, 3.0]) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + run() diff --git a/examples/bivariate_gaussian_snl.py b/examples/bivariate_gaussian_snl.py index 00539b7..a38f145 100644 --- a/examples/bivariate_gaussian_snl.py +++ b/examples/bivariate_gaussian_snl.py @@ -1,5 +1,5 @@ """ -Example using SNL and masked autoregressive flows flows +Example using sequential neural likelihood estimation on a bivariate Gaussian """ from functools import partial @@ -13,12 +13,12 @@ from jax import numpy as jnp from jax import random as jr from surjectors import ( + MADE, Chain, MaskedAutoregressive, Permutation, TransformedDistribution, ) -from surjectors.conditioners import MADE from surjectors.util import unstack from sbijax import SNL diff --git a/examples/bivariate_gaussian_snp.py b/examples/bivariate_gaussian_snp.py index 3cc0230..bc058b8 100644 --- a/examples/bivariate_gaussian_snp.py +++ b/examples/bivariate_gaussian_snp.py @@ -1,5 +1,5 @@ """ -Example using SNP and masked autoregressive flows +Example using sequential posterior estimation on a bivariate Gaussian """ import distrax @@ -10,10 +10,13 @@ import seaborn as sns from jax import numpy as jnp from jax import random as jr -from surjectors import Chain, TransformedDistribution -from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive -from surjectors.bijectors.permutation import Permutation -from surjectors.conditioners import MADE +from surjectors import ( + Chain, + MaskedAutoregressive, + Permutation, + TransformedDistribution, +) +from surjectors.nn import MADE from surjectors.util import unstack from sbijax import SNP diff --git a/examples/slcp_ssnl.py b/examples/slcp_ssnl.py index 6193fed..0be7b35 100644 --- a/examples/slcp_ssnl.py +++ b/examples/slcp_ssnl.py @@ -1,6 +1,5 @@ """ -SLCP example from [1] using SNL and masked autoregressive bijections -or surjections +Example using SSNL on the SLCP experiment """ import argparse diff --git a/pyproject.toml b/pyproject.toml index 99363f7..53b9993 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,16 +15,17 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "blackjax-nightly>=1.0.0.post17", "distrax>=0.1.2", "dm-haiku>=0.0.9", "optax>=0.1.3", - "surjectors>=0.2.2", + "surjectors>=0.3.0", "tfp-nightly>=0.20.0.dev20230404" ] dynamic = ["version"] @@ -59,8 +60,7 @@ test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax' [tool.black] line-length = 80 -extend-ignore = "E203" -target-version = ['py39'] +target-version = ['py311'] exclude = ''' /( \.eggs @@ -84,7 +84,7 @@ include_trailing_comma = true [tool.flake8] max-line-length = 80 -extend-ignore = ["E203", "W503", "E731"] +extend-ignore = ["E203", "W503", "E731", "E231"] per-file-ignores = [ '__init__.py:F401', ] diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 672daac..589dc6a 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,10 +2,12 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.1.3" +__version__ = "0.1.5" from sbijax.abc.rejection_abc import RejectionABC from sbijax.abc.smc_abc import SMCABC +from sbijax.snass import SNASS +from sbijax.snasss import SNASSS from sbijax.snl import SNL from sbijax.snp import SNP diff --git a/sbijax/_sne_base.py b/sbijax/_sne_base.py index 595cc96..b72af97 100644 --- a/sbijax/_sne_base.py +++ b/sbijax/_sne_base.py @@ -1,5 +1,4 @@ from abc import ABC -from typing import Iterable import chex from jax import numpy as jnp @@ -21,8 +20,6 @@ def __init__(self, model_fns, density_estimator): super().__init__(model_fns) self.model = density_estimator self.n_total_simulations = 0 - self._train_iter: Iterable - self._val_iter: Iterable def simulate_data_and_possibly_append( self, diff --git a/sbijax/abc/smc_abc.py b/sbijax/abc/smc_abc.py index 07a3758..c3a7e03 100644 --- a/sbijax/abc/smc_abc.py +++ b/sbijax/abc/smc_abc.py @@ -208,9 +208,7 @@ def _move( ) n -= len(idxs) - new_particles = new_particles[ - :n_particles, - ] + new_particles = new_particles[:n_particles,] new_log_weights = self._new_log_weights( new_particles, particles, log_weights, cov_chol_factor ) diff --git a/sbijax/generator.py b/sbijax/generator.py index 173d22b..ca2ab1b 100644 --- a/sbijax/generator.py +++ b/sbijax/generator.py @@ -1,7 +1,6 @@ from collections import namedtuple import chex -import numpy as np from jax import lax from jax import numpy as jnp from jax import random as jr @@ -11,13 +10,20 @@ # pylint: disable=missing-class-docstring,too-few-public-methods class DataLoader: - def __init__(self, num_batches, idxs, get_batch): + def __init__(self, num_batches, idxs=None, get_batch=None, batches=None): self.num_batches = num_batches self.idxs = idxs - self.num_samples = len(idxs) + if idxs is not None: + self.num_samples = len(idxs) + else: + self.num_samples = self.num_batches * batches[0]["y"].shape[0] self.get_batch = get_batch + self.batches = batches def __call__(self, idx, idxs=None): + if self.batches is not None: + return self.batches[idx] + if idxs is None: idxs = self.idxs return self.get_batch(idx, idxs) @@ -63,7 +69,7 @@ def as_batch_iterator( def get_batch(idx, idxs=idxs): start_idx = idx * batch_size - step_size = np.minimum(n - start_idx, batch_size) + step_size = jnp.minimum(n - start_idx, batch_size) ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size) batch = { name: lax.index_take(array, (ret_idx,), axes=(0,)) diff --git a/sbijax/nn/snass_net.py b/sbijax/nn/snass_net.py new file mode 100644 index 0000000..47b26a9 --- /dev/null +++ b/sbijax/nn/snass_net.py @@ -0,0 +1,50 @@ +from typing import Callable, List + +import haiku as hk +import jax +from jax import numpy as jnp + + +# pylint: disable=missing-function-docstring,missing-class-docstring +class SNASSNet(hk.Module): + """ + A network for SNASS + """ + + def __init__( + self, + summary_net_dimensions: List[int] = None, + critic_net_dimensions: List[int] = None, + summary_net: Callable = None, + critic_net: Callable = None, + ): + super().__init__() + if summary_net_dimensions is not None: + assert critic_net_dimensions is not None + assert summary_net is None + assert critic_net is None + self._summary = hk.nets.MLP( + output_sizes=summary_net_dimensions, activation=jax.nn.relu + ) + self._critic = hk.nets.MLP( + output_sizes=critic_net_dimensions, activation=jax.nn.relu + ) + else: + assert summary_net is not None + assert critic_net is not None + self._summary = summary_net + self._critic = critic_net + + def __call__(self, method, **kwargs): + return getattr(self, method)(**kwargs) + + def forward(self, y, theta): + s = self.summary(y) + c = self.critic(s, theta) + return s, c + + def summary(self, y): + return self._summary(y) + + def critic(self, y, theta): + return self._critic(jnp.concatenate([y, theta], axis=-1)) diff --git a/sbijax/nn/snasss_net.py b/sbijax/nn/snasss_net.py new file mode 100644 index 0000000..d283123 --- /dev/null +++ b/sbijax/nn/snasss_net.py @@ -0,0 +1,57 @@ +from typing import Callable, List + +import haiku as hk +import jax +from jax import numpy as jnp + + +# pylint: disable=missing-function-docstring,missing-class-docstring, +# pylint: too-many-arguments +class SNASSSNet(hk.Module): + def __init__( + self, + summary_net_dimensions: List[int] = None, + sec_summary_net_dimensions: List[int] = None, + critic_net_dimensions: List[int] = None, + summary_net: Callable = None, + sec_summary_net: Callable = None, + critic_net: Callable = None, + ): + super().__init__() + if summary_net_dimensions is not None: + assert critic_net_dimensions is not None + assert summary_net is None + assert critic_net is None + self._summary = hk.nets.MLP( + output_sizes=summary_net_dimensions, activation=jax.nn.relu + ) + self._secondary_summary = hk.nets.MLP( + output_sizes=sec_summary_net_dimensions, activation=jax.nn.relu + ) + self._critic = hk.nets.MLP( + output_sizes=critic_net_dimensions, activation=jax.nn.relu + ) + else: + assert summary_net is not None + assert critic_net is not None + self._summary = summary_net + self._secondary_summary = sec_summary_net + self._critic = critic_net + + def __call__(self, method, **kwargs): + return getattr(self, method)(**kwargs) + + def forward(self, y, theta): + s = self.summary(y) + s2 = self.secondary_summary(s, theta) + c = self.critic(s2, y[:, [0]]) + return s, s2, c + + def summary(self, y): + return self._summary(y) + + def secondary_summary(self, y, theta): + return self._secondary_summary(jnp.concatenate([y, theta], axis=-1)) + + def critic(self, y, theta): + return self._critic(jnp.concatenate([y, theta], axis=-1)) diff --git a/sbijax/snass.py b/sbijax/snass.py new file mode 100644 index 0000000..ac37c9f --- /dev/null +++ b/sbijax/snass.py @@ -0,0 +1,245 @@ +from functools import partial + +import jax +import numpy as np +import optax +from absl import logging +from jax import numpy as jnp +from jax import random as jr + +from sbijax.generator import DataLoader +from sbijax.nn.early_stopping import EarlyStopping +from sbijax.snl import SNL + + +def _jsd_summary_loss(params, rng, apply_fn, **batch): + y, theta = batch["y"], batch["theta"] + m, _ = y.shape + summr = apply_fn(params, method="summary", y=y) + idx_pos = jnp.tile(jnp.arange(m), 10) + idx_neg = jax.vmap(lambda x: jr.permutation(x, m))( + jr.split(rng, 10) + ).reshape(-1) + f_pos = apply_fn(params, method="critic", y=summr, theta=theta) + f_neg = apply_fn( + params, method="critic", y=summr[idx_pos], theta=theta[idx_neg] + ) + a, b = -jax.nn.softplus(-f_pos), jax.nn.softplus(f_neg) + mi = a.mean() - b.mean() + return -mi + + +# pylint: disable=too-many-arguments,unused-argument +class SNASS(SNL): + """Sequential neural approximate summary statistics. + + References: + .. [1] Yanzhi Chen et al. "Neural Approximate Sufficient Statistics for + Implicit Models". ICLR, 2021 + """ + + def __init__(self, model_fns, density_estimator, snass_net): + super().__init__(model_fns, density_estimator) + self.sc_net = snass_net + + # pylint: disable=arguments-differ,too-many-locals + def fit( + self, + rng_key, + data, + optimizer=optax.adam(0.0003), + n_iter=1000, + batch_size=128, + percentage_data_as_validation_set=0.1, + n_early_stopping_patience=10, + **kwargs, + ): + """Fit a SNASS model. + + Args: + rng_seq: a hk.PRNGSequence + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated data + that is used for valitation and early stopping + n_early_stopping_patience: number of iterations of no improvement + of training the flow before stopping optimisation + kwargs: keyword arguments with sampler specific parameters. For + sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval + + Returns: + tuple of parameters and a tuple of the training information + """ + + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set + ) + + snet_params, snet_losses = self._fit_summary_net( + rng_key=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + train_iter = self._as_summary(train_iter, snet_params) + val_iter = self._as_summary(val_iter, snet_params) + + nde_params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + return {"params": nde_params, "s_params": snet_params}, ( + losses, + snet_losses, + ) + + def _as_summary(self, iters, params): + @jax.jit + def as_batch(y, theta): + return { + "y": self.sc_net.apply(params, method="summary", y=y), + "theta": theta, + } + + return DataLoader( + num_batches=iters.num_batches, + batches=[as_batch(**iters(i)) for i in range(iters.num_batches)], + ) + + def _fit_summary_net( + self, + rng_key, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + + init_key, rng_key = jr.split(rng_key) + params = self._init_summary_net_params(init_key, **train_iter(0)) + state = optimizer.init(params) + loss_fn = jax.jit( + partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) + ) + + @jax.jit + def step(rng, params, state, **batch): + loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + best_params, best_loss = None, np.inf + logging.info("training summary net") + for i in range(n_iter): + train_loss = 0.0 + epoch_key, rng_key = jr.split(rng_key) + for j in range(train_iter.num_batches): + batch = train_iter(j) + batch_loss, params, state = step( + jr.fold_in(epoch_key, j), params, state, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + val_key, rng_key = jr.split(rng_key) + validation_loss = self._summary_validation_loss( + params, val_key, val_iter + ) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + if validation_loss < best_loss: + best_loss = validation_loss + best_params = params.copy() + + losses = jnp.vstack(losses)[: (i + 1), :] + return best_params, losses + + def _init_summary_net_params(self, rng_key, **init_data): + params = self.sc_net.init(rng_key, method="forward", **init_data) + return params + + def _summary_validation_loss(self, params, rng_key, val_iter): + loss_fn = jax.jit( + partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) + ) + + def body_fn(i, batch_key): + batch = val_iter(i) + loss = loss_fn(params, batch_key, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + losses = 0.0 + for i in range(val_iter.num_batches): + batch_key, rng_key = jr.split(rng_key) + losses += body_fn(i, batch_key) + return losses + + def sample_posterior( + self, + rng_key, + params, + observable, + *, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): + """Sample from the approximate posterior. + + Args: + rng_key: a random key + params: a pytree of parameter for the model + observable: observation to condition on + n_chains: number of MCMC chains + n_samples: number of samples per chain + n_warmup: number of samples to discard + kwargs: keyword arguments with sampler specific parameters. For + sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval + + Returns: + an array of samples from the posterior distribution of dimension + (n_samples \times p) + """ + + observable = jnp.atleast_2d(observable) + summary = self.sc_net.apply( + params["s_params"], method="summary", y=observable + ) + return super().sample_posterior( + rng_key, + params["params"], + summary, + n_chains=n_chains, + n_samples=n_samples, + n_warmup=n_warmup, + **kwargs, + ) diff --git a/sbijax/snasss.py b/sbijax/snasss.py new file mode 100644 index 0000000..82242fd --- /dev/null +++ b/sbijax/snasss.py @@ -0,0 +1,269 @@ +from functools import partial + +import jax +import numpy as np +import optax +from absl import logging +from jax import numpy as jnp +from jax import random as jr + +from sbijax.generator import DataLoader +from sbijax.nn.early_stopping import EarlyStopping +from sbijax.snl import SNL + + +def _sample_unit_sphere(rng_key, n, dim): + u = jr.normal(rng_key, (n, dim)) + norm = jnp.linalg.norm(u, ord=2, axis=-1, keepdims=True) + return u / norm + + +# pylint: disable=too-many-locals +def _jsd_summary_loss(params, rng_key, apply_fn, **batch): + y, theta = batch["y"], batch["theta"] + n, p = theta.shape + + phi_key, rng_key = jr.split(rng_key) + summr = apply_fn(params, method="summary", y=y) + summr = jnp.tile(summr, [10, 1]) + theta = jnp.tile(theta, [10, 1]) + + phi = _sample_unit_sphere(phi_key, 10, p) + phi = jnp.repeat(phi, n, axis=0) + + second_summr = apply_fn( + params, method="secondary_summary", y=summr, theta=phi + ) + theta_prime = jnp.sum(theta * phi, axis=1).reshape(-1, 1) + + idx_pos = jnp.tile(jnp.arange(n), 10) + perm_key, rng_key = jr.split(rng_key) + idx_neg = jax.vmap(lambda x: jr.permutation(x, n))( + jr.split(perm_key, 10) + ).reshape(-1) + f_pos = apply_fn(params, method="critic", y=second_summr, theta=theta_prime) + f_neg = apply_fn( + params, + method="critic", + y=second_summr[idx_pos], + theta=theta_prime[idx_neg], + ) + a, b = -jax.nn.softplus(-f_pos), jax.nn.softplus(f_neg) + mi = a.mean() - b.mean() + return -mi + + +# pylint: disable=too-many-arguments,unused-argument +class SNASSS(SNL): + """Sequential neural approximate slice sufficient statistics. + + References: + .. [1] Yanzhi Chen et al. "Is Learning Summary Statistics Necessary for + Likelihood-free Inference". ICML, 2023 + """ + + def __init__(self, model_fns, density_estimator, summary_net): + super().__init__(model_fns, density_estimator) + self.sc_net = summary_net + + # pylint: disable=arguments-differ,too-many-locals + def fit( + self, + rng_key, + data, + optimizer=optax.adam(0.0003), + n_iter=1000, + batch_size=128, + percentage_data_as_validation_set=0.1, + n_early_stopping_patience=10, + **kwargs, + ): + """Fit a SNASSS model. + + Args: + rng_seq: a hk.PRNGSequence + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated data + that is used for valitation and early stopping + n_early_stopping_patience: number of iterations of no improvement + of training the flow before stopping optimisation + kwargs: keyword arguments with sampler specific parameters. For + sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval + + Returns: + tuple of parameters and a tuple of the training information + """ + + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set + ) + + snet_params, snet_losses = self._fit_summary_net( + rng_key=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + train_iter = self._as_summary(train_iter, snet_params) + val_iter = self._as_summary(val_iter, snet_params) + + nde_params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + return {"params": nde_params, "s_params": snet_params}, ( + losses, + snet_losses, + ) + + def _as_summary(self, iters, params): + @jax.jit + def as_batch(y, theta): + return { + "y": self.sc_net.apply(params, method="summary", y=y), + "theta": theta, + } + + return DataLoader( + num_batches=iters.num_batches, + batches=[as_batch(**iters(i)) for i in range(iters.num_batches)], + ) + + def _fit_summary_net( + self, + rng_key, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + + init_key, rng_key = jr.split(rng_key) + params = self._init_summary_net_params(init_key, **train_iter(0)) + state = optimizer.init(params) + loss_fn = jax.jit( + partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) + ) + + @jax.jit + def step(rng, params, state, **batch): + loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + best_params, best_loss = None, np.inf + logging.info("training summary net") + for i in range(n_iter): + train_loss = 0.0 + epoch_key, rng_key = jr.split(rng_key) + for j in range(train_iter.num_batches): + batch = train_iter(j) + batch_loss, params, state = step( + jr.fold_in(epoch_key, j), params, state, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + val_key, rng_key = jr.split(rng_key) + validation_loss = self._summary_validation_loss( + params, val_key, val_iter + ) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + if validation_loss < best_loss: + best_loss = validation_loss + best_params = params.copy() + + losses = jnp.vstack(losses)[: (i + 1), :] + return best_params, losses + + def _init_summary_net_params(self, rng_key, **init_data): + params = self.sc_net.init(rng_key, method="forward", **init_data) + return params + + def _summary_validation_loss(self, params, rng_key, val_iter): + loss_fn = jax.jit( + partial(_jsd_summary_loss, apply_fn=self.sc_net.apply) + ) + + def body_fn(i, batch_key): + batch = val_iter(i) + loss = loss_fn(params, batch_key, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + losses = 0.0 + for i in range(val_iter.num_batches): + batch_key, rng_key = jr.split(rng_key) + losses += body_fn(i, batch_key) + return losses + + def sample_posterior( + self, + rng_key, + params, + observable, + *, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): + """Sample from the approximate posterior. + + Args: + rng_key: a random key + params: a pytree of parameter for the model + observable: observation to condition on + n_chains: number of MCMC chains + n_samples: number of samples per chain + n_warmup: number of samples to discard + kwargs: keyword arguments with sampler specific parameters. For + sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval + + Returns: + an array of samples from the posterior distribution of dimension + (n_samples \times p) + """ + + observable = jnp.atleast_2d(observable) + summary = self.sc_net.apply( + params["s_params"], method="summary", y=observable + ) + return super().sample_posterior( + rng_key, + params["params"], + summary, + n_chains=n_chains, + n_samples=n_samples, + n_warmup=n_warmup, + **kwargs, + ) diff --git a/sbijax/snl.py b/sbijax/snl.py index 047c0bd..97b20fb 100644 --- a/sbijax/snl.py +++ b/sbijax/snl.py @@ -10,17 +10,15 @@ from sbijax._sne_base import SNE from sbijax.mcmc import mcmc_diagnostics, sample_with_nuts, sample_with_slice - -# pylint: disable=too-many-arguments,unused-argument from sbijax.mcmc.irmh import sample_with_imh from sbijax.mcmc.mala import sample_with_mala from sbijax.mcmc.rmh import sample_with_rmh from sbijax.nn.early_stopping import EarlyStopping +# pylint: disable=too-many-arguments,unused-argument class SNL(SNE): - """ - Sequential neural likelihood + """Sequential neural likelihood. From the Papamakarios paper """ @@ -177,7 +175,7 @@ def simulate_data_and_possibly_append( **kwargs, ): """ - Simulate data from the posteriorand append it to an existing data set + Simulate data from the posterior and append it to an existing data set (if provided) Parameters @@ -266,6 +264,28 @@ def sample_posterior( """ observable = jnp.atleast_2d(observable) + return self._sample_posterior( + rng_key, + params, + observable, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ) + + def _sample_posterior( + self, + rng_key, + params, + observable, + *, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): + part = partial( self.model.apply, params=params, method="log_prob", y=observable ) diff --git a/sbijax/snl_test.py b/sbijax/snl_test.py index 8f4c704..4789338 100644 --- a/sbijax/snl_test.py +++ b/sbijax/snl_test.py @@ -6,7 +6,7 @@ from jax import numpy as jnp from jax import random as jr from surjectors import Chain, MaskedCoupling, TransformedDistribution -from surjectors.conditioners import mlp_conditioner +from surjectors.nn import make_mlp from surjectors.util import make_alternating_binary_mask from sbijax import SNL @@ -44,8 +44,8 @@ def _flow(method, **kwargs): mask = make_alternating_binary_mask(dim, i % 2 == 0) layer = MaskedCoupling( mask=mask, - bijector=_bijector_fn, - conditioner=mlp_conditioner([8, 8, dim * 2]), + bijector_fn=_bijector_fn, + conditioner=make_mlp([8, 8, dim * 2]), ) layers.append(layer) chain = Chain(layers) diff --git a/sbijax/snp_test.py b/sbijax/snp_test.py index 81629b0..c6e4837 100644 --- a/sbijax/snp_test.py +++ b/sbijax/snp_test.py @@ -4,7 +4,7 @@ import haiku as hk from jax import numpy as jnp from surjectors import Chain, MaskedCoupling, TransformedDistribution -from surjectors.conditioners import mlp_conditioner +from surjectors.nn import make_mlp from surjectors.util import make_alternating_binary_mask from sbijax import SNP @@ -42,8 +42,8 @@ def _flow(method, **kwargs): mask = make_alternating_binary_mask(dim, i % 2 == 0) layer = MaskedCoupling( mask=mask, - bijector=_bijector_fn, - conditioner=mlp_conditioner([8, 8, dim * 2]), + bijector_fn=_bijector_fn, + conditioner=make_mlp([8, 8, dim * 2]), ) layers.append(layer) chain = Chain(layers)