Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MH variants and MALA #18

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def run():
params=params,
observable=y_observed,
data=data,
sampler="imh",
)
params, info = snl.fit(
jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer
Expand Down
4 changes: 4 additions & 0 deletions examples/slcp_ssnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import argparse
from functools import partial
from timeit import default_timer as timer

import distrax
import haiku as hk
Expand Down Expand Up @@ -189,6 +190,7 @@ def run(use_surjectors):
optimizer = optax.adam(1e-3)

data, params = None, {}
start = timer()
for i in range(5):
data, _ = snl.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(12), i),
Expand All @@ -200,6 +202,8 @@ def run(use_surjectors):
params, info = snl.fit(
jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer
)
end = timer()
print(end - start)

sample_key, rng_key = jr.split(jr.PRNGKey(123))
snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
]
requires-python = ">=3.8"
dependencies = [
"blackjax-nightly>=0.9.6.post127",
"blackjax-nightly>=1.0.0.post17",
"distrax>=0.1.2",
"dm-haiku>=0.0.9",
"optax>=0.1.3",
Expand Down
3 changes: 2 additions & 1 deletion sbijax/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import namedtuple

import chex
import numpy as np
from jax import lax
from jax import numpy as jnp
from jax import random as jr
Expand Down Expand Up @@ -62,7 +63,7 @@ def as_batch_iterator(

def get_batch(idx, idxs=idxs):
start_idx = idx * batch_size
step_size = jnp.minimum(n - start_idx, batch_size)
step_size = np.minimum(n - start_idx, batch_size)
ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size)
batch = {
name: lax.index_take(array, (ret_idx,), axes=(0,))
Expand Down
73 changes: 73 additions & 0 deletions sbijax/mcmc/irmh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import blackjax as bj
import distrax
import jax
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
def sample_with_imh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
"""
Sample from a distribution using the indepdendent Metropolis-Hastings
sampler.

Parameters
----------
rng_seq: hk.PRNGSequence
a hk.PRNGSequence
lp: Callable
the logdensity you wish to sample from
prior: Callable
a function that returns a prior sample
n_chains: int
number of chains to sample
n_samples: int
number of samples per chain
n_warmup: int
number of samples to discard

Returns
-------
jnp.ndarrau
a JAX array of dimension n_samples \times n_chains \times len_theta
"""

def _inference_loop(rng_key, kernel, initial_state, n_samples):
@jax.jit
def _step(states, rng_key):
keys = jax.random.split(rng_key, n_chains)
states, _ = jax.vmap(kernel)(keys, states)
return states, states

sampling_keys = jax.random.split(rng_key, n_samples)
_, states = jax.lax.scan(_step, initial_state, sampling_keys)
return states

init_key, rng_key = jr.split(rng_key)
initial_states, kernel = _mh_init(init_key, n_chains, prior, lp)

states = _inference_loop(init_key, kernel, initial_states, n_samples)
_ = states.position["theta"].block_until_ready()
thetas = states.position["theta"][n_warmup:, :, :]

return thetas


def _irmh_proposal_distribution(shape):
def fn(rng_key):
return {"theta": jax.random.normal(rng_key, shape=(shape,))}

return fn


# pylint: disable=missing-function-docstring,no-member
def _mh_init(rng_key, n_chains, prior: distrax.Distribution, lp):
init_key, rng_key = jr.split(rng_key)
initial_positions = prior(seed=init_key, sample_shape=(n_chains,))
kernel = bj.irmh(
lp, _irmh_proposal_distribution(initial_positions.shape[1])
)
initial_positions = {"theta": initial_positions}
initial_state = jax.vmap(kernel.init)(initial_positions)
return initial_state, kernel.step
63 changes: 63 additions & 0 deletions sbijax/mcmc/mala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import blackjax as bj
import distrax
import jax
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
def sample_with_mala(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
"""
Sample from a distribution using the MALA sampler.

Parameters
----------
rng_seq: hk.PRNGSequence
a hk.PRNGSequence
lp: Callable
the logdensity you wish to sample from
prior: Callable
a function that returns a prior sample
n_chains: int
number of chains to sample
n_samples: int
number of samples per chain
n_warmup: int
number of samples to discard

Returns
-------
jnp.ndarrau
a JAX array of dimension n_samples \times n_chains \times len_theta
"""

def _inference_loop(rng_key, kernel, initial_state, n_samples):
@jax.jit
def _step(states, rng_key):
keys = jax.random.split(rng_key, n_chains)
states, _ = jax.vmap(kernel)(keys, states)
return states, states

sampling_keys = jax.random.split(rng_key, n_samples)
_, states = jax.lax.scan(_step, initial_state, sampling_keys)
return states

init_key, rng_key = jr.split(rng_key)
initial_states, kernel = _mala_init(init_key, n_chains, prior, lp)

states = _inference_loop(init_key, kernel, initial_states, n_samples)
_ = states.position["theta"].block_until_ready()
thetas = states.position["theta"][n_warmup:, :, :]

return thetas


# pylint: disable=missing-function-docstring,no-member
def _mala_init(rng_key, n_chains, prior: distrax.Distribution, lp):
init_key, rng_key = jr.split(rng_key)
initial_positions = prior(seed=init_key, sample_shape=(n_chains,))
kernel = bj.mala(lp, 1.0)
initial_positions = {"theta": initial_positions}
initial_state = jax.vmap(kernel.init)(initial_positions)
return initial_state, kernel.step
69 changes: 69 additions & 0 deletions sbijax/mcmc/rmh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import blackjax as bj
import distrax
import jax
from jax import numpy as jnp
from jax import random as jr


# pylint: disable=too-many-arguments,unused-argument
def sample_with_rmh(
rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs
):
"""
Sample from a distribution using the Rosenbluth-Metropolis-Hastings sampler.

Parameters
----------
rng_seq: hk.PRNGSequence
a hk.PRNGSequence
lp: Callable
the logdensity you wish to sample from
prior: Callable
a function that returns a prior sample
n_chains: int
number of chains to sample
n_samples: int
number of samples per chain
n_warmup: int
number of samples to discard

Returns
-------
jnp.ndarrau
a JAX array of dimension n_samples \times n_chains \times len_theta
"""

def _inference_loop(rng_key, kernel, initial_state, n_samples):
@jax.jit
def _step(states, rng_key):
keys = jax.random.split(rng_key, n_chains)
states, _ = jax.vmap(kernel)(keys, states)
return states, states

sampling_keys = jax.random.split(rng_key, n_samples)
_, states = jax.lax.scan(_step, initial_state, sampling_keys)
return states

init_key, rng_key = jr.split(rng_key)
initial_states, kernel = _mh_init(init_key, n_chains, prior, lp)

states = _inference_loop(init_key, kernel, initial_states, n_samples)
_ = states.position["theta"].block_until_ready()
thetas = states.position["theta"][n_warmup:, :, :]

return thetas


# pylint: disable=missing-function-docstring,no-member
def _mh_init(rng_key, n_chains, prior: distrax.Distribution, lp):
init_key, rng_key = jr.split(rng_key)
initial_positions = prior(seed=init_key, sample_shape=(n_chains,))
kernel = bj.rmh(
lp,
bj.mcmc.random_walk.normal(
jnp.full_like(initial_positions.shape[1], 0.25)
),
)
initial_positions = {"theta": initial_positions}
initial_state = jax.vmap(kernel.init)(initial_positions)
return initial_state, kernel.step
24 changes: 24 additions & 0 deletions sbijax/snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from sbijax.mcmc import mcmc_diagnostics, sample_with_nuts, sample_with_slice

# pylint: disable=too-many-arguments,unused-argument
from sbijax.mcmc.irmh import sample_with_imh
from sbijax.mcmc.mala import sample_with_mala
from sbijax.mcmc.rmh import sample_with_rmh
from sbijax.nn.early_stopping import EarlyStopping


Expand Down Expand Up @@ -283,6 +286,27 @@ def lp__(theta):
return jax.vmap(_joint_logdensity_fn)(theta)

sampling_fn = sample_with_slice
elif "sampler" in kwargs and kwargs["sampler"] == "rmh":
kwargs.pop("sampler", None)

def lp__(theta):
return _joint_logdensity_fn(**theta)

sampling_fn = sample_with_rmh
elif "sampler" in kwargs and kwargs["sampler"] == "imh":
kwargs.pop("sampler", None)

def lp__(theta):
return _joint_logdensity_fn(**theta)

sampling_fn = sample_with_imh
elif "sampler" in kwargs and kwargs["sampler"] == "mala":
kwargs.pop("sampler", None)

def lp__(theta):
return _joint_logdensity_fn(**theta)

sampling_fn = sample_with_mala
else:

def lp__(theta):
Expand Down