From cf8fe67fb16fbdeccf5e72ed87b63362938480ca Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 2 Oct 2023 22:38:16 +0200 Subject: [PATCH 1/5] Refactor --- README.md | 4 + Untitled.ipynb | 442 +++++++++++++++++++++ examples/bivariate_gaussian_smcabc.py | 8 +- examples/bivariate_gaussian_snl.py | 40 +- examples/bivariate_gaussian_snp.py | 36 +- examples/slcp_smcabc.py | 147 ------- examples/slcp_snl_masked_autoregressive.py | 230 ----------- examples/slcp_snl_masked_coupling.py | 208 ---------- pyproject.toml | 2 +- sbijax/_sbi_base.py | 71 +--- sbijax/_sne_base.py | 80 ++-- sbijax/abc/rejection_abc.py | 41 +- sbijax/abc/smc_abc.py | 84 ++-- sbijax/generator.py | 10 +- sbijax/mcmc/nuts.py | 30 +- sbijax/mcmc/slice.py | 27 +- sbijax/nn/__init__.py | 0 sbijax/nn/early_stopping.py | 36 ++ sbijax/snl.py | 321 ++++++++------- sbijax/snl_test.py | 25 +- sbijax/snp.py | 227 ++++++----- sbijax/snp_test.py | 24 +- 22 files changed, 1041 insertions(+), 1052 deletions(-) create mode 100644 Untitled.ipynb delete mode 100644 examples/slcp_smcabc.py delete mode 100644 examples/slcp_snl_masked_autoregressive.py delete mode 100644 examples/slcp_snl_masked_coupling.py create mode 100644 sbijax/nn/__init__.py create mode 100644 sbijax/nn/early_stopping.py diff --git a/README.md b/README.md index 1acce8d..e29b69f 100644 --- a/README.md +++ b/README.md @@ -16,11 +16,15 @@ SbiJAX so far implements - Rejection ABC (`RejectionABC`), - Sequential Monte Carlo ABC (`SMCABC`), - Sequential Neural Likelihood Estimation (`SNL`) +- Surjective Sequential Neural Likelihood Estimation (`SSNL`) +- Sequential Neural Posterior Estimation C (short `SNP`) ## Examples You can find several self-contained examples on how to use the algorithms in `examples`. +## Usage + ## Installation Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU, diff --git a/Untitled.ipynb b/Untitled.ipynb new file mode 100644 index 0000000..9a51e15 --- /dev/null +++ b/Untitled.ipynb @@ -0,0 +1,442 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "05a45ed4-97f7-4bba-9581-a37e978180a5", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "#os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00d2daa7-a509-47e1-86c6-752a75b47d0b", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "from jax import scipy as jsp\n", + "\n", + "import blackjax as bj\n", + "import tensorflow_probability.substrates.jax as tfp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5c5bdce-fe4c-4fb3-a0c1-15911ba2b40a", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "from jax.lib import xla_bridge" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c522c600-31a4-405c-bd12-ebabcc77131d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[CpuDevice(id=0)]\n", + "cpu\n", + "1\n", + "cpu\n" + ] + } + ], + "source": [ + "print(jax.devices())\n", + "print(jax.default_backend())\n", + "print(jax.device_count())\n", + "print(xla_bridge.get_backend().platform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8198dccd-d852-4dbc-a98f-486b8733a030", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "import distrax\n", + "import haiku as hk\n", + "import jax\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from jax import numpy as jnp\n", + "from jax import random\n", + "from jax import scipy as jsp\n", + "from jax import vmap\n", + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e090efd-dfff-4fd1-b421-71f0166ffe59", + "metadata": {}, + "outputs": [], + "source": [ + "def likelihood_fn(theta):\n", + " mu = jnp.tile(theta[:2], 4)\n", + " s1, s2 = theta[2] ** 2, theta[3] ** 2\n", + " corr = s1 * s2 * jnp.tanh(theta[4])\n", + " cov = jnp.array([[s1**2, corr], [corr, s2**2]])\n", + " cov = jsp.linalg.block_diag(*[cov for _ in range(4)])\n", + " p = distrax.MultivariateNormalFullCovariance(mu, cov)\n", + " return p" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bbedba0-346c-4604-ab7f-5318e9d3b15b", + "metadata": {}, + "outputs": [], + "source": [ + "lik = likelihood_fn(jr.normal(jr.PRNGKey(123), (5,)))\n", + "y = lik.sample(seed=jr.PRNGKey(1), sample_shape=(10,))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45482e72-6971-48cb-a58c-2a79be31b443", + "metadata": {}, + "outputs": [], + "source": [ + "def likelihood_fn(theta, y):\n", + " mu = jnp.tile(theta[:2], 4)\n", + " s1, s2 = theta[2] ** 2, theta[3] ** 2\n", + " corr = s1 * s2 * jnp.tanh(theta[4])\n", + " cov = jnp.array([[s1**2, corr], [corr, s2**2]])\n", + " cov = jsp.linalg.block_diag(*[cov for _ in range(4)])\n", + " p = distrax.MultivariateNormalFullCovariance(mu, cov)\n", + " return p.log_prob(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e07a4d59-80f8-4303-982e-fabf2acf24c5", + "metadata": {}, + "outputs": [], + "source": [ + "prior = distrax.Independent(\n", + " distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d898ca02-22f3-4102-b667-c2ef7abdd492", + "metadata": {}, + "outputs": [], + "source": [ + "def log_density_fn(theta, y):\n", + " prior_lp = prior.log_prob(theta)\n", + " likelihood_lp = likelihood_fn(theta, y)\n", + "\n", + " lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)\n", + " return lp\n", + "\n", + "target_log_prob_fn_partial = partial(log_density_fn, y=y)\n", + "target_log_prob_fn = lambda theta: jax.vmap(target_log_prob_fn_partial)(theta)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55e0fa2c-6aaa-4ab6-922a-1b5f2559b5d1", + "metadata": {}, + "outputs": [], + "source": [ + "n_samples = 10000\n", + "n_warmup = 5000\n", + "n_chains = 4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "263b64ec-8640-4f61-8188-0f0953d53668", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'functools.partial' object has no attribute 'experimental_default_event_space_bijector'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtfp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmcmc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwindowed_adaptive_nuts\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget_log_prob_fn_partial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_warmup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_step_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_tree_depth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_energy_diff\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000.0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43munrolled_leapfrog_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mparallel_iterations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:710\u001b[0m, in \u001b[0;36mwindowed_adaptive_nuts\u001b[0;34m(n_draws, joint_dist, n_chains, num_adaptation_steps, current_state, init_step_size, dual_averaging_kwargs, max_tree_depth, max_energy_diff, unrolled_leapfrog_steps, parallel_iterations, trace_fn, return_final_kernel_results, discard_tuning, chain_axis_names, seed, **pins)\u001b[0m\n\u001b[1;32m 703\u001b[0m dual_averaging_kwargs\u001b[38;5;241m.\u001b[39msetdefault(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtarget_accept_prob\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m0.85\u001b[39m)\n\u001b[1;32m 704\u001b[0m proposal_kernel_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 705\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep_size\u001b[39m\u001b[38;5;124m'\u001b[39m: init_step_size,\n\u001b[1;32m 706\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_tree_depth\u001b[39m\u001b[38;5;124m'\u001b[39m: max_tree_depth,\n\u001b[1;32m 707\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_energy_diff\u001b[39m\u001b[38;5;124m'\u001b[39m: max_energy_diff,\n\u001b[1;32m 708\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124munrolled_leapfrog_steps\u001b[39m\u001b[38;5;124m'\u001b[39m: unrolled_leapfrog_steps,\n\u001b[1;32m 709\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparallel_iterations\u001b[39m\u001b[38;5;124m'\u001b[39m: parallel_iterations}\n\u001b[0;32m--> 710\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_windowed_adaptive_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 711\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_draws\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_draws\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 712\u001b[0m \u001b[43m \u001b[49m\u001b[43mjoint_dist\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjoint_dist\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 713\u001b[0m \u001b[43m \u001b[49m\u001b[43mkind\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mnuts\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 714\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_chains\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[43m \u001b[49m\u001b[43mproposal_kernel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproposal_kernel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 716\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 717\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 718\u001b[0m \u001b[43m \u001b[49m\u001b[43mdual_averaging_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdual_averaging_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 719\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrace_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrace_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 720\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_final_kernel_results\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_final_kernel_results\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 721\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiscard_tuning\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdiscard_tuning\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 722\u001b[0m \u001b[43m \u001b[49m\u001b[43mchain_axis_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchain_axis_names\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 723\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 724\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpins\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:890\u001b[0m, in \u001b[0;36m_windowed_adaptive_impl\u001b[0;34m(n_draws, joint_dist, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, chain_axis_names, **pins)\u001b[0m\n\u001b[1;32m 885\u001b[0m dual_averaging_kwargs\u001b[38;5;241m.\u001b[39msetdefault(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mexperimental_reduce_chain_axis_names\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 886\u001b[0m chain_axis_names)\n\u001b[1;32m 887\u001b[0m setup_seed, sample_seed \u001b[38;5;241m=\u001b[39m samplers\u001b[38;5;241m.\u001b[39msplit_seed(\n\u001b[1;32m 888\u001b[0m samplers\u001b[38;5;241m.\u001b[39msanitize_seed(seed), n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 889\u001b[0m (target_log_prob_fn, initial_transformed_position, bijector,\n\u001b[0;32m--> 890\u001b[0m step_broadcast, batch_shape, shard_axis_names) \u001b[38;5;241m=\u001b[39m \u001b[43m_setup_mcmc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 891\u001b[0m \u001b[43m \u001b[49m\u001b[43mjoint_dist\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 892\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_chains\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 893\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 894\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msetup_seed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 895\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpins\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 897\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m proposal_kernel_kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep_size\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 898\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_shape\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m (\u001b[38;5;241m0\u001b[39m,): \u001b[38;5;66;03m# Scalar batch has a 0-vector shape.\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:226\u001b[0m, in \u001b[0;36m_setup_mcmc\u001b[0;34m(model, n_chains, init_position, seed, **pins)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Construct bijector and transforms needed for windowed MCMC.\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \n\u001b[1;32m 194\u001b[0m \u001b[38;5;124;03mThis pins the initial model, constructs a bijector that unconstrains and\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;124;03m shard_axis_names: Shard axis names for the model\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 225\u001b[0m pinned_model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mexperimental_pin(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpins) \u001b[38;5;28;01mif\u001b[39;00m pins \u001b[38;5;28;01melse\u001b[39;00m model\n\u001b[0;32m--> 226\u001b[0m bijector, step_bijector \u001b[38;5;241m=\u001b[39m \u001b[43m_get_flat_unconstraining_bijector\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpinned_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_position \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 229\u001b[0m raw_init_dist \u001b[38;5;241m=\u001b[39m initialization\u001b[38;5;241m.\u001b[39minit_near_unconstrained_zero(pinned_model)\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:172\u001b[0m, in \u001b[0;36m_get_flat_unconstraining_bijector\u001b[0;34m(jd_model)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Create a bijector from a joint distribution that flattens and unconstrains.\u001b[39;00m\n\u001b[1;32m 153\u001b[0m \n\u001b[1;32m 154\u001b[0m \u001b[38;5;124;03mThe intention is (loosely) to go from a model joint distribution supported on\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;124;03m points, and the second may be used to initialize a step size.\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# TODO(b/180396233): This bijector is in general point-dependent.\u001b[39;00m\n\u001b[0;32m--> 172\u001b[0m event_space_bij \u001b[38;5;241m=\u001b[39m \u001b[43mjd_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental_default_event_space_bijector\u001b[49m()\n\u001b[1;32m 173\u001b[0m flat_bijector \u001b[38;5;241m=\u001b[39m restructure\u001b[38;5;241m.\u001b[39mpack_sequence_as(jd_model\u001b[38;5;241m.\u001b[39mevent_shape_tensor())\n\u001b[1;32m 175\u001b[0m unconstrained_shapes \u001b[38;5;241m=\u001b[39m event_space_bij(\n\u001b[1;32m 176\u001b[0m flat_bijector)\u001b[38;5;241m.\u001b[39minverse_event_shape_tensor(jd_model\u001b[38;5;241m.\u001b[39mevent_shape_tensor())\n", + "\u001b[0;31mAttributeError\u001b[0m: 'functools.partial' object has no attribute 'experimental_default_event_space_bijector'" + ] + } + ], + "source": [ + "nuts = tfp.mcmc.NoUTurnSampler(\n", + " target_log_prob_fn,\n", + " step_size=0.1,\n", + " max_tree_depth=10,\n", + " max_energy_diff=1000.0,\n", + " unrolled_leapfrog_steps=1,\n", + ")\n", + "\n", + "\n", + "nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(\n", + " inner_kernel=nuts,\n", + " num_adaptation_steps=int(0.8 * n_warmup),\n", + " target_accept_prob=jnp.asarray(0.75, jnp.float32)\n", + ")\n", + "\n", + "tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(\n", + " nuts,\n", + " initial_running_variance,\n", + " num_estimation_steps=None,\n", + " momentum_distribution_setter_fn=hmc_like_momentum_distribution_setter_fn,\n", + " momentum_distribution_getter_fn=hmc_like_momentum_distribution_getter_fn,\n", + " validate_args=False,\n", + " experimental_shard_axis_names=None,\n", + " name=None\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d826bb25-929a-4273-aab6-d281a54a0844", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.device_count()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b59b71ae-963c-43f1-ab00-4d3881354864", + "metadata": {}, + "outputs": [], + "source": [ + "initial_states = jr.normal(jr.PRNGKey(4), shape=(n_chains, 5))\n", + "samples = tfp.mcmc.sample_chain(\n", + " num_results=n_samples - n_warmup,\n", + " current_state=initial_states,\n", + " num_steps_between_results=1,\n", + " kernel=adaptive_sampler,\n", + " num_burnin_steps=n_warmup,\n", + " trace_fn=None,\n", + " seed=jr.PRNGKey(2),\n", + ")\n", + "samples = samples[n_warmup:, ...]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "99db6c8b-49db-435e-a886-e5968106c512", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.devices()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "e6d8884c-1153-48b3-867a-cac4020e344b", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[24], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msfunc\u001b[39m(x): \n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m: \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43msfunc\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/api.py:2253\u001b[0m, in \u001b[0;36m_cpp_pmap..cache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 2251\u001b[0m execute: Optional[functools\u001b[38;5;241m.\u001b[39mpartial] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 2252\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(top_trace, core\u001b[38;5;241m.\u001b[39mEvalTrace):\n\u001b[0;32m-> 2253\u001b[0m execute \u001b[38;5;241m=\u001b[39m \u001b[43mpxla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mxla_pmap_impl_lazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2254\u001b[0m out \u001b[38;5;241m=\u001b[39m map_bind_continuation(execute(\u001b[38;5;241m*\u001b[39mtracers))\n\u001b[1;32m 2255\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:974\u001b[0m, in \u001b[0;36mxla_pmap_impl_lazy\u001b[0;34m(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)\u001b[0m\n\u001b[1;32m 972\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _emap_apply_fn\n\u001b[1;32m 973\u001b[0m abstract_args \u001b[38;5;241m=\u001b[39m unsafe_map(xla\u001b[38;5;241m.\u001b[39mabstractify, args)\n\u001b[0;32m--> 974\u001b[0m compiled_fun, fingerprint \u001b[38;5;241m=\u001b[39m \u001b[43mparallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 975\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_axis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 976\u001b[0m \u001b[43m \u001b[49m\u001b[43min_axes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_axes_thunk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mabstract_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# Don't re-abstractify args unless logging is enabled for performance.\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39mjax_distributed_debug:\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/linear_util.py:303\u001b[0m, in \u001b[0;36mcache..memoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 301\u001b[0m fun\u001b[38;5;241m.\u001b[39mpopulate_stores(stores)\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 303\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 304\u001b[0m cache[key] \u001b[38;5;241m=\u001b[39m (ans, fun\u001b[38;5;241m.\u001b[39mstores)\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ans\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1245\u001b[0m, in \u001b[0;36mparallel_callable\u001b[0;34m(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *avals)\u001b[0m\n\u001b[1;32m 1232\u001b[0m \u001b[38;5;129m@lu\u001b[39m\u001b[38;5;241m.\u001b[39mcache\n\u001b[1;32m 1233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_callable\u001b[39m(fun: lu\u001b[38;5;241m.\u001b[39mWrappedFun,\n\u001b[1;32m 1234\u001b[0m backend_name: Optional[\u001b[38;5;28mstr\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1243\u001b[0m global_arg_shapes: Sequence[Optional[Tuple[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]]],\n\u001b[1;32m 1244\u001b[0m \u001b[38;5;241m*\u001b[39mavals):\n\u001b[0;32m-> 1245\u001b[0m pmap_computation \u001b[38;5;241m=\u001b[39m \u001b[43mlower_parallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1246\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_axis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43min_axes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_axes_thunk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mavals\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1248\u001b[0m pmap_executable \u001b[38;5;241m=\u001b[39m pmap_computation\u001b[38;5;241m.\u001b[39mcompile()\n\u001b[1;32m 1249\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m WeakRefList([pmap_executable\u001b[38;5;241m.\u001b[39munsafe_call, pmap_executable\u001b[38;5;241m.\u001b[39mfingerprint])\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1414\u001b[0m, in \u001b[0;36mlower_parallel_callable\u001b[0;34m(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, avals)\u001b[0m\n\u001b[1;32m 1409\u001b[0m must_run_on_all_devices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 1411\u001b[0m pci \u001b[38;5;241m=\u001b[39m ParallelCallableInfo(\n\u001b[1;32m 1412\u001b[0m name, backend, axis_name, axis_size, global_axis_size, devices,\n\u001b[1;32m 1413\u001b[0m in_axes, out_axes_thunk, avals)\n\u001b[0;32m-> 1414\u001b[0m jaxpr, consts, replicas, parts, shards \u001b[38;5;241m=\u001b[39m \u001b[43mstage_parallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1415\u001b[0m \u001b[43m \u001b[49m\u001b[43mpci\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1417\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m logger\u001b[38;5;241m.\u001b[39misEnabledFor(logging\u001b[38;5;241m.\u001b[39mDEBUG):\n\u001b[1;32m 1418\u001b[0m logger\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msharded_avals: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, shards\u001b[38;5;241m.\u001b[39msharded_avals)\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1321\u001b[0m, in \u001b[0;36mstage_parallel_callable\u001b[0;34m(pci, fun, global_arg_shapes)\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m core\u001b[38;5;241m.\u001b[39mextend_axis_env(pci\u001b[38;5;241m.\u001b[39maxis_name, pci\u001b[38;5;241m.\u001b[39mglobal_axis_size, \u001b[38;5;28;01mNone\u001b[39;00m): \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 1318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m dispatch\u001b[38;5;241m.\u001b[39mlog_elapsed_time(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFinished tracing + transforming \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1319\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfor pmap in \u001b[39m\u001b[38;5;132;01m{elapsed_time}\u001b[39;00m\u001b[38;5;124m sec\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1320\u001b[0m event\u001b[38;5;241m=\u001b[39mdispatch\u001b[38;5;241m.\u001b[39mJAXPR_TRACE_EVENT):\n\u001b[0;32m-> 1321\u001b[0m jaxpr, out_sharded_avals, consts \u001b[38;5;241m=\u001b[39m \u001b[43mpe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrace_to_jaxpr_final\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1322\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_sharded_avals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdebug_info_final\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpmap\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1323\u001b[0m jaxpr \u001b[38;5;241m=\u001b[39m dispatch\u001b[38;5;241m.\u001b[39mapply_outfeed_rewriter(jaxpr)\n\u001b[1;32m 1325\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out_sharded_avals) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(pci\u001b[38;5;241m.\u001b[39mout_axes), (\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28mlen\u001b[39m(out_sharded_avals), \u001b[38;5;28mlen\u001b[39m(pci\u001b[38;5;241m.\u001b[39mout_axes))\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:2065\u001b[0m, in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info, keep_inputs)\u001b[0m\n\u001b[1;32m 2063\u001b[0m main\u001b[38;5;241m.\u001b[39mjaxpr_stack \u001b[38;5;241m=\u001b[39m () \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 2064\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m core\u001b[38;5;241m.\u001b[39mnew_sublevel():\n\u001b[0;32m-> 2065\u001b[0m jaxpr, out_avals, consts \u001b[38;5;241m=\u001b[39m \u001b[43mtrace_to_subjaxpr_dynamic\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2066\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_avals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdebug_info\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdebug_info\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2067\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m fun, main\n\u001b[1;32m 2068\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jaxpr, out_avals, consts\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1998\u001b[0m, in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals, keep_inputs, debug_info)\u001b[0m\n\u001b[1;32m 1996\u001b[0m in_tracers \u001b[38;5;241m=\u001b[39m _input_type_to_tracers(trace\u001b[38;5;241m.\u001b[39mnew_arg, in_avals)\n\u001b[1;32m 1997\u001b[0m in_tracers_ \u001b[38;5;241m=\u001b[39m [t \u001b[38;5;28;01mfor\u001b[39;00m t, keep \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(in_tracers, keep_inputs) \u001b[38;5;28;01mif\u001b[39;00m keep]\n\u001b[0;32m-> 1998\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[43mfun\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall_wrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43min_tracers_\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1999\u001b[0m out_tracers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmap\u001b[39m(trace\u001b[38;5;241m.\u001b[39mfull_raise, ans)\n\u001b[1;32m 2000\u001b[0m jaxpr, consts \u001b[38;5;241m=\u001b[39m frame\u001b[38;5;241m.\u001b[39mto_jaxpr(out_tracers)\n", + "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/linear_util.py:167\u001b[0m, in \u001b[0;36mWrappedFun.call_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m gen \u001b[38;5;241m=\u001b[39m gen_static_args \u001b[38;5;241m=\u001b[39m out_store \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 167\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# Some transformations yield from inside context managers, so we have to\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# interrupt them before reraising the exception. Otherwise they will only\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# get garbage-collected at some later time, running their cleanup tasks\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# only after this exception is handled, which can corrupt the global\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;66;03m# state.\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m stack:\n", + "Cell \u001b[0;32mIn[24], line 2\u001b[0m, in \u001b[0;36msfunc\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msfunc\u001b[39m(x): \n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m: \u001b[38;5;28;01mpass\u001b[39;00m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "def sfunc(x): \n", + " while True: pass\n", + "\n", + "jax.pmap(sfunc)(jnp.arange(4))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9032ad64-c70e-4acb-9541-8ef14a36da2b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "numpy: cpu usage 0.6/8 wall_time:1.7s\n", + "vmap: cpu usage 1.8/8 wall_time:1.9s\n", + "dot: cpu usage 7.0/8 wall_time:3.7s\n" + ] + } + ], + "source": [ + "import time, os, jax, numpy as np, jax.numpy as jnp\n", + "jax.config.update('jax_platform_name', 'cpu') # insures we use the CPU\n", + "\n", + "def timer(name, f, x, shouldBlock=True):\n", + " # warmup\n", + " y = f(x).block_until_ready() if shouldBlock else f(x)\n", + " # running the code\n", + " start_wall = time.perf_counter()\n", + " start_cpu = time.process_time()\n", + " y = f(x).block_until_ready() if shouldBlock else f(x)\n", + " end_wall = time.perf_counter()\n", + " end_cpu = time.process_time()\n", + " # computing the metric and displaying it\n", + " wall_time = end_wall - start_wall\n", + " cpu_time = end_cpu - start_cpu\n", + " cpu_count = os.cpu_count()\n", + " print(f\"{name}: cpu usage {cpu_time/wall_time:.1f}/{cpu_count} wall_time:{wall_time:.1f}s\")\n", + "\n", + "# test functions\n", + "key = jax.random.PRNGKey(0)\n", + "x = jax.random.normal(key, shape=(500000000,), dtype=jnp.float64)\n", + "x_mat = jax.random.normal(key, shape=(10000,10000), dtype=jnp.float64)\n", + "f_numpy = np.cos\n", + "f_vmap = jax.jit(jax.vmap(jnp.cos))\n", + "f_dot = jax.jit(lambda x: jnp.dot(x,x.T)) # to show that JAX can indeed use all cores\n", + "\n", + "timer('numpy', f_numpy, x, shouldBlock=False)\n", + "timer('vmap', f_vmap, x)\n", + "timer('dot', f_dot, x_mat)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "abf84ca0-52d8-4469-82d5-1e9099a1d05b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "os.cpu_count()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1cef2b44-f3ac-43e0-ba1e-604c8fa8bfc4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n" + ] + } + ], + "source": [ + "from jax import local_device_count\n", + "print(jax.local_devices()) #1 instead of 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9dd7daeb-449e-4111-9170-f7937b495991", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sbi-dev", + "language": "python", + "name": "sbi-dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/bivariate_gaussian_smcabc.py b/examples/bivariate_gaussian_smcabc.py index 1f26d8b..6a1014e 100644 --- a/examples/bivariate_gaussian_smcabc.py +++ b/examples/bivariate_gaussian_smcabc.py @@ -6,7 +6,7 @@ import jax import matplotlib.pyplot as plt import seaborn as sns -from jax import numpy as jnp +from jax import numpy as jnp, random as jr from sbijax import SMCABC @@ -42,8 +42,10 @@ def run(): fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn smc = SMCABC(fns, summary_fn, distance_fn) - smc.fit(23, y_observed) - smc_samples, _ = smc.sample_posterior(10, 1000, 1000, 0.8, 500) + smc_samples, _ = smc.sample_posterior( + jr.PRNGKey(22), y_observed, + 10, 1000, 1000, 0.6, 500 + ) fig, axes = plt.subplots(2) for i in range(2): diff --git a/examples/bivariate_gaussian_snl.py b/examples/bivariate_gaussian_snl.py index 0dba1f9..9ea6e3f 100644 --- a/examples/bivariate_gaussian_snl.py +++ b/examples/bivariate_gaussian_snl.py @@ -11,7 +11,7 @@ import optax import seaborn as sns from jax import numpy as jnp -from jax import random +from jax import random as jr from surjectors import ( Chain, MaskedAutoregressive, @@ -31,16 +31,14 @@ def prior_model_fns(): def simulator_fn(seed, theta): - p = distrax.Normal(jnp.zeros_like(theta), 0.1) + p = distrax.Normal(jnp.zeros_like(theta), 1.0) y = theta + p.sample(seed=seed) return y def log_density_fn(theta, y): prior = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) - likelihood = distrax.MultivariateNormalDiag( - theta, 0.1 * jnp.ones_like(theta) - ) + likelihood = distrax.MultivariateNormalDiag(theta, jnp.ones_like(theta)) lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y)) return lp @@ -94,26 +92,28 @@ def run(): snl = SNL(fns, make_model(2)) optimizer = optax.adam(1e-3) - params, info = snl.fit( - random.PRNGKey(23), - y_observed, - optimizer=optimizer, - n_rounds=3, - max_n_iter=100, - batch_size=64, - n_early_stopping_patience=5, - sampler="slice", - ) + data, params = None, {} + for i in range(2): + data, _ = snl.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(12), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = snl.fit( + jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer + ) + + sample_key, rng_key = jr.split(jr.PRNGKey(123)) slice_samples = sample_with_slice( - hk.PRNGSequence(0), log_density, 4, 2000, 1000, prior_simulator_fn + sample_key, log_density, prior_simulator_fn ) slice_samples = slice_samples.reshape(-1, 2) - snl_samples, _ = snl.sample_posterior( - params, 4, 2000, 1000, sampler="slice" - ) - print(f"Took n={snl.n_total_simulations} simulations in total") + sample_key, rng_key = jr.split(rng_key) + snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed) + fig, axes = plt.subplots(2, 2) for i in range(2): sns.histplot( diff --git a/examples/bivariate_gaussian_snp.py b/examples/bivariate_gaussian_snp.py index c5a1e70..3cc0230 100644 --- a/examples/bivariate_gaussian_snp.py +++ b/examples/bivariate_gaussian_snp.py @@ -9,7 +9,7 @@ import optax import seaborn as sns from jax import numpy as jnp -from jax import random +from jax import random as jr from surjectors import Chain, TransformedDistribution from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive from surjectors.bijectors.permutation import Permutation @@ -25,7 +25,7 @@ def prior_model_fns(): def simulator_fn(seed, theta): - p = distrax.Normal(jnp.zeros_like(theta), 0.1) + p = distrax.Normal(jnp.zeros_like(theta), 1.0) y = theta + p.sample(seed=seed) return y @@ -72,21 +72,25 @@ def run(): prior_simulator_fn, prior_logdensity_fn = prior_model_fns() fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn - optimizer = optax.adamw(1e-04) snp = SNP(fns, make_model(2)) - params, info = snp.fit( - random.PRNGKey(2), - y_observed, - n_rounds=3, - optimizer=optimizer, - n_early_stopping_patience=10, - batch_size=64, - n_atoms=10, - max_n_iter=100, - ) - - print(f"Took n={snp.n_total_simulations} simulations in total") - snp_samples, _ = snp.sample_posterior(params, 10000) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(2): + data, _ = snp.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(1), i), + params=params, + observable=y_observed, + data=data, + ) + params, info = snp.fit( + jr.fold_in(jr.PRNGKey(2), i), + data=data, + optimizer=optimizer, + ) + + rng_key = jr.PRNGKey(23) + snp_samples, _ = snp.sample_posterior(rng_key, params, y_observed) fig, axes = plt.subplots(2) for i, ax in enumerate(axes): sns.histplot(snp_samples[:, i], color="darkblue", ax=ax) diff --git a/examples/slcp_smcabc.py b/examples/slcp_smcabc.py deleted file mode 100644 index f5fc814..0000000 --- a/examples/slcp_smcabc.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -SLCP example from [1] using SMCABC -""" - -from functools import partial - -import distrax -import haiku as hk -import jax -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -from jax import numpy as jnp -from jax import random -from jax import scipy as jsp -from jax import vmap - -from sbijax import SMCABC -from sbijax.mcmc import sample_with_slice - - -def prior_model_fns(): - p = distrax.Independent( - distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 - ) - return p.sample, p.log_prob - - -def likelihood_fn(theta, y): - mu = jnp.tile(theta[:2], 4) - s1, s2 = theta[2] ** 2, theta[3] ** 2 - corr = s1 * s2 * jnp.tanh(theta[4]) - cov = jnp.array([[s1**2, corr], [corr, s2**2]]) - cov = jsp.linalg.block_diag(*[cov for _ in range(4)]) - p = distrax.MultivariateNormalFullCovariance(mu, cov) - return p.log_prob(y) - - -def simulator_fn(seed, theta): - orig_shape = theta.shape - if theta.ndim == 2: - theta = theta[:, None, :] - us_key, noise_key = random.split(seed) - - def _unpack_params(ps): - m0 = ps[..., [0]] - m1 = ps[..., [1]] - s0 = ps[..., [2]] ** 2 - s1 = ps[..., [3]] ** 2 - r = np.tanh(ps[..., [4]]) - return m0, m1, s0, s1, r - - m0, m1, s0, s1, r = _unpack_params(theta) - us = distrax.Normal(0.0, 1.0).sample( - seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2) - ) - xs = jnp.empty_like(us) - xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) - y = xs.at[:, :, :, 1].set( - s1 * (r * us[:, :, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 - ) - if len(orig_shape) == 2: - y = y.reshape((*theta.shape[:1], 8)) - else: - y = y.reshape((*theta.shape[:2], 8)) - return y - - -def summary_fn(y): - if y.ndim == 2: - y = y[None, ...] - sumr = jnp.mean(y, axis=1, keepdims=True) - return sumr - - -def distance_fn(y_simulated, y_observed): - diff = y_simulated - y_observed - dist = jax.vmap(lambda el: jnp.linalg.norm(el))(diff) - return dist - - -def run(): - len_theta = 5 - # this is the thetas used in SNL - # thetas = jnp.array([-0.7, -2.9, -1.0, -0.9, 0.6]) - y_observed = jnp.array( - [ - [ - -0.9707123, - -2.9461224, - -0.4494722, - -3.4231849, - -0.13285634, - -3.364017, - -0.85367596, - -2.4271638, - ] - ] - ) - prior_simulator_fn, prior_logdensity_fn = prior_model_fns() - fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn - - smc = SMCABC(fns, summary_fn, distance_fn) - smc.fit(23, y_observed) - smc_samples, _ = smc.sample_posterior(5, 1000, 10, 0.9, 500) - - def log_density_fn(theta, y): - prior_lp = prior_logdensity_fn(theta) - likelihood_lp = likelihood_fn(theta, y) - - lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) - return lp - - log_density_partial = partial(log_density_fn, y=y_observed) - log_density = lambda x: vmap(log_density_partial)(x) - - slice_samples = sample_with_slice( - hk.PRNGSequence(12), log_density, 4, 10000, 5000, prior_simulator_fn - ) - slice_samples = slice_samples.reshape(-1, len_theta) - - g = sns.PairGrid(pd.DataFrame(slice_samples)) - g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2) - g.map_diag(plt.hist, color="black") - for ax in g.axes.flatten(): - ax.set_xlim(-5, 5) - ax.set_ylim(-5, 5) - g.fig.set_figheight(5) - g.fig.set_figwidth(5) - plt.show() - - fig, axes = plt.subplots(len_theta, 2) - for i in range(len_theta): - sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0]) - sns.histplot(smc_samples[:, i], color="darkblue", ax=axes[i, 1]) - axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$") - axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$") - for j in range(2): - axes[i, j].set_xlim(-5, 5) - sns.despine() - plt.tight_layout() - plt.show() - - -if __name__ == "__main__": - run() diff --git a/examples/slcp_snl_masked_autoregressive.py b/examples/slcp_snl_masked_autoregressive.py deleted file mode 100644 index 48b1664..0000000 --- a/examples/slcp_snl_masked_autoregressive.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -SLCP example from [1] using SNL and masked coupling bijections or surjections -""" - -import argparse -from functools import partial - -import distrax -import haiku as hk -import jax -import matplotlib.pyplot as plt -import numpy as np -import optax -import pandas as pd -import seaborn as sns -from jax import numpy as jnp -from jax import random -from jax import scipy as jsp -from jax import vmap -from surjectors import Chain, TransformedDistribution -from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive -from surjectors.bijectors.permutation import Permutation -from surjectors.conditioners import MADE, mlp_conditioner -from surjectors.surjectors.affine_masked_autoregressive_inference_funnel import ( # type: ignore # noqa: E501 - AffineMaskedAutoregressiveInferenceFunnel, -) -from surjectors.util import unstack - -from sbijax import SNL -from sbijax.mcmc.slice import sample_with_slice - - -def prior_model_fns(): - p = distrax.Independent( - distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 - ) - return p.sample, p.log_prob - - -def likelihood_fn(theta, y): - mu = jnp.tile(theta[:2], 4) - s1, s2 = theta[2] ** 2, theta[3] ** 2 - corr = s1 * s2 * jnp.tanh(theta[4]) - cov = jnp.array([[s1**2, corr], [corr, s2**2]]) - cov = jsp.linalg.block_diag(*[cov for _ in range(4)]) - p = distrax.MultivariateNormalFullCovariance(mu, cov) - return p.log_prob(y) - - -def simulator_fn(seed, theta): - orig_shape = theta.shape - if theta.ndim == 2: - theta = theta[:, None, :] - us_key, noise_key = random.split(seed) - - def _unpack_params(ps): - m0 = ps[..., [0]] - m1 = ps[..., [1]] - s0 = ps[..., [2]] ** 2 - s1 = ps[..., [3]] ** 2 - r = np.tanh(ps[..., [4]]) - return m0, m1, s0, s1, r - - m0, m1, s0, s1, r = _unpack_params(theta) - us = distrax.Normal(0.0, 1.0).sample( - seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2) - ) - xs = jnp.empty_like(us) - xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) - y = xs.at[:, :, :, 1].set( - s1 * (r * us[:, :, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 - ) - if len(orig_shape) == 2: - y = y.reshape((*theta.shape[:1], 8)) - else: - y = y.reshape((*theta.shape[:2], 8)) - return y - - -def make_model(dim, use_surjectors): - def _bijector_fn(params): - means, log_scales = unstack(params, -1) - return distrax.ScalarAffine(means, jnp.exp(log_scales)) - - def _decoder_fn(n_dim): - decoder_net = mlp_conditioner( - [50, n_dim * 2], - w_init=hk.initializers.TruncatedNormal(stddev=0.001), - ) - - def _fn(z): - params = decoder_net(z) - mu, log_scale = jnp.split(params, 2, -1) - return distrax.Independent( - distrax.Normal(mu, jnp.exp(log_scale)), 1 - ) - - return _fn - - def _flow(method, **kwargs): - layers = [] - n_dimension = dim - order = jnp.arange(n_dimension) - for i in range(5): - if i == 2 and use_surjectors: - n_latent = 6 - layer = AffineMaskedAutoregressiveInferenceFunnel( - n_latent, - _decoder_fn(n_dimension - n_latent), - conditioner=MADE( - n_latent, - [50, n_latent * 2], - 2, - w_init=hk.initializers.TruncatedNormal(0.001), - b_init=jnp.zeros, - activation=jax.nn.tanh, - ), - ) - n_dimension = n_latent - order = order[::-1] - order = order[:n_dimension] - jnp.min(order[:n_dimension]) - else: - layer = MaskedAutoregressive( - bijector_fn=_bijector_fn, - conditioner=MADE( - n_dimension, - [50, n_dimension * 2], - 2, - w_init=hk.initializers.TruncatedNormal(0.001), - b_init=jnp.zeros, - activation=jax.nn.tanh, - ), - ) - order = order[::-1] - layers.append(layer) - layers.append(Permutation(order, 1)) - chain = Chain(layers) - - base_distribution = distrax.Independent( - distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)), - reinterpreted_batch_ndims=1, - ) - td = TransformedDistribution(base_distribution, chain) - return td(method, **kwargs) - - td = hk.transform(_flow) - td = hk.without_apply_rng(td) - return td - - -def run(use_surjectors): - len_theta = 5 - # this is the thetas used in SNL - # thetas = jnp.array([-0.7, -2.9, -1.0, -0.9, 0.6]) - y_observed = jnp.array( - [ - [ - -0.9707123, - -2.9461224, - -0.4494722, - -3.4231849, - -0.13285634, - -3.364017, - -0.85367596, - -2.4271638, - ] - ] - ) - - prior_simulator_fn, prior_fn = prior_model_fns() - fns = (prior_simulator_fn, prior_fn), simulator_fn - - def log_density_fn(theta, y): - prior_lp = prior_fn(theta) - likelihood_lp = likelihood_fn(theta, y) - - lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) - return lp - - log_density_partial = partial(log_density_fn, y=y_observed) - log_density = lambda x: vmap(log_density_partial)(x) - - model = make_model(y_observed.shape[1], use_surjectors) - snl = SNL(fns, model) - optimizer = optax.adam(1e-3) - params, info = snl.fit( - random.PRNGKey(23), - y_observed, - optimizer, - n_rounds=5, - max_n_iter=100, - batch_size=64, - n_early_stopping_patience=5, - sampler="slice", - ) - - slice_samples = sample_with_slice( - hk.PRNGSequence(12), log_density, 4, 5000, 2500, prior_simulator_fn - ) - slice_samples = slice_samples.reshape(-1, len_theta) - snl_samples, _ = snl.sample_posterior(params, 4, 5000, 2500) - - g = sns.PairGrid(pd.DataFrame(slice_samples)) - g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2) - g.map_diag(plt.hist, color="black") - for ax in g.axes.flatten(): - ax.set_xlim(-5, 5) - ax.set_ylim(-5, 5) - g.fig.set_figheight(5) - g.fig.set_figwidth(5) - plt.show() - - fig, axes = plt.subplots(len_theta, 2) - for i in range(len_theta): - sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0]) - sns.histplot(snl_samples[:, i], color="darkblue", ax=axes[i, 1]) - axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$") - axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$") - for j in range(2): - axes[i, j].set_xlim(-5, 5) - sns.despine() - plt.tight_layout() - plt.show() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--use-surjectors", action="store_true", default=True) - args = parser.parse_args() - run(args.use_surjectors) diff --git a/examples/slcp_snl_masked_coupling.py b/examples/slcp_snl_masked_coupling.py deleted file mode 100644 index 7bec0d2..0000000 --- a/examples/slcp_snl_masked_coupling.py +++ /dev/null @@ -1,208 +0,0 @@ -""" -SLCP example from [1] using SNL and masked coupling bijections or surjections -""" - -import argparse -from functools import partial - -import distrax -import haiku as hk -import jax -import matplotlib.pyplot as plt -import numpy as np -import optax -import pandas as pd -import seaborn as sns -from jax import numpy as jnp -from jax import random -from jax import scipy as jsp -from surjectors import ( - AffineMaskedCouplingInferenceFunnel, - Chain, - MaskedCoupling, - TransformedDistribution, -) -from surjectors.conditioners import mlp_conditioner -from surjectors.util import make_alternating_binary_mask - -from sbijax import SNL -from sbijax.mcmc import sample_with_slice - - -def prior_model_fns(): - p = distrax.Independent( - distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 - ) - return p.sample, p.log_prob - - -def likelihood_fn(theta, y): - mu = jnp.tile(theta[:2], 4) - s1, s2 = theta[2] ** 2, theta[3] ** 2 - corr = s1 * s2 * jnp.tanh(theta[4]) - cov = jnp.array([[s1**2, corr], [corr, s2**2]]) - cov = jsp.linalg.block_diag(*[cov for _ in range(4)]) - p = distrax.MultivariateNormalFullCovariance(mu, cov) - return p.log_prob(y) - - -def simulator_fn(seed, theta): - orig_shape = theta.shape - if theta.ndim == 2: - theta = theta[:, None, :] - us_key, noise_key = random.split(seed) - - def _unpack_params(ps): - m0 = ps[..., [0]] - m1 = ps[..., [1]] - s0 = ps[..., [2]] ** 2 - s1 = ps[..., [3]] ** 2 - r = np.tanh(ps[..., [4]]) - return m0, m1, s0, s1, r - - m0, m1, s0, s1, r = _unpack_params(theta) - us = distrax.Normal(0.0, 1.0).sample( - seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2) - ) - xs = jnp.empty_like(us) - xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) - y = xs.at[:, :, :, 1].set( - s1 * (r * us[:, :, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 - ) - if len(orig_shape) == 2: - y = y.reshape((*theta.shape[:1], 8)) - else: - y = y.reshape((*theta.shape[:2], 8)) - return y - - -def make_model(dim, use_surjectors): - def _bijector_fn(params): - means, log_scales = jnp.split(params, 2, -1) - return distrax.ScalarAffine(means, jnp.exp(log_scales)) - - def _conditional_fn(n_dim): - decoder_net = mlp_conditioner([32, 32, n_dim * 2]) - - def _fn(z): - params = decoder_net(z) - mu, log_scale = jnp.split(params, 2, -1) - return distrax.Independent( - distrax.Normal(mu, jnp.exp(log_scale)), 1 - ) - - return _fn - - def _flow(method, **kwargs): - layers = [] - n_dimension = dim - for i in range(5): - mask = make_alternating_binary_mask(n_dimension, i % 2 == 0) - if i == 2 and use_surjectors: - n_latent = 6 - layer = AffineMaskedCouplingInferenceFunnel( - n_latent, - _conditional_fn(n_dimension - n_latent), - mlp_conditioner([32, 32, n_dimension * 2]), - ) - n_dimension = n_latent - else: - layer = MaskedCoupling( - mask=mask, - bijector=_bijector_fn, - conditioner=mlp_conditioner([32, 32, n_dimension * 2]), - ) - layers.append(layer) - chain = Chain(layers) - - base_distribution = distrax.Independent( - distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)), - reinterpreted_batch_ndims=1, - ) - td = TransformedDistribution(base_distribution, chain) - return td(method, **kwargs) - - td = hk.transform(_flow) - td = hk.without_apply_rng(td) - return td - - -def run(use_surjectors): - len_theta = 5 - # this is the thetas used in SNL - # thetas = jnp.array([-0.7, -2.9, -1.0, -0.9, 0.6]) - y_observed = jnp.array( - [ - [ - -0.9707123, - -2.9461224, - -0.4494722, - -3.4231849, - -0.13285634, - -3.364017, - -0.85367596, - -2.4271638, - ] - ] - ) - - prior_simulator_fn, prior_fn = prior_model_fns() - fns = (prior_simulator_fn, prior_fn), simulator_fn - - def log_density_fn(theta, y): - prior_lp = prior_fn(theta) - likelihood_lp = likelihood_fn(theta, y) - - lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) - return lp - - log_density_partial = partial(log_density_fn, y=y_observed) - log_density = lambda x: jax.vmap(log_density_partial)(x) - - snl = SNL(fns, make_model(y_observed.shape[1], use_surjectors)) - optimizer = optax.adam(1e-3) - params, info = snl.fit( - random.PRNGKey(23), - y_observed, - optimizer, - n_rounds=5, - max_n_iter=100, - batch_size=64, - n_early_stopping_patience=10, - sampler="slice", - ) - - slice_samples = sample_with_slice( - hk.PRNGSequence(12), log_density, 4, 5000, 2500, prior_simulator_fn - ) - slice_samples = slice_samples.reshape(-1, len_theta) - snl_samples, _ = snl.sample_posterior(params, 4, 5000, 2500) - - g = sns.PairGrid(pd.DataFrame(slice_samples)) - g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2) - g.map_diag(plt.hist, color="black") - for ax in g.axes.flatten(): - ax.set_xlim(-5, 5) - ax.set_ylim(-5, 5) - g.fig.set_figheight(5) - g.fig.set_figwidth(5) - plt.show() - - fig, axes = plt.subplots(len_theta, 2) - for i in range(len_theta): - sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0]) - sns.histplot(snl_samples[:, i], color="darkblue", ax=axes[i, 1]) - axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$") - axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$") - for j in range(2): - axes[i, j].set_xlim(-5, 5) - sns.despine() - plt.tight_layout() - plt.show() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--use-surjectors", action="store_true", default=True) - args = parser.parse_args() - run(args.use_surjectors) diff --git a/pyproject.toml b/pyproject.toml index 5715787..690350c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,9 +23,9 @@ dependencies = [ "blackjax-nightly>=0.9.6.post127", "distrax>=0.1.2", "dm-haiku>=0.0.9", - "flax>=0.6.3", "optax>=0.1.3", "surjectors>=0.2.2", + "tfp-nightly>=0.20.0.dev20230404" ] dynamic = ["version"] diff --git a/sbijax/_sbi_base.py b/sbijax/_sbi_base.py index 4bd20fa..496f919 100644 --- a/sbijax/_sbi_base.py +++ b/sbijax/_sbi_base.py @@ -1,15 +1,10 @@ import abc -from typing import Optional -import chex -import haiku as hk -from jax import numpy as jnp -from jax import random +from jax import random as jr -from sbijax.generator import named_dataset - -# pylint: disable=too-many-instance-attributes,unused-argument +# pylint: disable=too-many-instance-attributes,unused-argument, +# pylint: disable=too-few-public-methods class SBI(abc.ABC): """ SBI base class @@ -18,66 +13,16 @@ class SBI(abc.ABC): def __init__(self, model_fns): self.prior_sampler_fn, self.prior_log_density_fn = model_fns[0] self.simulator_fn = model_fns[1] - self._len_theta = len(self.prior_sampler_fn(seed=random.PRNGKey(0))) - - self._observed: chex.Array - self._rng_seq: hk.PRNGSequence - self._data: Optional[named_dataset] = None - - @property - def observed(self): - """Get the observation to condition on""" - return self._observed - - @observed.setter - def observed(self, observed): - """Set the observation to condition on""" - self._observed = jnp.atleast_2d(observed) - - @property - def data(self): - """Get the data set""" - return self._data - - @data.setter - def data(self, data): - """Set the data set""" - if not isinstance(data, named_dataset): - raise TypeError("data is not of type 'named_dataset'") - self._data = data - - @property - def rng_seq(self): - """Rng sequence""" - return self._rng_seq - - @rng_seq.setter - def rng_seq(self, rng_seq): - self._rng_seq = rng_seq - - def fit(self, rng_key, observed, **kwargs): - """ - Fit the model - - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - observed: chex.Array - (n \times p)-dimensional array of observations, where `n` is the n - number of samples - """ - - self.rng_seq = hk.PRNGSequence(rng_key) - self.observed = observed + self._len_theta = len(self.prior_sampler_fn(seed=jr.PRNGKey(123))) @abc.abstractmethod - def sample_posterior(self, **kwargs): + def sample_posterior(self, rng_key, **kwargs): """ Sample from the posterior distribution Parameters ---------- - kwargs - keyword arguments + rng_key: jax.PRNGKey + a random key + kwargs: keyword arguments with sampler specific parameters """ diff --git a/sbijax/_sne_base.py b/sbijax/_sne_base.py index 6171c20..400a1a0 100644 --- a/sbijax/_sne_base.py +++ b/sbijax/_sne_base.py @@ -1,7 +1,9 @@ from abc import ABC from typing import Iterable +import chex from jax import numpy as jnp +from jax import random as jr from sbijax import generator from sbijax._sbi_base import SBI @@ -22,64 +24,84 @@ def __init__(self, model_fns, density_estimator): self._train_iter: Iterable self._val_iter: Iterable - def simulate_new_data_and_append(self, params, n_simulations): + def simulate_data_and_possibly_append( + self, + rng_key, + params, + observable, + data=None, + n_simulations=1000, + **kwargs, + ): """ - Simulate novel data-parameters pairs and append to the - existing data set. + Simulate data from the posteriorand append it to an existing data set + (if provided) Parameters ---------- + rng_key: jax.PRNGKey + a random key params: pytree - parameter set of the neural network + a dictionary of neural network parameters + observable: jnp.ndarray + an observation + data: NamedTuple + existing data set n_simulations: int - number of data-parameter pairs to draw + number of newly simulated data + kwargs: keyword arguments + dictionary of ey value pairs passed to `sample_posterior` Returns ------- - Returns the data set. + NamedTuple: + returns a NamedTuple of two axis, y and theta """ - self.data = self._simulate_new_data_and_append( - params, self.data, n_simulations - ) - return self.data - - def _simulate_new_data_and_append( - self, - params, - D, - n_simulations_per_round, - **kwargs, - ): - if D is None: + observable = jnp.atleast_2d(observable) + sample_key, rng_key = jr.split(rng_key) + if data is None: diagnostics = None - self.n_total_simulations += n_simulations_per_round + self.n_total_simulations += n_simulations new_thetas = self.prior_sampler_fn( - seed=next(self._rng_seq), - sample_shape=(n_simulations_per_round,), + seed=sample_key, + sample_shape=(n_simulations,), ) else: + if "n_samples" not in kwargs: + kwargs["n_samples"] = n_simulations new_thetas, diagnostics = self.sample_posterior( + rng_key=sample_key, params=params, - n_simulations_per_round=n_simulations_per_round, + observable=observable, **kwargs, ) + perm_key, rng_key = jr.split(rng_key) + new_thetas = jr.permutation(perm_key, new_thetas) + new_thetas = new_thetas[:n_simulations, :] - new_obs = self.simulator_fn(seed=next(self._rng_seq), theta=new_thetas) + simulate_key, rng_key = jr.split(rng_key) + new_obs = self.simulator_fn(seed=simulate_key, theta=new_thetas) new_data = named_dataset(new_obs, new_thetas) - if D is None: + + chex.assert_shape(new_thetas, [n_simulations, None]) + chex.assert_shape(new_data, [n_simulations, None]) + + if data is None: d_new = new_data else: d_new = named_dataset( - *[jnp.vstack([a, b]) for a, b in zip(D, new_data)] + *[jnp.vstack([a, b]) for a, b in zip(data, new_data)] ) return d_new, diagnostics - def as_iterators(self, D, batch_size, percentage_data_as_validation_set): + def as_iterators( + self, rng_key, data, batch_size, percentage_data_as_validation_set + ): """Convert the data set to an iterable for training""" return generator.as_batch_iterators( - next(self._rng_seq), - D, + rng_key, + data, batch_size, 1.0 - percentage_data_as_validation_set, True, diff --git a/sbijax/abc/rejection_abc.py b/sbijax/abc/rejection_abc.py index bda54de..f940022 100644 --- a/sbijax/abc/rejection_abc.py +++ b/sbijax/abc/rejection_abc.py @@ -1,11 +1,11 @@ -import chex from jax import numpy as jnp -from jax import random +from jax import random as jr from sbijax._sbi_base import SBI -# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-instance-attributes,too-many-arguments +# pylint: disable=too-many-locals,too-few-public-methods class RejectionABC(SBI): """ Sisson et al. - Handbook of approximate Bayesian computation @@ -17,19 +17,27 @@ def __init__(self, model_fns, summary_fn, kernel_fn): super().__init__(model_fns) self.kernel_fn = kernel_fn self.summary_fn = summary_fn - self.summarized_observed: chex.Array - - def fit(self, rng_key, observed, **kwargs): - super().fit(rng_key, observed) - self.summarized_observed = self.summary_fn(self.observed) # pylint: disable=arguments-differ - def sample_posterior(self, n_samples, n_simulations_per_theta, K, h): + def sample_posterior( + self, + rng_key, + observable, + n_samples, + n_simulations_per_theta, + K, + h, + **kwargs, + ): """ Sample from the approximate posterior Parameters ---------- + rng_key: jax.PRNGKey + a random key + observable: jnp.Array + observation to condition on n_samples: int number of samples to draw for each parameter n_simulations_per_theta: int @@ -46,23 +54,26 @@ def sample_posterior(self, n_samples, n_simulations_per_theta, K, h): (n_samples \times p) """ + observable = jnp.atleast_2d(observable) + thetas = None n = n_samples K = jnp.maximum( K, self.kernel_fn(jnp.zeros((1, 2, 2)), jnp.zeros((1, 2, 2)))[0] ) while n > 0: + p_key, simulate_key, prior_key, rng_key = jr.split(rng_key) n_sim = jnp.minimum(n, 1000) - ps = self.prior_sampler_fn( - seed=next(self._rng_seq), sample_shape=(n_sim,) - ) + ps = self.prior_sampler_fn(seed=prior_key, sample_shape=(n_sim,)) ys = self.simulator_fn( - seed=next(self._rng_seq), + seed=simulate_key, theta=jnp.tile(ps, [n_simulations_per_theta, 1, 1]), ) ys = jnp.swapaxes(ys, 1, 0) - k = self.kernel_fn(self.summary_fn(ys), self.summarized_observed, h) - p = random.uniform(next(self._rng_seq), shape=(len(k),)) + k = self.kernel_fn( + self.summary_fn(ys), self.summary_fn(observable), h + ) + p = jr.uniform(p_key, shape=(len(k),)) mr = k / K idxs = jnp.where(p < mr)[0] if thetas is None: diff --git a/sbijax/abc/smc_abc.py b/sbijax/abc/smc_abc.py index 08da1d8..07a3758 100644 --- a/sbijax/abc/smc_abc.py +++ b/sbijax/abc/smc_abc.py @@ -6,12 +6,14 @@ from blackjax.smc import resampling from blackjax.smc.ess import ess from jax import numpy as jnp +from jax import random as jr from jax import scipy as jsp from sbijax._sbi_base import SBI -# pylint: disable=arguments-differ,too-many-function-args +# pylint: disable=arguments-differ,too-many-function-args,too-many-locals +# pylint: disable=too-few-public-methods class SMCABC(SBI): """ Sisson et al. - Handbook of approximate Bayesian computation @@ -26,13 +28,11 @@ def __init__(self, model_fns, summary_fn, distance_fn): self.summarized_observed: chex.Array self.n_total_simulations = 0 - def fit(self, rng_key, observed): - super().fit(rng_key, observed) - self.summarized_observed = self.summary_fn(self.observed) - # pylint: disable=too-many-arguments,arguments-differ def sample_posterior( self, + rng_key, + observable, n_rounds, n_particles, n_simulations_per_theta, @@ -64,15 +64,21 @@ def sample_posterior( an array of samples from the posterior distribution of dimension (n_samples \times p) """ + observable = jnp.atleast_2d(observable) - all_particles, all_n_simulations = [], [] + init_key, rng_key = jr.split(rng_key) particles, log_weights, epsilon = self._init_particles( - n_particles, n_simulations_per_theta + init_key, observable, n_particles, n_simulations_per_theta ) - for _ in range(n_rounds): + all_particles, all_n_simulations = [], [] + for n in range(n_rounds): epsilon *= eps_step + rng_key = jr.fold_in(rng_key, n) + particle_key, rng_key = jr.split(rng_key) particles, log_weights = self._move( + particle_key, + observable, n_particles, particles, log_weights, @@ -82,8 +88,9 @@ def sample_posterior( ) curr_ess = ess(log_weights) if curr_ess < ess_min: + resample_key, rng_key = jr.split(rng_key) particles, log_weights = self._resample( - particles, log_weights, particles.shape[0] + resample_key, particles, log_weights, particles.shape[0] ) all_particles.append(particles.copy()) all_n_simulations.append(self.n_total_simulations) @@ -95,45 +102,50 @@ def _chol_factor(self, particles, cov_scale): chol = jnp.linalg.cholesky(jnp.cov(particles.T) * cov_scale) return chol - def _init_particles(self, n_particles, n_simulations_per_theta): + def _init_particles( + self, rng_key, observable, n_particles, n_simulations_per_theta + ): self.n_total_simulations += n_particles * 10 + + init_key, rng_key = jr.split(rng_key) particles = self.prior_sampler_fn( - seed=next(self._rng_seq), sample_shape=(n_particles * 10,) + seed=init_key, sample_shape=(n_particles * 10,) ) thetas = jnp.tile(particles, [n_simulations_per_theta, 1, 1]) chex.assert_axis_dimension(thetas, 0, n_simulations_per_theta) chex.assert_axis_dimension(thetas, 1, n_particles * 10) - ys = self.simulator_fn( - seed=next(self._rng_seq), - theta=thetas, - ) + simulator_key, rng_key = jr.split(rng_key) + ys = self.simulator_fn(seed=simulator_key, theta=thetas) ys = jnp.swapaxes(ys, 1, 0) chex.assert_axis_dimension(ys, 0, n_particles * 10) chex.assert_axis_dimension(ys, 1, n_simulations_per_theta) summary_statistics = self.summary_fn(ys) distances = self.distance_fn( - summary_statistics, self.summarized_observed + summary_statistics, self.summary_fn(observable) ) - sort_idx = jnp.argsort(distances) + sort_idx = jnp.argsort(distances) particles = particles[sort_idx][:n_particles] log_weights = -jnp.log(jnp.full(n_particles, n_particles)) initial_epsilon = distances[-1] return particles, log_weights, initial_epsilon - def _sample_candidates(self, particles, log_weights, n, cov_chol_factor): + def _sample_candidates( + self, rng_key, particles, log_weights, n, cov_chol_factor + ): n_sim = jnp.maximum(jnp.minimum(n, 1000), 100) self.n_total_simulations += n_sim + sample_key, perturb_key, rng_key = jr.split(rng_key, 3) new_candidate_particles, _ = self._resample( - particles, log_weights, n_sim + sample_key, particles, log_weights, n_sim ) new_candidate_particles = self._perturb( - new_candidate_particles, cov_chol_factor + perturb_key, new_candidate_particles, cov_chol_factor ) cand_lps = self.prior_log_density_fn(new_candidate_particles) new_candidate_particles = new_candidate_particles[ @@ -142,22 +154,28 @@ def _sample_candidates(self, particles, log_weights, n, cov_chol_factor): return new_candidate_particles def _simulate_and_distance( - self, new_candidate_particles, n_simulations_per_theta + self, + rng_key, + observable, + new_candidate_particles, + n_simulations_per_theta, ): ys = self.simulator_fn( - seed=next(self._rng_seq), + seed=rng_key, theta=jnp.tile( new_candidate_particles, [n_simulations_per_theta, 1, 1] ), ) ys = jnp.swapaxes(ys, 1, 0) summary_statistics = self.summary_fn(ys) - ds = self.distance_fn(summary_statistics, self.summarized_observed) + ds = self.distance_fn(summary_statistics, self.summary_fn(observable)) return ds # pylint: disable=too-many-arguments def _move( self, + rng_key, + observable, n_particles, particles, log_weights, @@ -169,11 +187,15 @@ def _move( cov_chol_factor = self._chol_factor(particles, cov_scale) n = n_particles while n > 0: + sample_key, simulate_key, rng_key = jr.split(rng_key, 3) new_candidate_particles = self._sample_candidates( - particles, log_weights, n, cov_chol_factor + sample_key, particles, log_weights, n, cov_chol_factor ) ds = self._simulate_and_distance( - new_candidate_particles, n_simulations_per_theta + simulate_key, + observable, + new_candidate_particles, + n_simulations_per_theta, ) idxs = jnp.where(ds < epsilon)[0] @@ -195,10 +217,8 @@ def _move( return new_particles, new_log_weights - def _resample(self, particles, log_weights, n_samples): - idxs = resampling.multinomial( - next(self._rng_seq), jnp.exp(log_weights), n_samples - ) + def _resample(self, rng_key, particles, log_weights, n_samples): + idxs = resampling.multinomial(rng_key, jnp.exp(log_weights), n_samples) particles = particles[idxs] return particles, -jnp.log(jnp.full(n_samples, n_samples)) @@ -223,7 +243,5 @@ def _particle_weight(partcl): def _kernel(self, mus, cov_chol_factor): return distrax.MultivariateNormalTri(loc=mus, scale_tri=cov_chol_factor) - def _perturb(self, mus, cov_chol_factor): - return self._kernel(mus, cov_chol_factor).sample( - seed=next(self._rng_seq) - ) + def _perturb(self, rng_key, mus, cov_chol_factor): + return self._kernel(mus, cov_chol_factor).sample(seed=rng_key) diff --git a/sbijax/generator.py b/sbijax/generator.py index 0b04990..0f0c7bc 100644 --- a/sbijax/generator.py +++ b/sbijax/generator.py @@ -2,8 +2,7 @@ import chex from jax import lax -from jax import numpy as jnp -from jax import random +from jax import numpy as jnp, random as jr named_dataset = namedtuple("named_dataset", "y theta") @@ -13,6 +12,7 @@ class DataLoader: def __init__(self, num_batches, idxs, get_batch): self.num_batches = num_batches self.idxs = idxs + self.num_samples = len(idxs) self.get_batch = get_batch def __call__(self, idx, idxs=None): @@ -29,12 +29,12 @@ def as_batch_iterators( n_train = int(n * split) if shuffle: - idxs = random.permutation(rng_key, jnp.arange(n)) + idxs = jr.permutation(rng_key, jnp.arange(n)) data = named_dataset(*[el[idxs] for _, el in enumerate(data)]) y_train = named_dataset(*[el[:n_train] for el in data]) y_val = named_dataset(*[el[n_train:] for el in data]) - train_rng_key, val_rng_key = random.split(rng_key) + train_rng_key, val_rng_key = jr.split(rng_key) train_itr = as_batch_iterator(train_rng_key, y_train, batch_size, shuffle) val_itr = as_batch_iterator(val_rng_key, y_val, batch_size, shuffle) @@ -57,7 +57,7 @@ def as_batch_iterator( idxs = jnp.arange(n) if shuffle: - idxs = random.permutation(rng_key, idxs) + idxs = jr.permutation(rng_key, idxs) def get_batch(idx, idxs=idxs): start_idx = idx * batch_size diff --git a/sbijax/mcmc/nuts.py b/sbijax/mcmc/nuts.py index b1f3484..2781ec1 100644 --- a/sbijax/mcmc/nuts.py +++ b/sbijax/mcmc/nuts.py @@ -1,12 +1,13 @@ import blackjax as bj import distrax import jax -from jax import numpy as jnp -from jax import random +from jax import random as jr -# pylint: disable=too-many-arguments -def sample_with_nuts(rng_seq, lp, len_theta, n_chains, n_samples, n_warmup): +# pylint: disable=too-many-arguments,unused-argument +def sample_with_nuts( + rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs +): """ Sample from a distribution using the No-U-Turn sampler. @@ -16,8 +17,8 @@ def sample_with_nuts(rng_seq, lp, len_theta, n_chains, n_samples, n_warmup): a hk.PRNGSequence lp: Callable the logdensity you wish to sample from - len_theta: int - the number of parameters to sample + prior: Callable + a function that returns a prior sample n_chains: int number of chains to sample n_samples: int @@ -42,22 +43,23 @@ def _step(states, rng_key): _, states = jax.lax.scan(_step, initial_state, sampling_keys) return states - initial_states, kernel = _nuts_init(rng_seq, len_theta, n_chains, lp) - states = _inference_loop(next(rng_seq), kernel, initial_states, n_samples) + init_key, rng_key = jr.split(rng_key) + initial_states, kernel = _nuts_init(init_key, n_chains, prior, lp) + + states = _inference_loop(init_key, kernel, initial_states, n_samples) _ = states.position["theta"].block_until_ready() thetas = states.position["theta"][n_warmup:, :, :] + return thetas # pylint: disable=missing-function-docstring -def _nuts_init(rng_seq, len_theta, n_chains, lp): - initial_positions = distrax.MultivariateNormalDiag( - jnp.zeros(len_theta), - jnp.ones(len_theta), - ).sample(seed=next(rng_seq), sample_shape=(n_chains,)) +def _nuts_init(rng_key, n_chains, prior: distrax.Distribution, lp): + init_key, rng_key = jr.split(rng_key) + initial_positions = prior(seed=init_key, sample_shape=(n_chains,)) initial_positions = {"theta": initial_positions} - init_keys = random.split(next(rng_seq), n_chains) + init_keys = jr.split(rng_key, n_chains) warmup = bj.window_adaptation(bj.nuts, lp) initial_states, kernel_params = jax.vmap( lambda seed, param: warmup.run(seed, param)[0] diff --git a/sbijax/mcmc/slice.py b/sbijax/mcmc/slice.py index 905e3c5..82f914d 100644 --- a/sbijax/mcmc/slice.py +++ b/sbijax/mcmc/slice.py @@ -1,15 +1,17 @@ import distrax import tensorflow_probability.substrates.jax as tfp +from jax import random as jr # pylint: disable=too-many-arguments,unused-argument def sample_with_slice( - rng_seq, + rng_key, lp, - n_chains, - n_samples, - n_warmup, prior, + *, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, n_thin=2, n_doubling=5, step_size=1, @@ -24,6 +26,8 @@ def sample_with_slice( a hk.PRNGSequence lp: Callable the logdensity you wish to sample from + prior: Callable + a function that returns a prior sample n_chains: int number of chains to sample n_samples: int @@ -37,9 +41,12 @@ def sample_with_slice( a JAX array of dimension n_samples \times n_chains \times len_theta """ - initial_states = _slice_init(rng_seq, n_chains, prior) + init_key, rng_key = jr.split(rng_key) + initial_states = _slice_init(init_key, n_chains, prior) + + sample_key, rng_key = jr.split(rng_key) samples = tfp.mcmc.sample_chain( - num_results=n_samples, + num_results=n_samples - n_warmup, current_state=initial_states, num_steps_between_results=n_thin, kernel=tfp.mcmc.SliceSampler( @@ -47,14 +54,12 @@ def sample_with_slice( ), num_burnin_steps=n_warmup, trace_fn=None, - seed=next(rng_seq), + seed=sample_key, ) - samples = samples[n_warmup:, ...] return samples # pylint: disable=missing-function-docstring -def _slice_init(rng_seq, n_chains, prior: distrax.Distribution): - initial_positions = prior(seed=next(rng_seq), sample_shape=(n_chains,)) - +def _slice_init(rng_key, n_chains, prior: distrax.Distribution): + initial_positions = prior(seed=rng_key, sample_shape=(n_chains,)) return initial_positions diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sbijax/nn/early_stopping.py b/sbijax/nn/early_stopping.py new file mode 100644 index 0000000..57f8f3c --- /dev/null +++ b/sbijax/nn/early_stopping.py @@ -0,0 +1,36 @@ +import dataclasses +import math + + +# pylint: disable=missing-function-docstring +@dataclasses.dataclass +class EarlyStopping: + """ + Early stopping of neural network training + """ + + min_delta: float = 0 + patience: int = 0 + best_metric: float = float("inf") + patience_count: int = 0 + should_stop: bool = False + + def reset(self): + self.best_metric = float("inf") + self.patience_count = 0 + self.should_stop = False + return self + + def update(self, metric): + if ( + math.isinf(self.best_metric) + or self.best_metric - metric > self.min_delta + ): + self.best_metric = metric + self.patience_count = 0 + return True, self + + should_stop = self.patience_count >= self.patience or self.should_stop + self.should_stop = should_stop + self.patience_count = self.patience_count + 1 + return False, self diff --git a/sbijax/snl.py b/sbijax/snl.py index a0fa6e9..825cfc3 100644 --- a/sbijax/snl.py +++ b/sbijax/snl.py @@ -1,4 +1,3 @@ -from collections import namedtuple from functools import partial import chex @@ -6,16 +5,16 @@ import numpy as np import optax from absl import logging - -# TODO(simon): this is a bit an annoying dependency to have -from flax.training.early_stopping import EarlyStopping from jax import numpy as jnp +from jax import random as jr from sbijax._sne_base import SNE from sbijax.mcmc import mcmc_diagnostics, sample_with_nuts, sample_with_slice - # pylint: disable=too-many-arguments,unused-argument +from sbijax.nn.early_stopping import EarlyStopping + + class SNL(SNE): """ Sequential neural likelihood @@ -27,16 +26,11 @@ class SNL(SNE): def fit( self, rng_key, - observed, - optimizer, - n_rounds=10, - n_simulations_per_round=1000, - max_n_iter=1000, + data, + optimizer=optax.adam(0.0003), + n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, - n_samples=10000, - n_warmup=5000, - n_chains=4, n_early_stopping_patience=10, **kwargs, ): @@ -47,28 +41,17 @@ def fit( ---------- rng_seq: hk.PRNGSequence a hk.PRNGSequence - observed: chex.Array - (n \times p)-dimensional array of observations, where `n` is the n - number of samples + data: NamedTuple + data set obtained from calling `simulate_data_and_possibly_append` optimizer: optax.Optimizer an optax optimizer object - n_rounds: int - number of rounds to optimize - n_simulations_per_round: int - number of data simulations per round - max_n_iter: + n_iter: maximal number of training iterations per round batch_size: int batch size used for training the model percentage_data_as_validation_set: percentage of the simulated data that is used for valitation and early stopping - n_samples: int - number of samples to draw to approximate the posterior - n_warmup: int - number of samples to discard - n_chains: int - number of chains to sample n_early_stopping_patience: int number of iterations of no improvement of training the flow before stopping optimisation @@ -86,50 +69,177 @@ def fit( information """ - super().fit(rng_key, observed) + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set + ) + params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + ) + + return params, losses + + # pylint: disable=arguments-differ + def _fit_model_single_round( + self, + seed, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + ): + init_key, seed = jr.split(seed) + params = self._init_params(init_key, **train_iter(0)) + state = optimizer.init(params) + + @jax.jit + def step(params, state, **batch): + def loss_fn(params): + lp = self.model.apply( + params, method="log_prob", y=batch["y"], x=batch["theta"] + ) + return -jnp.mean(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + losses = np.zeros([n_iter, 2]) + early_stop = EarlyStopping(1e-3, n_early_stopping_patience) + logging.info("training model") + for i in range(n_iter): + train_loss = 0.0 + for j in range(train_iter.num_batches): + batch = train_iter(j) + batch_loss, params, state = step(params, state, **batch) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples + ) + validation_loss = self._validation_loss(params, val_iter) + losses[i] = jnp.array([train_loss, validation_loss]) + + _, early_stop = early_stop.update(validation_loss) + if early_stop.should_stop: + logging.info("early stopping criterion found") + break + + losses = jnp.vstack(losses)[:i, :] + return params, losses + + def _validation_loss(self, params, val_iter): + @jax.jit + def loss_fn(**batch): + lp = self.model.apply( + params, method="log_prob", y=batch["y"], x=batch["theta"] + ) + return -jnp.mean(lp) + + def body_fn(i): + batch = val_iter(i) + loss = loss_fn(**batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + losses = 0.0 + for i in range(val_iter.num_batches): + losses += body_fn(i) + return losses + + def _init_params(self, rng_key, **init_data): + params = self.model.init( + rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"] + ) + return params + + def simulate_data_and_possibly_append( + self, + rng_key, + params=None, + observable=None, + data=None, + n_simulations=1_000, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): + """ + Simulate data from the posteriorand append it to an existing data set + (if provided) + + Parameters + ---------- + rng_key: jax.PRNGKey + a random key + params: pytree + a dictionary of neural network parameters + observable: jnp.ndarray + an observation + data: NamedTuple + existing data set + n_simulations: int + number of newly simulated data + n_chains: int + number of MCMC chains + n_samples: int + number of sa les to draw in total + n_warmup: int + number of draws to discared + kwargs: keyword arguments + dictionary of ey value pairs passed to `sample_posterior`. + The following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps (int) + - n_doubling: number of doubling steps of the interval (int) + - step_size: step size of the initial interval (float) - simulator_fn = partial( - self._simulate_new_data_and_append, - n_simulations_per_round=n_simulations_per_round, + Returns + ------- + NamedTuple: + returns a NamedTuple of two axis, y and theta + """ + return super().simulate_data_and_possibly_append( + rng_key=rng_key, + params=params, + observable=observable, + data=data, + n_simulations=n_simulations, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, + **kwargs, ) - D, params, all_diagnostics, all_losses, all_params = ( - None, - None, - [], - [], - [], - ) - for _ in range(n_rounds): - D, diagnostics = simulator_fn(params, D, **kwargs) - self._train_iter, self._val_iter = self.as_iterators( - D, batch_size, percentage_data_as_validation_set - ) - params, losses = self._fit_model_single_round( - optimizer=optimizer, - max_n_iter=max_n_iter, - n_early_stopping_patience=n_early_stopping_patience, - ) - all_params.append(params.copy()) - all_losses.append(losses) - all_diagnostics.append(diagnostics) - - snl_info = namedtuple("snl_info", "params losses diagnostics") - return params, snl_info(all_params, all_losses, all_diagnostics) - # pylint: disable=arguments-differ - def sample_posterior(self, params, n_chains, n_samples, n_warmup, **kwargs): + def sample_posterior( + self, + rng_key, + params, + observable, + *, + n_chains=4, + n_samples=2_000, + n_warmup=1_000, + **kwargs, + ): """ Sample from the approximate posterior Parameters ---------- + rng_key: jax.PRNGKey + a random key params: pytree a pytree of parameter for the model + observable: jnp.Array + observation to condition on n_chains: int - number of chains to sample + number of MCMC chains n_samples: int number of samples per chain n_warmup: int @@ -148,12 +258,13 @@ def sample_posterior(self, params, n_chains, n_samples, n_warmup, **kwargs): (n_samples \times p) """ + observable = jnp.atleast_2d(observable) part = partial( - self.model.apply, params=params, method="log_prob", y=self.observed + self.model.apply, params=params, method="log_prob", y=observable ) def _log_likelihood_fn(theta): - theta = jnp.tile(theta, [self.observed.shape[0], 1]) + theta = jnp.tile(theta, [observable.shape[0], 1]) return part(x=theta) def _joint_logdensity_fn(theta): @@ -162,96 +273,30 @@ def _joint_logdensity_fn(theta): return jnp.sum(lp) + jnp.sum(lp_prior) if "sampler" in kwargs and kwargs["sampler"] == "slice": + kwargs.pop("sampler", None) def lp__(theta): return jax.vmap(_joint_logdensity_fn)(theta) - kwargs.pop("sampler", None) - samples = sample_with_slice( - self._rng_seq, - lp__, - n_chains, - n_samples, - n_warmup, - self.prior_sampler_fn, - **kwargs, - ) + sampling_fn = sample_with_slice else: def lp__(theta): return _joint_logdensity_fn(**theta) - samples = sample_with_nuts( - self._rng_seq, - lp__, - self._len_theta, - n_chains, - n_samples, - n_warmup, - ) + sampling_fn = sample_with_nuts + + samples = sampling_fn( + rng_key=rng_key, + lp=lp__, + prior=self.prior_sampler_fn, + n_chains=n_chains, + n_samples=n_samples, + n_warmup=n_warmup, + **kwargs, + ) chex.assert_shape(samples, [n_samples - n_warmup, n_chains, None]) diagnostics = mcmc_diagnostics(samples) samples = samples.reshape((n_samples - n_warmup) * n_chains, -1) return samples, diagnostics - - def _fit_model_single_round( - self, optimizer, max_n_iter, n_early_stopping_patience - ): - params = self._init_params(next(self._rng_seq), **self._train_iter(0)) - state = optimizer.init(params) - - @jax.jit - def step(params, state, **batch): - def loss_fn(params): - lp = self.model.apply( - params, method="log_prob", y=batch["y"], x=batch["theta"] - ) - return -jnp.sum(lp) - - loss, grads = jax.value_and_grad(loss_fn)(params) - updates, new_state = optimizer.update(grads, state, params) - new_params = optax.apply_updates(params, updates) - return loss, new_params, new_state - - losses = np.zeros([max_n_iter, 2]) - early_stop = EarlyStopping(1e-3, n_early_stopping_patience) - logging.info("training model") - for i in range(max_n_iter): - train_loss = 0.0 - for j in range(self._train_iter.num_batches): - batch = self._train_iter(j) - batch_loss, params, state = step(params, state, **batch) - train_loss += batch_loss - print(train_loss) - validation_loss = self._validation_loss(params) - losses[i] = jnp.array([train_loss, validation_loss]) - - _, early_stop = early_stop.update(validation_loss) - if early_stop.should_stop: - logging.info("early stopping criterion found") - break - - losses = jnp.vstack(losses)[:i, :] - return params, losses - - def _validation_loss(self, params): - def _loss_fn(**batch): - lp = self.model.apply( - params, method="log_prob", y=batch["y"], x=batch["theta"] - ) - return -jnp.sum(lp) - - losses = jnp.array( - [ - _loss_fn(**self._val_iter(j)) - for j in range(self._val_iter.num_batches) - ] - ) - return jnp.sum(losses) - - def _init_params(self, rng_key, **init_data): - params = self.model.init( - rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"] - ) - return params diff --git a/sbijax/snl_test.py b/sbijax/snl_test.py index 9f1992c..8e8783a 100644 --- a/sbijax/snl_test.py +++ b/sbijax/snl_test.py @@ -2,7 +2,6 @@ import distrax import haiku as hk -import optax from jax import numpy as jnp from surjectors import Chain, MaskedCoupling, TransformedDistribution from surjectors.conditioners import mlp_conditioner @@ -68,14 +67,24 @@ def test_snl(): fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn snl = SNL(fns, make_model(2)) - params, info = snl.fit( + data, params = None, {} + for i in range(2): + data, _ = snl.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = snl.fit(next(rng_seq), data=data) + _ = snl.sample_posterior( next(rng_seq), + params, y_observed, - n_rounds=2, - optimizer=optax.adam(1e-4), - sampler="slice", n_chains=2, - n_samples=100, - n_warmup=50, + n_samples=200, + n_warmup=100, ) - _ = snl.sample_posterior(params, 2, 100, 50, sampler="slice") diff --git a/sbijax/snp.py b/sbijax/snp.py index d4f62ed..b4f0b96 100644 --- a/sbijax/snp.py +++ b/sbijax/snp.py @@ -1,4 +1,3 @@ -from collections import namedtuple from functools import partial import jax @@ -7,7 +6,7 @@ from absl import logging from flax.training.early_stopping import EarlyStopping from jax import numpy as jnp -from jax import random +from jax import random as jr from jax import scipy as jsp from sbijax._sne_base import SNE @@ -21,19 +20,21 @@ class SNP(SNE): From the Greenberg paper """ + def __init__(self, model_fns, density_estimator): + super().__init__(model_fns, density_estimator) + self.n_round = 0 + # pylint: disable=arguments-differ,too-many-locals def fit( self, rng_key, - observed, - optimizer, - n_rounds=10, - n_simulations_per_round=1000, - n_atoms=10, - max_n_iter=1000, + data, + optimizer=optax.adam(0.0003), + n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, + n_atoms=10, **kwargs, ): """ @@ -43,18 +44,11 @@ def fit( ---------- rng_seq: hk.PRNGSequence a hk.PRNGSequence - observed: chex.Array - (n \times p)-dimensional array of observations, where `n` is the n - number of samples + data: NamedTuple + data set obtained from calling `simulate_data_and_possibly_append` optimizer: optax.Optimizer an optax optimizer object - n_rounds: int - number of rounds to optimize - n_simulations_per_round: int - number of data simulations per round - n_atoms : int - number of atoms to approximate the proposal posterior - max_n_iter: + n_iter: maximal number of training iterations per round batch_size: int batch size used for training the model @@ -64,6 +58,8 @@ def fit( n_early_stopping_patience: int number of iterations of no improvement of training the flow before stopping optimisation + n_atoms : int + number of atoms to approximate the proposal posterior Returns ------- @@ -72,81 +68,37 @@ def fit( information """ - super().fit(rng_key, observed) - - simulator_fn = partial( - self._simulate_new_data_and_append, - n_simulations_per_round=n_simulations_per_round, + itr_key, rng_key = jr.split(rng_key) + train_iter, val_iter = self.as_iterators( + itr_key, data, batch_size, percentage_data_as_validation_set ) - D, params, all_losses, all_params = None, None, [], [] - for i_round in range(n_rounds): - D, _ = simulator_fn(params, D, **kwargs) - self._train_iter, self._val_iter = self.as_iterators( - D, batch_size, percentage_data_as_validation_set - ) - params, losses = self._fit_model_single_round( - optimizer=optimizer, - max_iter=max_n_iter, - n_early_stopping_patience=n_early_stopping_patience, - n_round=i_round, - n_atoms=n_atoms, - ) - all_params.append(params.copy()) - all_losses.append(losses) - - snp_info = namedtuple("snl_info", "params losses") - return params, snp_info(all_params, all_losses) - - def sample_posterior(self, params, n_simulations_per_round, **kwargs): - """ - Sample from the approximate posterior - - Parameters - ---------- - params: pytree - a pytree of parameter for the model - n_simulations_per_round: int - number of samples per chain - - Returns - ------- - chex.Array - an array of samples from the posterior distribution of dimension - (n_samples \times p) - """ - - thetas = None - n_curr = n_simulations_per_round - n_total_simulations_round = 0 - while n_curr > 0: - n_sim = jnp.maximum(100, n_curr) - n_total_simulations_round += n_sim - proposal = self.model.apply( - params, - next(self.rng_seq), - method="sample", - sample_shape=(n_sim,), - x=jnp.tile(self.observed, [n_sim, 1]), - ) - proposal_probs = self.prior_log_density_fn(proposal) - proposal_accepted = proposal[jnp.isfinite(proposal_probs)] - if thetas is None: - thetas = proposal_accepted - else: - thetas = jnp.vstack([thetas, proposal_accepted]) - n_curr -= proposal_accepted.shape[0] - self.n_total_simulations += n_total_simulations_round - return ( - thetas[:n_simulations_per_round], - thetas.shape[0] / n_total_simulations_round, + params, losses = self._fit_model_single_round( + seed=rng_key, + train_iter=train_iter, + val_iter=val_iter, + optimizer=optimizer, + n_iter=n_iter, + n_early_stopping_patience=n_early_stopping_patience, + n_atoms=n_atoms, ) + return params, losses + def _fit_model_single_round( - self, optimizer, max_iter, n_early_stopping_patience, n_round, n_atoms + self, + seed, + train_iter, + val_iter, + optimizer, + n_iter, + n_early_stopping_patience, + n_atoms, ): - params = self._init_params(next(self._rng_seq), **self._train_iter(0)) + init_key, seed = jr.split(seed) + params = self._init_params(init_key, **train_iter(0)) state = optimizer.init(params) + n_round = self.n_round if n_round == 0: def loss_fn(params, rng, **batch): @@ -157,7 +109,7 @@ def loss_fn(params, rng, **batch): y=batch["theta"], x=batch["y"], ) - return -jnp.sum(lp) + return -jnp.mean(lp) else: @@ -169,7 +121,7 @@ def loss_fn(params, rng, **batch): theta=batch["theta"], y=batch["y"], ) - return -jnp.sum(lp) + return -jnp.mean(lp) @jax.jit def step(params, rng, state, **batch): @@ -178,19 +130,24 @@ def step(params, rng, state, **batch): new_params = optax.apply_updates(params, updates) return loss, new_params, new_state - losses = np.zeros([max_iter, 2]) + losses = np.zeros([n_iter, 2]) early_stop = EarlyStopping(1e-3, n_early_stopping_patience) logging.info("training model") - for i in range(max_iter): + for i in range(n_iter): train_loss = 0.0 - for j in range(self._train_iter.num_batches): - batch = self._train_iter(j) + rng_key = jr.fold_in(seed, i) + for j in range(train_iter.num_batches): + train_key, rng_key = jr.split(rng_key) + batch = train_iter(j) batch_loss, params, state = step( - params, next(self.rng_seq), state, **batch + params, train_key, state, **batch + ) + train_loss += batch_loss * ( + batch["y"].shape[0] / train_iter.num_samples ) - train_loss += batch_loss + val_key, rng_key = jr.split(rng_key) validation_loss = self._validation_loss( - params, next(self.rng_seq), n_round, n_atoms + val_key, params, val_iter, n_atoms ) losses[i] = jnp.array([train_loss, validation_loss]) @@ -199,6 +156,7 @@ def step(params, rng, state, **batch): logging.info("early stopping criterion found") break + self.n_round += 1 losses = jnp.vstack(losses)[:i, :] return params, losses @@ -215,9 +173,9 @@ def _proposal_posterior_log_prob(self, params, rng, n_atoms, theta, y): probs = jnp.ones((n, n)) * (1 - jnp.eye(n)) / (n - 1) choice = partial( - random.choice, a=jnp.arange(n), replace=False, shape=(n_atoms - 1,) + jr.choice, a=jnp.arange(n), replace=False, shape=(n_atoms - 1,) ) - sample_keys = random.split(rng, probs.shape[0]) + sample_keys = jr.split(rng, probs.shape[0]) choices = jax.vmap(lambda key, prob: choice(key, p=prob))( sample_keys, probs ) @@ -242,8 +200,8 @@ def _proposal_posterior_log_prob(self, params, rng, n_atoms, theta, y): return log_prob_proposal_posterior - def _validation_loss(self, params, seed, n_round, n_atoms): - if n_round == 0: + def _validation_loss(self, rng_key, params, val_iter, n_atoms): + if self.n_round == 0: def loss_fn(rng, **batch): lp = self.model.apply( @@ -253,7 +211,7 @@ def loss_fn(rng, **batch): y=batch["theta"], x=batch["y"], ) - return -jnp.sum(lp) + return -jnp.mean(lp) else: @@ -261,10 +219,69 @@ def loss_fn(rng, **batch): lp = self._proposal_posterior_log_prob( params, rng, n_atoms, batch["theta"], batch["y"] ) - return -jnp.sum(lp) + return -jnp.mean(lp) - loss = 0 - for j in range(self._val_iter.num_batches): - rng, seed = random.split(seed) - loss += jax.jit(loss_fn)(rng, **self._val_iter(j)) + def body_fn(i, rng_key): + batch = val_iter(i) + loss = jax.jit(loss_fn)(rng_key, **batch) + return loss * (batch["y"].shape[0] / val_iter.num_samples) + + loss = 0.0 + for i in range(val_iter.num_batches): + val_key, rng_key = jr.split(rng_key) + loss += body_fn(i, val_key) return loss + + def sample_posterior( + self, rng_key, params, observable, *, n_samples=4_000, **kwargs + ): + """ + Sample from the approximate posterior + + Parameters + ---------- + rng_key: jax.PRNGKey + a random key + params: pytree + a pytree of parameter for the model + observable: jnp.Array + observation to condition on + n_samples: int + number of samples to draw + + Returns + ------- + chex.Array + an array of samples from the posterior distribution of dimension + (n_samples \times p) + """ + + observable = jnp.atleast_2d(observable) + + thetas = None + n_curr = n_samples + n_total_simulations_round = 0 + while n_curr > 0: + n_sim = jnp.minimum(200, jnp.maximum(200, n_curr)) + n_total_simulations_round += n_sim + sample_key, rng_key = jr.split(rng_key) + proposal = self.model.apply( + params, + sample_key, + method="sample", + sample_shape=(n_sim,), + x=jnp.tile(observable, [n_sim, 1]), + ) + proposal_probs = self.prior_log_density_fn(proposal) + proposal_accepted = proposal[jnp.isfinite(proposal_probs)] + if thetas is None: + thetas = proposal_accepted + else: + thetas = jnp.vstack([thetas, proposal_accepted]) + n_curr -= proposal_accepted.shape[0] + + self.n_total_simulations += n_total_simulations_round + return ( + thetas[:n_samples], + thetas.shape[0] / n_total_simulations_round, + ) diff --git a/sbijax/snp_test.py b/sbijax/snp_test.py index b284163..81629b0 100644 --- a/sbijax/snp_test.py +++ b/sbijax/snp_test.py @@ -2,7 +2,6 @@ import distrax import haiku as hk -import optax from jax import numpy as jnp from surjectors import Chain, MaskedCoupling, TransformedDistribution from surjectors.conditioners import mlp_conditioner @@ -67,11 +66,24 @@ def test_snp(): fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn snp = SNP(fns, make_model(2)) - params, info = snp.fit( + data, params = None, {} + for i in range(2): + data, _ = snp.simulate_data_and_possibly_append( + next(rng_seq), + params=params, + observable=y_observed, + data=data, + n_simulations=100, + n_chains=2, + n_samples=200, + n_warmup=100, + ) + params, info = snp.fit(next(rng_seq), data=data) + _ = snp.sample_posterior( next(rng_seq), + params, y_observed, - n_rounds=2, - optimizer=optax.adam(1e-4), - sampler="slice", + n_chains=2, + n_samples=200, + n_warmup=100, ) - _ = snp.sample_posterior(params, 100) From 6c4386c1d34a3b750861d1399dd358e357977b18 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 3 Oct 2023 18:55:43 +0200 Subject: [PATCH 2/5] Increment API --- sbijax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbijax/__init__.py b/sbijax/__init__.py index c510012..ff076eb 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,7 +2,7 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.0.12" +__version__ = "0.1.0" from sbijax.abc.rejection_abc import RejectionABC From 4eae022cfb5cd8dc6a5930843cc7a7fab3a603e6 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 3 Oct 2023 19:03:09 +0200 Subject: [PATCH 3/5] Fix lints --- Untitled.ipynb | 442 -------------------------- examples/bivariate_gaussian_smcabc.py | 6 +- sbijax/generator.py | 3 +- 3 files changed, 5 insertions(+), 446 deletions(-) delete mode 100644 Untitled.ipynb diff --git a/Untitled.ipynb b/Untitled.ipynb deleted file mode 100644 index 9a51e15..0000000 --- a/Untitled.ipynb +++ /dev/null @@ -1,442 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "05a45ed4-97f7-4bba-9581-a37e978180a5", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "#os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00d2daa7-a509-47e1-86c6-752a75b47d0b", - "metadata": {}, - "outputs": [], - "source": [ - "import jax.numpy as jnp\n", - "import jax.random as jr\n", - "from jax import scipy as jsp\n", - "\n", - "import blackjax as bj\n", - "import tensorflow_probability.substrates.jax as tfp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c5c5bdce-fe4c-4fb3-a0c1-15911ba2b40a", - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "from jax.lib import xla_bridge" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c522c600-31a4-405c-bd12-ebabcc77131d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[CpuDevice(id=0)]\n", - "cpu\n", - "1\n", - "cpu\n" - ] - } - ], - "source": [ - "print(jax.devices())\n", - "print(jax.default_backend())\n", - "print(jax.device_count())\n", - "print(xla_bridge.get_backend().platform)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8198dccd-d852-4dbc-a98f-486b8733a030", - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "\n", - "import distrax\n", - "import haiku as hk\n", - "import jax\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "from jax import numpy as jnp\n", - "from jax import random\n", - "from jax import scipy as jsp\n", - "from jax import vmap\n", - "from functools import partial" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4e090efd-dfff-4fd1-b421-71f0166ffe59", - "metadata": {}, - "outputs": [], - "source": [ - "def likelihood_fn(theta):\n", - " mu = jnp.tile(theta[:2], 4)\n", - " s1, s2 = theta[2] ** 2, theta[3] ** 2\n", - " corr = s1 * s2 * jnp.tanh(theta[4])\n", - " cov = jnp.array([[s1**2, corr], [corr, s2**2]])\n", - " cov = jsp.linalg.block_diag(*[cov for _ in range(4)])\n", - " p = distrax.MultivariateNormalFullCovariance(mu, cov)\n", - " return p" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5bbedba0-346c-4604-ab7f-5318e9d3b15b", - "metadata": {}, - "outputs": [], - "source": [ - "lik = likelihood_fn(jr.normal(jr.PRNGKey(123), (5,)))\n", - "y = lik.sample(seed=jr.PRNGKey(1), sample_shape=(10,))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45482e72-6971-48cb-a58c-2a79be31b443", - "metadata": {}, - "outputs": [], - "source": [ - "def likelihood_fn(theta, y):\n", - " mu = jnp.tile(theta[:2], 4)\n", - " s1, s2 = theta[2] ** 2, theta[3] ** 2\n", - " corr = s1 * s2 * jnp.tanh(theta[4])\n", - " cov = jnp.array([[s1**2, corr], [corr, s2**2]])\n", - " cov = jsp.linalg.block_diag(*[cov for _ in range(4)])\n", - " p = distrax.MultivariateNormalFullCovariance(mu, cov)\n", - " return p.log_prob(y)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e07a4d59-80f8-4303-982e-fabf2acf24c5", - "metadata": {}, - "outputs": [], - "source": [ - "prior = distrax.Independent(\n", - " distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d898ca02-22f3-4102-b667-c2ef7abdd492", - "metadata": {}, - "outputs": [], - "source": [ - "def log_density_fn(theta, y):\n", - " prior_lp = prior.log_prob(theta)\n", - " likelihood_lp = likelihood_fn(theta, y)\n", - "\n", - " lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)\n", - " return lp\n", - "\n", - "target_log_prob_fn_partial = partial(log_density_fn, y=y)\n", - "target_log_prob_fn = lambda theta: jax.vmap(target_log_prob_fn_partial)(theta)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "55e0fa2c-6aaa-4ab6-922a-1b5f2559b5d1", - "metadata": {}, - "outputs": [], - "source": [ - "n_samples = 10000\n", - "n_warmup = 5000\n", - "n_chains = 4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "263b64ec-8640-4f61-8188-0f0953d53668", - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'functools.partial' object has no attribute 'experimental_default_event_space_bijector'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtfp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmcmc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwindowed_adaptive_nuts\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget_log_prob_fn_partial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_warmup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_step_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_tree_depth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_energy_diff\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000.0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43munrolled_leapfrog_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mparallel_iterations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:710\u001b[0m, in \u001b[0;36mwindowed_adaptive_nuts\u001b[0;34m(n_draws, joint_dist, n_chains, num_adaptation_steps, current_state, init_step_size, dual_averaging_kwargs, max_tree_depth, max_energy_diff, unrolled_leapfrog_steps, parallel_iterations, trace_fn, return_final_kernel_results, discard_tuning, chain_axis_names, seed, **pins)\u001b[0m\n\u001b[1;32m 703\u001b[0m dual_averaging_kwargs\u001b[38;5;241m.\u001b[39msetdefault(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtarget_accept_prob\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m0.85\u001b[39m)\n\u001b[1;32m 704\u001b[0m proposal_kernel_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 705\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep_size\u001b[39m\u001b[38;5;124m'\u001b[39m: init_step_size,\n\u001b[1;32m 706\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_tree_depth\u001b[39m\u001b[38;5;124m'\u001b[39m: max_tree_depth,\n\u001b[1;32m 707\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_energy_diff\u001b[39m\u001b[38;5;124m'\u001b[39m: max_energy_diff,\n\u001b[1;32m 708\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124munrolled_leapfrog_steps\u001b[39m\u001b[38;5;124m'\u001b[39m: unrolled_leapfrog_steps,\n\u001b[1;32m 709\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparallel_iterations\u001b[39m\u001b[38;5;124m'\u001b[39m: parallel_iterations}\n\u001b[0;32m--> 710\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_windowed_adaptive_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 711\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_draws\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_draws\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 712\u001b[0m \u001b[43m \u001b[49m\u001b[43mjoint_dist\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjoint_dist\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 713\u001b[0m \u001b[43m \u001b[49m\u001b[43mkind\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mnuts\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 714\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_chains\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[43m \u001b[49m\u001b[43mproposal_kernel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproposal_kernel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 716\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 717\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 718\u001b[0m \u001b[43m \u001b[49m\u001b[43mdual_averaging_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdual_averaging_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 719\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrace_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrace_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 720\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_final_kernel_results\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_final_kernel_results\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 721\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiscard_tuning\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdiscard_tuning\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 722\u001b[0m \u001b[43m \u001b[49m\u001b[43mchain_axis_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchain_axis_names\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 723\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 724\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpins\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:890\u001b[0m, in \u001b[0;36m_windowed_adaptive_impl\u001b[0;34m(n_draws, joint_dist, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, chain_axis_names, **pins)\u001b[0m\n\u001b[1;32m 885\u001b[0m dual_averaging_kwargs\u001b[38;5;241m.\u001b[39msetdefault(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mexperimental_reduce_chain_axis_names\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 886\u001b[0m chain_axis_names)\n\u001b[1;32m 887\u001b[0m setup_seed, sample_seed \u001b[38;5;241m=\u001b[39m samplers\u001b[38;5;241m.\u001b[39msplit_seed(\n\u001b[1;32m 888\u001b[0m samplers\u001b[38;5;241m.\u001b[39msanitize_seed(seed), n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 889\u001b[0m (target_log_prob_fn, initial_transformed_position, bijector,\n\u001b[0;32m--> 890\u001b[0m step_broadcast, batch_shape, shard_axis_names) \u001b[38;5;241m=\u001b[39m \u001b[43m_setup_mcmc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 891\u001b[0m \u001b[43m \u001b[49m\u001b[43mjoint_dist\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 892\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_chains\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 893\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 894\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msetup_seed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 895\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpins\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 897\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m proposal_kernel_kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep_size\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 898\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_shape\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m (\u001b[38;5;241m0\u001b[39m,): \u001b[38;5;66;03m# Scalar batch has a 0-vector shape.\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:226\u001b[0m, in \u001b[0;36m_setup_mcmc\u001b[0;34m(model, n_chains, init_position, seed, **pins)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Construct bijector and transforms needed for windowed MCMC.\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \n\u001b[1;32m 194\u001b[0m \u001b[38;5;124;03mThis pins the initial model, constructs a bijector that unconstrains and\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;124;03m shard_axis_names: Shard axis names for the model\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 225\u001b[0m pinned_model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mexperimental_pin(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpins) \u001b[38;5;28;01mif\u001b[39;00m pins \u001b[38;5;28;01melse\u001b[39;00m model\n\u001b[0;32m--> 226\u001b[0m bijector, step_bijector \u001b[38;5;241m=\u001b[39m \u001b[43m_get_flat_unconstraining_bijector\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpinned_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_position \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 229\u001b[0m raw_init_dist \u001b[38;5;241m=\u001b[39m initialization\u001b[38;5;241m.\u001b[39minit_near_unconstrained_zero(pinned_model)\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:172\u001b[0m, in \u001b[0;36m_get_flat_unconstraining_bijector\u001b[0;34m(jd_model)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Create a bijector from a joint distribution that flattens and unconstrains.\u001b[39;00m\n\u001b[1;32m 153\u001b[0m \n\u001b[1;32m 154\u001b[0m \u001b[38;5;124;03mThe intention is (loosely) to go from a model joint distribution supported on\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;124;03m points, and the second may be used to initialize a step size.\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# TODO(b/180396233): This bijector is in general point-dependent.\u001b[39;00m\n\u001b[0;32m--> 172\u001b[0m event_space_bij \u001b[38;5;241m=\u001b[39m \u001b[43mjd_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental_default_event_space_bijector\u001b[49m()\n\u001b[1;32m 173\u001b[0m flat_bijector \u001b[38;5;241m=\u001b[39m restructure\u001b[38;5;241m.\u001b[39mpack_sequence_as(jd_model\u001b[38;5;241m.\u001b[39mevent_shape_tensor())\n\u001b[1;32m 175\u001b[0m unconstrained_shapes \u001b[38;5;241m=\u001b[39m event_space_bij(\n\u001b[1;32m 176\u001b[0m flat_bijector)\u001b[38;5;241m.\u001b[39minverse_event_shape_tensor(jd_model\u001b[38;5;241m.\u001b[39mevent_shape_tensor())\n", - "\u001b[0;31mAttributeError\u001b[0m: 'functools.partial' object has no attribute 'experimental_default_event_space_bijector'" - ] - } - ], - "source": [ - "nuts = tfp.mcmc.NoUTurnSampler(\n", - " target_log_prob_fn,\n", - " step_size=0.1,\n", - " max_tree_depth=10,\n", - " max_energy_diff=1000.0,\n", - " unrolled_leapfrog_steps=1,\n", - ")\n", - "\n", - "\n", - "nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(\n", - " inner_kernel=nuts,\n", - " num_adaptation_steps=int(0.8 * n_warmup),\n", - " target_accept_prob=jnp.asarray(0.75, jnp.float32)\n", - ")\n", - "\n", - "tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(\n", - " nuts,\n", - " initial_running_variance,\n", - " num_estimation_steps=None,\n", - " momentum_distribution_setter_fn=hmc_like_momentum_distribution_setter_fn,\n", - " momentum_distribution_getter_fn=hmc_like_momentum_distribution_getter_fn,\n", - " validate_args=False,\n", - " experimental_shard_axis_names=None,\n", - " name=None\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d826bb25-929a-4273-aab6-d281a54a0844", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.device_count()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b59b71ae-963c-43f1-ab00-4d3881354864", - "metadata": {}, - "outputs": [], - "source": [ - "initial_states = jr.normal(jr.PRNGKey(4), shape=(n_chains, 5))\n", - "samples = tfp.mcmc.sample_chain(\n", - " num_results=n_samples - n_warmup,\n", - " current_state=initial_states,\n", - " num_steps_between_results=1,\n", - " kernel=adaptive_sampler,\n", - " num_burnin_steps=n_warmup,\n", - " trace_fn=None,\n", - " seed=jr.PRNGKey(2),\n", - ")\n", - "samples = samples[n_warmup:, ...]" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "99db6c8b-49db-435e-a886-e5968106c512", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.devices()" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "e6d8884c-1153-48b3-867a-cac4020e344b", - "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[24], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msfunc\u001b[39m(x): \n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m: \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43msfunc\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/api.py:2253\u001b[0m, in \u001b[0;36m_cpp_pmap..cache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 2251\u001b[0m execute: Optional[functools\u001b[38;5;241m.\u001b[39mpartial] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 2252\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(top_trace, core\u001b[38;5;241m.\u001b[39mEvalTrace):\n\u001b[0;32m-> 2253\u001b[0m execute \u001b[38;5;241m=\u001b[39m \u001b[43mpxla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mxla_pmap_impl_lazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2254\u001b[0m out \u001b[38;5;241m=\u001b[39m map_bind_continuation(execute(\u001b[38;5;241m*\u001b[39mtracers))\n\u001b[1;32m 2255\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:974\u001b[0m, in \u001b[0;36mxla_pmap_impl_lazy\u001b[0;34m(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)\u001b[0m\n\u001b[1;32m 972\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _emap_apply_fn\n\u001b[1;32m 973\u001b[0m abstract_args \u001b[38;5;241m=\u001b[39m unsafe_map(xla\u001b[38;5;241m.\u001b[39mabstractify, args)\n\u001b[0;32m--> 974\u001b[0m compiled_fun, fingerprint \u001b[38;5;241m=\u001b[39m \u001b[43mparallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 975\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_axis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 976\u001b[0m \u001b[43m \u001b[49m\u001b[43min_axes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_axes_thunk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mabstract_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# Don't re-abstractify args unless logging is enabled for performance.\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39mjax_distributed_debug:\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/linear_util.py:303\u001b[0m, in \u001b[0;36mcache..memoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 301\u001b[0m fun\u001b[38;5;241m.\u001b[39mpopulate_stores(stores)\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 303\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 304\u001b[0m cache[key] \u001b[38;5;241m=\u001b[39m (ans, fun\u001b[38;5;241m.\u001b[39mstores)\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ans\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1245\u001b[0m, in \u001b[0;36mparallel_callable\u001b[0;34m(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *avals)\u001b[0m\n\u001b[1;32m 1232\u001b[0m \u001b[38;5;129m@lu\u001b[39m\u001b[38;5;241m.\u001b[39mcache\n\u001b[1;32m 1233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_callable\u001b[39m(fun: lu\u001b[38;5;241m.\u001b[39mWrappedFun,\n\u001b[1;32m 1234\u001b[0m backend_name: Optional[\u001b[38;5;28mstr\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1243\u001b[0m global_arg_shapes: Sequence[Optional[Tuple[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]]],\n\u001b[1;32m 1244\u001b[0m \u001b[38;5;241m*\u001b[39mavals):\n\u001b[0;32m-> 1245\u001b[0m pmap_computation \u001b[38;5;241m=\u001b[39m \u001b[43mlower_parallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1246\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_axis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43min_axes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_axes_thunk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mavals\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1248\u001b[0m pmap_executable \u001b[38;5;241m=\u001b[39m pmap_computation\u001b[38;5;241m.\u001b[39mcompile()\n\u001b[1;32m 1249\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m WeakRefList([pmap_executable\u001b[38;5;241m.\u001b[39munsafe_call, pmap_executable\u001b[38;5;241m.\u001b[39mfingerprint])\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1414\u001b[0m, in \u001b[0;36mlower_parallel_callable\u001b[0;34m(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, avals)\u001b[0m\n\u001b[1;32m 1409\u001b[0m must_run_on_all_devices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 1411\u001b[0m pci \u001b[38;5;241m=\u001b[39m ParallelCallableInfo(\n\u001b[1;32m 1412\u001b[0m name, backend, axis_name, axis_size, global_axis_size, devices,\n\u001b[1;32m 1413\u001b[0m in_axes, out_axes_thunk, avals)\n\u001b[0;32m-> 1414\u001b[0m jaxpr, consts, replicas, parts, shards \u001b[38;5;241m=\u001b[39m \u001b[43mstage_parallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1415\u001b[0m \u001b[43m \u001b[49m\u001b[43mpci\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1417\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m logger\u001b[38;5;241m.\u001b[39misEnabledFor(logging\u001b[38;5;241m.\u001b[39mDEBUG):\n\u001b[1;32m 1418\u001b[0m logger\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msharded_avals: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, shards\u001b[38;5;241m.\u001b[39msharded_avals)\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1321\u001b[0m, in \u001b[0;36mstage_parallel_callable\u001b[0;34m(pci, fun, global_arg_shapes)\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m core\u001b[38;5;241m.\u001b[39mextend_axis_env(pci\u001b[38;5;241m.\u001b[39maxis_name, pci\u001b[38;5;241m.\u001b[39mglobal_axis_size, \u001b[38;5;28;01mNone\u001b[39;00m): \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 1318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m dispatch\u001b[38;5;241m.\u001b[39mlog_elapsed_time(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFinished tracing + transforming \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1319\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfor pmap in \u001b[39m\u001b[38;5;132;01m{elapsed_time}\u001b[39;00m\u001b[38;5;124m sec\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1320\u001b[0m event\u001b[38;5;241m=\u001b[39mdispatch\u001b[38;5;241m.\u001b[39mJAXPR_TRACE_EVENT):\n\u001b[0;32m-> 1321\u001b[0m jaxpr, out_sharded_avals, consts \u001b[38;5;241m=\u001b[39m \u001b[43mpe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrace_to_jaxpr_final\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1322\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_sharded_avals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdebug_info_final\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpmap\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1323\u001b[0m jaxpr \u001b[38;5;241m=\u001b[39m dispatch\u001b[38;5;241m.\u001b[39mapply_outfeed_rewriter(jaxpr)\n\u001b[1;32m 1325\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out_sharded_avals) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(pci\u001b[38;5;241m.\u001b[39mout_axes), (\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28mlen\u001b[39m(out_sharded_avals), \u001b[38;5;28mlen\u001b[39m(pci\u001b[38;5;241m.\u001b[39mout_axes))\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:2065\u001b[0m, in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info, keep_inputs)\u001b[0m\n\u001b[1;32m 2063\u001b[0m main\u001b[38;5;241m.\u001b[39mjaxpr_stack \u001b[38;5;241m=\u001b[39m () \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 2064\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m core\u001b[38;5;241m.\u001b[39mnew_sublevel():\n\u001b[0;32m-> 2065\u001b[0m jaxpr, out_avals, consts \u001b[38;5;241m=\u001b[39m \u001b[43mtrace_to_subjaxpr_dynamic\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2066\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_avals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdebug_info\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdebug_info\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2067\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m fun, main\n\u001b[1;32m 2068\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jaxpr, out_avals, consts\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1998\u001b[0m, in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals, keep_inputs, debug_info)\u001b[0m\n\u001b[1;32m 1996\u001b[0m in_tracers \u001b[38;5;241m=\u001b[39m _input_type_to_tracers(trace\u001b[38;5;241m.\u001b[39mnew_arg, in_avals)\n\u001b[1;32m 1997\u001b[0m in_tracers_ \u001b[38;5;241m=\u001b[39m [t \u001b[38;5;28;01mfor\u001b[39;00m t, keep \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(in_tracers, keep_inputs) \u001b[38;5;28;01mif\u001b[39;00m keep]\n\u001b[0;32m-> 1998\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[43mfun\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall_wrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43min_tracers_\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1999\u001b[0m out_tracers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmap\u001b[39m(trace\u001b[38;5;241m.\u001b[39mfull_raise, ans)\n\u001b[1;32m 2000\u001b[0m jaxpr, consts \u001b[38;5;241m=\u001b[39m frame\u001b[38;5;241m.\u001b[39mto_jaxpr(out_tracers)\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/linear_util.py:167\u001b[0m, in \u001b[0;36mWrappedFun.call_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m gen \u001b[38;5;241m=\u001b[39m gen_static_args \u001b[38;5;241m=\u001b[39m out_store \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 167\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# Some transformations yield from inside context managers, so we have to\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# interrupt them before reraising the exception. Otherwise they will only\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# get garbage-collected at some later time, running their cleanup tasks\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# only after this exception is handled, which can corrupt the global\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;66;03m# state.\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m stack:\n", - "Cell \u001b[0;32mIn[24], line 2\u001b[0m, in \u001b[0;36msfunc\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msfunc\u001b[39m(x): \n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m: \u001b[38;5;28;01mpass\u001b[39;00m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "def sfunc(x): \n", - " while True: pass\n", - "\n", - "jax.pmap(sfunc)(jnp.arange(4))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "9032ad64-c70e-4acb-9541-8ef14a36da2b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "numpy: cpu usage 0.6/8 wall_time:1.7s\n", - "vmap: cpu usage 1.8/8 wall_time:1.9s\n", - "dot: cpu usage 7.0/8 wall_time:3.7s\n" - ] - } - ], - "source": [ - "import time, os, jax, numpy as np, jax.numpy as jnp\n", - "jax.config.update('jax_platform_name', 'cpu') # insures we use the CPU\n", - "\n", - "def timer(name, f, x, shouldBlock=True):\n", - " # warmup\n", - " y = f(x).block_until_ready() if shouldBlock else f(x)\n", - " # running the code\n", - " start_wall = time.perf_counter()\n", - " start_cpu = time.process_time()\n", - " y = f(x).block_until_ready() if shouldBlock else f(x)\n", - " end_wall = time.perf_counter()\n", - " end_cpu = time.process_time()\n", - " # computing the metric and displaying it\n", - " wall_time = end_wall - start_wall\n", - " cpu_time = end_cpu - start_cpu\n", - " cpu_count = os.cpu_count()\n", - " print(f\"{name}: cpu usage {cpu_time/wall_time:.1f}/{cpu_count} wall_time:{wall_time:.1f}s\")\n", - "\n", - "# test functions\n", - "key = jax.random.PRNGKey(0)\n", - "x = jax.random.normal(key, shape=(500000000,), dtype=jnp.float64)\n", - "x_mat = jax.random.normal(key, shape=(10000,10000), dtype=jnp.float64)\n", - "f_numpy = np.cos\n", - "f_vmap = jax.jit(jax.vmap(jnp.cos))\n", - "f_dot = jax.jit(lambda x: jnp.dot(x,x.T)) # to show that JAX can indeed use all cores\n", - "\n", - "timer('numpy', f_numpy, x, shouldBlock=False)\n", - "timer('vmap', f_vmap, x)\n", - "timer('dot', f_dot, x_mat)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "abf84ca0-52d8-4469-82d5-1e9099a1d05b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "os.cpu_count()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "1cef2b44-f3ac-43e0-ba1e-604c8fa8bfc4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n" - ] - } - ], - "source": [ - "from jax import local_device_count\n", - "print(jax.local_devices()) #1 instead of 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9dd7daeb-449e-4111-9170-f7937b495991", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "sbi-dev", - "language": "python", - "name": "sbi-dev" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/bivariate_gaussian_smcabc.py b/examples/bivariate_gaussian_smcabc.py index 6a1014e..99d2e37 100644 --- a/examples/bivariate_gaussian_smcabc.py +++ b/examples/bivariate_gaussian_smcabc.py @@ -6,7 +6,8 @@ import jax import matplotlib.pyplot as plt import seaborn as sns -from jax import numpy as jnp, random as jr +from jax import numpy as jnp +from jax import random as jr from sbijax import SMCABC @@ -43,8 +44,7 @@ def run(): smc = SMCABC(fns, summary_fn, distance_fn) smc_samples, _ = smc.sample_posterior( - jr.PRNGKey(22), y_observed, - 10, 1000, 1000, 0.6, 500 + jr.PRNGKey(22), y_observed, 10, 1000, 1000, 0.6, 500 ) fig, axes = plt.subplots(2) diff --git a/sbijax/generator.py b/sbijax/generator.py index 0f0c7bc..9ba321d 100644 --- a/sbijax/generator.py +++ b/sbijax/generator.py @@ -2,7 +2,8 @@ import chex from jax import lax -from jax import numpy as jnp, random as jr +from jax import numpy as jnp +from jax import random as jr named_dataset = namedtuple("named_dataset", "y theta") From 42528b9cef161173b268267d1e03b29e69f69231 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 3 Oct 2023 19:10:49 +0200 Subject: [PATCH 4/5] Fix lints --- sbijax/snp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sbijax/snp.py b/sbijax/snp.py index b4f0b96..ac58ce0 100644 --- a/sbijax/snp.py +++ b/sbijax/snp.py @@ -4,15 +4,16 @@ import numpy as np import optax from absl import logging -from flax.training.early_stopping import EarlyStopping from jax import numpy as jnp from jax import random as jr from jax import scipy as jsp from sbijax._sne_base import SNE - # pylint: disable=too-many-arguments,unused-argument +from sbijax.nn.early_stopping import EarlyStopping + + class SNP(SNE): """ Sequential neural posterior estimation From 629982d7cb61e283008386d63273c60c3e4df4c2 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 3 Oct 2023 19:44:33 +0200 Subject: [PATCH 5/5] Add SSNL --- examples/slcp_ssnl.py | 242 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 examples/slcp_ssnl.py diff --git a/examples/slcp_ssnl.py b/examples/slcp_ssnl.py new file mode 100644 index 0000000..8a29eec --- /dev/null +++ b/examples/slcp_ssnl.py @@ -0,0 +1,242 @@ +""" +SLCP example from [1] using SNL and masked autoregressive bijections +or surjections +""" + +import argparse +from functools import partial + +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +import pandas as pd +import seaborn as sns +from jax import numpy as jnp +from jax import random as jr +from jax import scipy as jsp +from jax import vmap +from surjectors import ( + AffineMaskedAutoregressiveInferenceFunnel, + Chain, + MaskedAutoregressive, + Permutation, + TransformedDistribution, +) +from surjectors.conditioners import MADE, mlp_conditioner +from surjectors.util import unstack + +from sbijax import SNL +from sbijax.mcmc.slice import sample_with_slice + + +def prior_model_fns(): + p = distrax.Independent( + distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 + ) + return p.sample, p.log_prob + + +def simulator_fn(seed, theta): + orig_shape = theta.shape + if theta.ndim == 2: + theta = theta[:, None, :] + us_key, noise_key = jr.split(seed) + + def _unpack_params(ps): + m0 = ps[..., [0]] + m1 = ps[..., [1]] + s0 = ps[..., [2]] ** 2 + s1 = ps[..., [3]] ** 2 + r = np.tanh(ps[..., [4]]) + return m0, m1, s0, s1, r + + m0, m1, s0, s1, r = _unpack_params(theta) + us = distrax.Normal(0.0, 1.0).sample( + seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2) + ) + xs = jnp.empty_like(us) + xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0) + y = xs.at[:, :, :, 1].set( + s1 * (r * us[:, :, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1 + ) + if len(orig_shape) == 2: + y = y.reshape((*theta.shape[:1], 8)) + else: + y = y.reshape((*theta.shape[:2], 8)) + return y + + +def likelihood_fn(theta, y): + mu = jnp.tile(theta[:2], 4) + s1, s2 = theta[2] ** 2, theta[3] ** 2 + corr = s1 * s2 * jnp.tanh(theta[4]) + cov = jnp.array([[s1**2, corr], [corr, s2**2]]) + cov = jsp.linalg.block_diag(*[cov for _ in range(4)]) + p = distrax.MultivariateNormalFullCovariance(mu, cov) + return p.log_prob(y) + + +def log_density_fn(theta, y): + prior_lp = distrax.Independent( + distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1 + ).log_prob(theta) + likelihood_lp = likelihood_fn(theta, y) + + lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) + return lp + + +def make_model(dim, use_surjectors): + def _bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _decoder_fn(n_dim): + decoder_net = mlp_conditioner( + [50, n_dim * 2], + w_init=hk.initializers.TruncatedNormal(stddev=0.001), + ) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent( + distrax.Normal(mu, jnp.exp(log_scale)), 1 + ) + + return _fn + + def _flow(method, **kwargs): + layers = [] + n_dimension = dim + order = jnp.arange(n_dimension) + for i in range(5): + if i == 2 and use_surjectors: + n_latent = 6 + layer = AffineMaskedAutoregressiveInferenceFunnel( + n_latent, + _decoder_fn(n_dimension - n_latent), + conditioner=MADE( + n_latent, + [50, n_latent * 2], + 2, + w_init=hk.initializers.TruncatedNormal(0.001), + b_init=jnp.zeros, + activation=jax.nn.tanh, + ), + ) + n_dimension = n_latent + order = order[::-1] + order = order[:n_dimension] - jnp.min(order[:n_dimension]) + else: + layer = MaskedAutoregressive( + bijector_fn=_bijector_fn, + conditioner=MADE( + n_dimension, + [50, n_dimension * 2], + 2, + w_init=hk.initializers.TruncatedNormal(0.001), + b_init=jnp.zeros, + activation=jax.nn.tanh, + ), + ) + order = order[::-1] + layers.append(layer) + layers.append(Permutation(order, 1)) + chain = Chain(layers) + + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)), + reinterpreted_batch_ndims=1, + ) + td = TransformedDistribution(base_distribution, chain) + return td(method, **kwargs) + + td = hk.transform(_flow) + td = hk.without_apply_rng(td) + return td + + +def run(use_surjectors): + len_theta = 5 + # this is the thetas used in SNL + # thetas = jnp.array([-0.7, -2.9, -1.0, -0.9, 0.6]) + y_observed = jnp.array( + [ + [ + -0.9707123, + -2.9461224, + -0.4494722, + -3.4231849, + -0.13285634, + -3.364017, + -0.85367596, + -2.4271638, + ] + ] + ) + + log_density_partial = partial(log_density_fn, y=y_observed) + log_density = lambda x: vmap(log_density_partial)(x) + + prior_simulator_fn, prior_fn = prior_model_fns() + fns = (prior_simulator_fn, prior_fn), simulator_fn + + snl = SNL(fns, make_model(y_observed.shape[1], use_surjectors)) + optimizer = optax.adam(1e-3) + + data, params = None, {} + for i in range(5): + data, _ = snl.simulate_data_and_possibly_append( + jr.fold_in(jr.PRNGKey(12), i), + params=params, + observable=y_observed, + data=data, + sampler="slice", + ) + params, info = snl.fit( + jr.fold_in(jr.PRNGKey(23), i), data=data, optimizer=optimizer + ) + + sample_key, rng_key = jr.split(jr.PRNGKey(123)) + snl_samples, _ = snl.sample_posterior(sample_key, params, y_observed) + + sample_key, rng_key = jr.split(rng_key) + slice_samples = sample_with_slice( + sample_key, + log_density, + prior_simulator_fn, + ) + slice_samples = slice_samples.reshape(-1, len_theta) + + g = sns.PairGrid(pd.DataFrame(slice_samples)) + g.map_upper(sns.scatterplot, color="black", marker=".", edgecolor=None, s=2) + g.map_diag(plt.hist, color="black") + for ax in g.axes.flatten(): + ax.set_xlim(-5, 5) + ax.set_ylim(-5, 5) + g.fig.set_figheight(5) + g.fig.set_figwidth(5) + plt.show() + + fig, axes = plt.subplots(len_theta, 2) + for i in range(len_theta): + sns.histplot(slice_samples[:, i], color="darkgrey", ax=axes[i, 0]) + sns.histplot(snl_samples[:, i], color="darkblue", ax=axes[i, 1]) + axes[i, 0].set_title(rf"Sampled posterior $\theta_{i}$") + axes[i, 1].set_title(rf"Approximated posterior $\theta_{i}$") + for j in range(2): + axes[i, j].set_xlim(-5, 5) + sns.despine() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--use-surjectors", action="store_true", default=True) + args = parser.parse_args() + run(args.use_surjectors)