Skip to content

Commit

Permalink
Slice summary methods (#19)
Browse files Browse the repository at this point in the history
* Increment version
* Add SNASSSand SNASSS
* Fix tests after surjectors update
  • Loading branch information
dirmeier authored Feb 21, 2024
1 parent 9259714 commit 9f59b0e
Show file tree
Hide file tree
Showing 20 changed files with 813 additions and 44 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -41,7 +41,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -61,7 +61,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 22.3.0
rev: 24.2.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
Expand Down
2 changes: 1 addition & 1 deletion examples/bivariate_gaussian_smcabc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Example using ABC
Example using sequential Monte Carlo ABC on a bivariate Gaussian
"""

import distrax
Expand Down
123 changes: 123 additions & 0 deletions examples/bivariate_gaussian_snasss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Example using sequential neural approximate (slice) summary statistics on a
bivariate Gaussian with repeated dimensions
"""

import distrax
import haiku as hk
import jax.nn
import matplotlib.pyplot as plt
import optax
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr
from surjectors import (
Chain,
MaskedAutoregressive,
Permutation,
TransformedDistribution,
)
from surjectors.nn import MADE
from surjectors.util import unstack

from sbijax import SNASSS
from sbijax.nn.snasss_net import SNASSSNet

W = jr.normal(jr.PRNGKey(0), (2, 10))


def prior_model_fns():
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1)
return p.sample, p.log_prob


def simulator_fn(seed, theta):
y = theta @ W
y = y + distrax.Normal(jnp.zeros_like(y), 0.1).sample(seed=seed)
return y


def make_model(dim):
def _bijector_fn(params):
means, log_scales = unstack(params, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))

def _flow(method, **kwargs):
layers = []
order = jnp.arange(dim)
for i in range(5):
layer = MaskedAutoregressive(
bijector_fn=_bijector_fn,
conditioner=MADE(
dim,
[50, 50, dim * 2],
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
activation=jax.nn.tanh,
),
)
order = order[::-1]
layers.append(layer)
layers.append(Permutation(order, 1))
chain = Chain(layers)

base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(dim), jnp.ones(dim)),
1,
)
td = TransformedDistribution(base_distribution, chain)
return td(method, **kwargs)

td = hk.transform(_flow)
td = hk.without_apply_rng(td)
return td


def make_critic(dim):
@hk.without_apply_rng
@hk.transform
def _net(method, **kwargs):
net = SNASSSNet([64, 64, dim], [64, 64, 1], [64, 64, 1])
return net(method, **kwargs)

return _net


def run():
y_observed = jnp.array([[2.0, -2.0]]) @ W

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SNASSS(fns, make_model(2), make_critic(2))
optimizer = optax.adam(1e-3)

data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=params,
observable=y_observed,
data=data,
)
params, _ = estim.fit(
jr.fold_in(jr.PRNGKey(2), i),
data=data,
optimizer=optimizer,
batch_size=100,
)

rng_key = jr.PRNGKey(23)
snp_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes):
sns.histplot(snp_samples[:, i], color="darkblue", ax=ax)
ax.set_xlim([-3.0, 3.0])
sns.despine()
plt.tight_layout()
plt.show()


if __name__ == "__main__":
run()
4 changes: 2 additions & 2 deletions examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Example using SNL and masked autoregressive flows flows
Example using sequential neural likelihood estimation on a bivariate Gaussian
"""

from functools import partial
Expand All @@ -13,12 +13,12 @@
from jax import numpy as jnp
from jax import random as jr
from surjectors import (
MADE,
Chain,
MaskedAutoregressive,
Permutation,
TransformedDistribution,
)
from surjectors.conditioners import MADE
from surjectors.util import unstack

from sbijax import SNL
Expand Down
13 changes: 8 additions & 5 deletions examples/bivariate_gaussian_snp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Example using SNP and masked autoregressive flows
Example using sequential posterior estimation on a bivariate Gaussian
"""

import distrax
Expand All @@ -10,10 +10,13 @@
import seaborn as sns
from jax import numpy as jnp
from jax import random as jr
from surjectors import Chain, TransformedDistribution
from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors.bijectors.permutation import Permutation
from surjectors.conditioners import MADE
from surjectors import (
Chain,
MaskedAutoregressive,
Permutation,
TransformedDistribution,
)
from surjectors.nn import MADE
from surjectors.util import unstack

from sbijax import SNP
Expand Down
3 changes: 1 addition & 2 deletions examples/slcp_ssnl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
SLCP example from [1] using SNL and masked autoregressive bijections
or surjections
Example using SSNL on the SLCP experiment
"""

import argparse
Expand Down
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
requires-python = ">=3.8"
requires-python = ">=3.9"
dependencies = [
"blackjax-nightly>=1.0.0.post17",
"distrax>=0.1.2",
"dm-haiku>=0.0.9",
"optax>=0.1.3",
"surjectors>=0.2.2",
"surjectors>=0.3.0",
"tfp-nightly>=0.20.0.dev20230404"
]
dynamic = ["version"]
Expand Down Expand Up @@ -59,8 +60,7 @@ test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax'

[tool.black]
line-length = 80
extend-ignore = "E203"
target-version = ['py39']
target-version = ['py311']
exclude = '''
/(
\.eggs
Expand All @@ -84,7 +84,7 @@ include_trailing_comma = true

[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503", "E731"]
extend-ignore = ["E203", "W503", "E731", "E231"]
per-file-ignores = [
'__init__.py:F401',
]
Expand Down
4 changes: 3 additions & 1 deletion sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.1.3"
__version__ = "0.1.5"


from sbijax.abc.rejection_abc import RejectionABC
from sbijax.abc.smc_abc import SMCABC
from sbijax.snass import SNASS
from sbijax.snasss import SNASSS
from sbijax.snl import SNL
from sbijax.snp import SNP
3 changes: 0 additions & 3 deletions sbijax/_sne_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC
from typing import Iterable

import chex
from jax import numpy as jnp
Expand All @@ -21,8 +20,6 @@ def __init__(self, model_fns, density_estimator):
super().__init__(model_fns)
self.model = density_estimator
self.n_total_simulations = 0
self._train_iter: Iterable
self._val_iter: Iterable

def simulate_data_and_possibly_append(
self,
Expand Down
4 changes: 1 addition & 3 deletions sbijax/abc/smc_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ def _move(
)
n -= len(idxs)

new_particles = new_particles[
:n_particles,
]
new_particles = new_particles[:n_particles,]
new_log_weights = self._new_log_weights(
new_particles, particles, log_weights, cov_chol_factor
)
Expand Down
14 changes: 10 additions & 4 deletions sbijax/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import namedtuple

import chex
import numpy as np
from jax import lax
from jax import numpy as jnp
from jax import random as jr
Expand All @@ -11,13 +10,20 @@

# pylint: disable=missing-class-docstring,too-few-public-methods
class DataLoader:
def __init__(self, num_batches, idxs, get_batch):
def __init__(self, num_batches, idxs=None, get_batch=None, batches=None):
self.num_batches = num_batches
self.idxs = idxs
self.num_samples = len(idxs)
if idxs is not None:
self.num_samples = len(idxs)
else:
self.num_samples = self.num_batches * batches[0]["y"].shape[0]
self.get_batch = get_batch
self.batches = batches

def __call__(self, idx, idxs=None):
if self.batches is not None:
return self.batches[idx]

if idxs is None:
idxs = self.idxs
return self.get_batch(idx, idxs)
Expand Down Expand Up @@ -63,7 +69,7 @@ def as_batch_iterator(

def get_batch(idx, idxs=idxs):
start_idx = idx * batch_size
step_size = np.minimum(n - start_idx, batch_size)
step_size = jnp.minimum(n - start_idx, batch_size)
ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size)
batch = {
name: lax.index_take(array, (ret_idx,), axes=(0,))
Expand Down
Loading

0 comments on commit 9f59b0e

Please sign in to comment.