diff --git a/examples/bivariate_gaussian_snl.py b/examples/bivariate_gaussian_snl.py index 9ea6e3f..00539b7 100644 --- a/examples/bivariate_gaussian_snl.py +++ b/examples/bivariate_gaussian_snl.py @@ -100,6 +100,7 @@ def run(): params=params, observable=y_observed, data=data, + sampler="imh", ) params, info = snl.fit( jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer diff --git a/examples/slcp_ssnl.py b/examples/slcp_ssnl.py index 8a29eec..6193fed 100644 --- a/examples/slcp_ssnl.py +++ b/examples/slcp_ssnl.py @@ -5,6 +5,7 @@ import argparse from functools import partial +from timeit import default_timer as timer import distrax import haiku as hk @@ -189,6 +190,7 @@ def run(use_surjectors): optimizer = optax.adam(1e-3) data, params = None, {} + start = timer() for i in range(5): data, _ = snl.simulate_data_and_possibly_append( jr.fold_in(jr.PRNGKey(12), i), @@ -200,6 +202,8 @@ def run(use_surjectors): params, info = snl.fit( jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer ) + end = timer() + print(end - start) sample_key, rng_key = jr.split(jr.PRNGKey(123)) snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed) diff --git a/pyproject.toml b/pyproject.toml index 690350c..99363f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ - "blackjax-nightly>=0.9.6.post127", + "blackjax-nightly>=1.0.0.post17", "distrax>=0.1.2", "dm-haiku>=0.0.9", "optax>=0.1.3", diff --git a/sbijax/generator.py b/sbijax/generator.py index 9ba321d..173d22b 100644 --- a/sbijax/generator.py +++ b/sbijax/generator.py @@ -1,6 +1,7 @@ from collections import namedtuple import chex +import numpy as np from jax import lax from jax import numpy as jnp from jax import random as jr @@ -62,7 +63,7 @@ def as_batch_iterator( def get_batch(idx, idxs=idxs): start_idx = idx * batch_size - step_size = jnp.minimum(n - start_idx, batch_size) + step_size = np.minimum(n - start_idx, batch_size) ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size) batch = { name: lax.index_take(array, (ret_idx,), axes=(0,)) diff --git a/sbijax/mcmc/irmh.py b/sbijax/mcmc/irmh.py new file mode 100644 index 0000000..05e0c6a --- /dev/null +++ b/sbijax/mcmc/irmh.py @@ -0,0 +1,73 @@ +import blackjax as bj +import distrax +import jax +from jax import random as jr + + +# pylint: disable=too-many-arguments,unused-argument +def sample_with_imh( + rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs +): + """ + Sample from a distribution using the indepdendent Metropolis-Hastings + sampler. + + Parameters + ---------- + rng_seq: hk.PRNGSequence + a hk.PRNGSequence + lp: Callable + the logdensity you wish to sample from + prior: Callable + a function that returns a prior sample + n_chains: int + number of chains to sample + n_samples: int + number of samples per chain + n_warmup: int + number of samples to discard + + Returns + ------- + jnp.ndarrau + a JAX array of dimension n_samples \times n_chains \times len_theta + """ + + def _inference_loop(rng_key, kernel, initial_state, n_samples): + @jax.jit + def _step(states, rng_key): + keys = jax.random.split(rng_key, n_chains) + states, _ = jax.vmap(kernel)(keys, states) + return states, states + + sampling_keys = jax.random.split(rng_key, n_samples) + _, states = jax.lax.scan(_step, initial_state, sampling_keys) + return states + + init_key, rng_key = jr.split(rng_key) + initial_states, kernel = _mh_init(init_key, n_chains, prior, lp) + + states = _inference_loop(init_key, kernel, initial_states, n_samples) + _ = states.position["theta"].block_until_ready() + thetas = states.position["theta"][n_warmup:, :, :] + + return thetas + + +def _irmh_proposal_distribution(shape): + def fn(rng_key): + return {"theta": jax.random.normal(rng_key, shape=(shape,))} + + return fn + + +# pylint: disable=missing-function-docstring,no-member +def _mh_init(rng_key, n_chains, prior: distrax.Distribution, lp): + init_key, rng_key = jr.split(rng_key) + initial_positions = prior(seed=init_key, sample_shape=(n_chains,)) + kernel = bj.irmh( + lp, _irmh_proposal_distribution(initial_positions.shape[1]) + ) + initial_positions = {"theta": initial_positions} + initial_state = jax.vmap(kernel.init)(initial_positions) + return initial_state, kernel.step diff --git a/sbijax/mcmc/mala.py b/sbijax/mcmc/mala.py new file mode 100644 index 0000000..7a84e67 --- /dev/null +++ b/sbijax/mcmc/mala.py @@ -0,0 +1,63 @@ +import blackjax as bj +import distrax +import jax +from jax import random as jr + + +# pylint: disable=too-many-arguments,unused-argument +def sample_with_mala( + rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs +): + """ + Sample from a distribution using the MALA sampler. + + Parameters + ---------- + rng_seq: hk.PRNGSequence + a hk.PRNGSequence + lp: Callable + the logdensity you wish to sample from + prior: Callable + a function that returns a prior sample + n_chains: int + number of chains to sample + n_samples: int + number of samples per chain + n_warmup: int + number of samples to discard + + Returns + ------- + jnp.ndarrau + a JAX array of dimension n_samples \times n_chains \times len_theta + """ + + def _inference_loop(rng_key, kernel, initial_state, n_samples): + @jax.jit + def _step(states, rng_key): + keys = jax.random.split(rng_key, n_chains) + states, _ = jax.vmap(kernel)(keys, states) + return states, states + + sampling_keys = jax.random.split(rng_key, n_samples) + _, states = jax.lax.scan(_step, initial_state, sampling_keys) + return states + + init_key, rng_key = jr.split(rng_key) + initial_states, kernel = _mala_init(init_key, n_chains, prior, lp) + + states = _inference_loop(init_key, kernel, initial_states, n_samples) + _ = states.position["theta"].block_until_ready() + thetas = states.position["theta"][n_warmup:, :, :] + + return thetas + + +# pylint: disable=missing-function-docstring,no-member +def _mala_init(rng_key, n_chains, prior: distrax.Distribution, lp): + init_key, rng_key = jr.split(rng_key) + initial_positions = prior(seed=init_key, sample_shape=(n_chains,)) + kernel = bj.mala(lp, 1.0) + initial_positions = {"theta": initial_positions} + initial_state = jax.vmap(kernel.init)(initial_positions) + return initial_state, kernel.step diff --git a/sbijax/mcmc/rmh.py b/sbijax/mcmc/rmh.py new file mode 100644 index 0000000..51deb61 --- /dev/null +++ b/sbijax/mcmc/rmh.py @@ -0,0 +1,69 @@ +import blackjax as bj +import distrax +import jax +from jax import numpy as jnp +from jax import random as jr + + +# pylint: disable=too-many-arguments,unused-argument +def sample_with_rmh( + rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs +): + """ + Sample from a distribution using the Rosenbluth-Metropolis-Hastings sampler. + + Parameters + ---------- + rng_seq: hk.PRNGSequence + a hk.PRNGSequence + lp: Callable + the logdensity you wish to sample from + prior: Callable + a function that returns a prior sample + n_chains: int + number of chains to sample + n_samples: int + number of samples per chain + n_warmup: int + number of samples to discard + + Returns + ------- + jnp.ndarrau + a JAX array of dimension n_samples \times n_chains \times len_theta + """ + + def _inference_loop(rng_key, kernel, initial_state, n_samples): + @jax.jit + def _step(states, rng_key): + keys = jax.random.split(rng_key, n_chains) + states, _ = jax.vmap(kernel)(keys, states) + return states, states + + sampling_keys = jax.random.split(rng_key, n_samples) + _, states = jax.lax.scan(_step, initial_state, sampling_keys) + return states + + init_key, rng_key = jr.split(rng_key) + initial_states, kernel = _mh_init(init_key, n_chains, prior, lp) + + states = _inference_loop(init_key, kernel, initial_states, n_samples) + _ = states.position["theta"].block_until_ready() + thetas = states.position["theta"][n_warmup:, :, :] + + return thetas + + +# pylint: disable=missing-function-docstring,no-member +def _mh_init(rng_key, n_chains, prior: distrax.Distribution, lp): + init_key, rng_key = jr.split(rng_key) + initial_positions = prior(seed=init_key, sample_shape=(n_chains,)) + kernel = bj.rmh( + lp, + bj.mcmc.random_walk.normal( + jnp.full_like(initial_positions.shape[1], 0.25) + ), + ) + initial_positions = {"theta": initial_positions} + initial_state = jax.vmap(kernel.init)(initial_positions) + return initial_state, kernel.step diff --git a/sbijax/snl.py b/sbijax/snl.py index f5ef57c..047c0bd 100644 --- a/sbijax/snl.py +++ b/sbijax/snl.py @@ -12,6 +12,9 @@ from sbijax.mcmc import mcmc_diagnostics, sample_with_nuts, sample_with_slice # pylint: disable=too-many-arguments,unused-argument +from sbijax.mcmc.irmh import sample_with_imh +from sbijax.mcmc.mala import sample_with_mala +from sbijax.mcmc.rmh import sample_with_rmh from sbijax.nn.early_stopping import EarlyStopping @@ -283,6 +286,27 @@ def lp__(theta): return jax.vmap(_joint_logdensity_fn)(theta) sampling_fn = sample_with_slice + elif "sampler" in kwargs and kwargs["sampler"] == "rmh": + kwargs.pop("sampler", None) + + def lp__(theta): + return _joint_logdensity_fn(**theta) + + sampling_fn = sample_with_rmh + elif "sampler" in kwargs and kwargs["sampler"] == "imh": + kwargs.pop("sampler", None) + + def lp__(theta): + return _joint_logdensity_fn(**theta) + + sampling_fn = sample_with_imh + elif "sampler" in kwargs and kwargs["sampler"] == "mala": + kwargs.pop("sampler", None) + + def lp__(theta): + return _joint_logdensity_fn(**theta) + + sampling_fn = sample_with_mala else: def lp__(theta):