Skip to content

Commit

Permalink
Add MH variants and MALA (#18)
Browse files Browse the repository at this point in the history
* Add MH variants and MALA
  • Loading branch information
dirmeier authored Feb 14, 2024
1 parent 0d9e5cd commit 9259714
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def run():
params=params,
observable=y_observed,
data=data,
sampler="imh",
)
params, info = snl.fit(
jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer
Expand Down
4 changes: 4 additions & 0 deletions examples/slcp_ssnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import argparse
from functools import partial
from timeit import default_timer as timer

import distrax
import haiku as hk
Expand Down Expand Up @@ -189,6 +190,7 @@ def run(use_surjectors):
optimizer = optax.adam(1e-3)

data, params = None, {}
start = timer()
for i in range(5):
data, _ = snl.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(12), i),
Expand All @@ -200,6 +202,8 @@ def run(use_surjectors):
params, info = snl.fit(
jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer
)
end = timer()
print(end - start)

sample_key, rng_key = jr.split(jr.PRNGKey(123))
snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
]
requires-python = ">=3.8"
dependencies = [
"blackjax-nightly>=0.9.6.post127",
"blackjax-nightly>=1.0.0.post17",
"distrax>=0.1.2",
"dm-haiku>=0.0.9",
"optax>=0.1.3",
Expand Down
3 changes: 2 additions & 1 deletion sbijax/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import namedtuple

import chex
import numpy as np
from jax import lax
from jax import numpy as jnp
from jax import random as jr
Expand Down Expand Up @@ -62,7 +63,7 @@ def as_batch_iterator(

def get_batch(idx, idxs=idxs):
start_idx = idx * batch_size
step_size = jnp.minimum(n - start_idx, batch_size)
step_size = np.minimum(n - start_idx, batch_size)
ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size)
batch = {
name: lax.index_take(array, (ret_idx,), axes=(0,))
Expand Down
73 changes: 73 additions & 0 deletions sbijax/mcmc/irmh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import blackjax as bj
import distrax
import jax
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
def sample_with_imh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
"""
Sample from a distribution using the indepdendent Metropolis-Hastings
sampler.
Parameters
----------
rng_seq: hk.PRNGSequence
a hk.PRNGSequence
lp: Callable
the logdensity you wish to sample from
prior: Callable
a function that returns a prior sample
n_chains: int
number of chains to sample
n_samples: int
number of samples per chain
n_warmup: int
number of samples to discard
Returns
-------
jnp.ndarrau
a JAX array of dimension n_samples \times n_chains \times len_theta
"""

def _inference_loop(rng_key, kernel, initial_state, n_samples):
@jax.jit
def _step(states, rng_key):
keys = jax.random.split(rng_key, n_chains)
states, _ = jax.vmap(kernel)(keys, states)
return states, states

sampling_keys = jax.random.split(rng_key, n_samples)
_, states = jax.lax.scan(_step, initial_state, sampling_keys)
return states

init_key, rng_key = jr.split(rng_key)
initial_states, kernel = _mh_init(init_key, n_chains, prior, lp)

states = _inference_loop(init_key, kernel, initial_states, n_samples)
_ = states.position["theta"].block_until_ready()
thetas = states.position["theta"][n_warmup:, :, :]

return thetas


def _irmh_proposal_distribution(shape):
def fn(rng_key):
return {"theta": jax.random.normal(rng_key, shape=(shape,))}

return fn


# pylint: disable=missing-function-docstring,no-member
def _mh_init(rng_key, n_chains, prior: distrax.Distribution, lp):
init_key, rng_key = jr.split(rng_key)
initial_positions = prior(seed=init_key, sample_shape=(n_chains,))
kernel = bj.irmh(
lp, _irmh_proposal_distribution(initial_positions.shape[1])
)
initial_positions = {"theta": initial_positions}
initial_state = jax.vmap(kernel.init)(initial_positions)
return initial_state, kernel.step
63 changes: 63 additions & 0 deletions sbijax/mcmc/mala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import blackjax as bj
import distrax
import jax
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
def sample_with_mala(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
"""
Sample from a distribution using the MALA sampler.
Parameters
----------
rng_seq: hk.PRNGSequence
a hk.PRNGSequence
lp: Callable
the logdensity you wish to sample from
prior: Callable
a function that returns a prior sample
n_chains: int
number of chains to sample
n_samples: int
number of samples per chain
n_warmup: int
number of samples to discard
Returns
-------
jnp.ndarrau
a JAX array of dimension n_samples \times n_chains \times len_theta
"""

def _inference_loop(rng_key, kernel, initial_state, n_samples):
@jax.jit
def _step(states, rng_key):
keys = jax.random.split(rng_key, n_chains)
states, _ = jax.vmap(kernel)(keys, states)
return states, states

sampling_keys = jax.random.split(rng_key, n_samples)
_, states = jax.lax.scan(_step, initial_state, sampling_keys)
return states

init_key, rng_key = jr.split(rng_key)
initial_states, kernel = _mala_init(init_key, n_chains, prior, lp)

states = _inference_loop(init_key, kernel, initial_states, n_samples)
_ = states.position["theta"].block_until_ready()
thetas = states.position["theta"][n_warmup:, :, :]

return thetas


# pylint: disable=missing-function-docstring,no-member
def _mala_init(rng_key, n_chains, prior: distrax.Distribution, lp):
init_key, rng_key = jr.split(rng_key)
initial_positions = prior(seed=init_key, sample_shape=(n_chains,))
kernel = bj.mala(lp, 1.0)
initial_positions = {"theta": initial_positions}
initial_state = jax.vmap(kernel.init)(initial_positions)
return initial_state, kernel.step
69 changes: 69 additions & 0 deletions sbijax/mcmc/rmh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import blackjax as bj
import distrax
import jax
from jax import numpy as jnp
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
def sample_with_rmh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
"""
Sample from a distribution using the Rosenbluth-Metropolis-Hastings sampler.
Parameters
----------
rng_seq: hk.PRNGSequence
a hk.PRNGSequence
lp: Callable
the logdensity you wish to sample from
prior: Callable
a function that returns a prior sample
n_chains: int
number of chains to sample
n_samples: int
number of samples per chain
n_warmup: int
number of samples to discard
Returns
-------
jnp.ndarrau
a JAX array of dimension n_samples \times n_chains \times len_theta
"""

def _inference_loop(rng_key, kernel, initial_state, n_samples):
@jax.jit
def _step(states, rng_key):
keys = jax.random.split(rng_key, n_chains)
states, _ = jax.vmap(kernel)(keys, states)
return states, states

sampling_keys = jax.random.split(rng_key, n_samples)
_, states = jax.lax.scan(_step, initial_state, sampling_keys)
return states

init_key, rng_key = jr.split(rng_key)
initial_states, kernel = _mh_init(init_key, n_chains, prior, lp)

states = _inference_loop(init_key, kernel, initial_states, n_samples)
_ = states.position["theta"].block_until_ready()
thetas = states.position["theta"][n_warmup:, :, :]

return thetas


# pylint: disable=missing-function-docstring,no-member
def _mh_init(rng_key, n_chains, prior: distrax.Distribution, lp):
init_key, rng_key = jr.split(rng_key)
initial_positions = prior(seed=init_key, sample_shape=(n_chains,))
kernel = bj.rmh(
lp,
bj.mcmc.random_walk.normal(
jnp.full_like(initial_positions.shape[1], 0.25)
),
)
initial_positions = {"theta": initial_positions}
initial_state = jax.vmap(kernel.init)(initial_positions)
return initial_state, kernel.step
24 changes: 24 additions & 0 deletions sbijax/snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from sbijax.mcmc import mcmc_diagnostics, sample_with_nuts, sample_with_slice

# pylint: disable=too-many-arguments,unused-argument
from sbijax.mcmc.irmh import sample_with_imh
from sbijax.mcmc.mala import sample_with_mala
from sbijax.mcmc.rmh import sample_with_rmh
from sbijax.nn.early_stopping import EarlyStopping


Expand Down Expand Up @@ -283,6 +286,27 @@ def lp__(theta):
return jax.vmap(_joint_logdensity_fn)(theta)

sampling_fn = sample_with_slice
elif "sampler" in kwargs and kwargs["sampler"] == "rmh":
kwargs.pop("sampler", None)

def lp__(theta):
return _joint_logdensity_fn(**theta)

sampling_fn = sample_with_rmh
elif "sampler" in kwargs and kwargs["sampler"] == "imh":
kwargs.pop("sampler", None)

def lp__(theta):
return _joint_logdensity_fn(**theta)

sampling_fn = sample_with_imh
elif "sampler" in kwargs and kwargs["sampler"] == "mala":
kwargs.pop("sampler", None)

def lp__(theta):
return _joint_logdensity_fn(**theta)

sampling_fn = sample_with_mala
else:

def lp__(theta):
Expand Down

0 comments on commit 9259714

Please sign in to comment.