From 4eae022cfb5cd8dc6a5930843cc7a7fab3a603e6 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 3 Oct 2023 19:03:09 +0200 Subject: [PATCH] Fix lints --- Untitled.ipynb | 442 -------------------------- examples/bivariate_gaussian_smcabc.py | 6 +- sbijax/generator.py | 3 +- 3 files changed, 5 insertions(+), 446 deletions(-) delete mode 100644 Untitled.ipynb diff --git a/Untitled.ipynb b/Untitled.ipynb deleted file mode 100644 index 9a51e15..0000000 --- a/Untitled.ipynb +++ /dev/null @@ -1,442 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "05a45ed4-97f7-4bba-9581-a37e978180a5", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "#os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00d2daa7-a509-47e1-86c6-752a75b47d0b", - "metadata": {}, - "outputs": [], - "source": [ - "import jax.numpy as jnp\n", - "import jax.random as jr\n", - "from jax import scipy as jsp\n", - "\n", - "import blackjax as bj\n", - "import tensorflow_probability.substrates.jax as tfp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c5c5bdce-fe4c-4fb3-a0c1-15911ba2b40a", - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "from jax.lib import xla_bridge" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c522c600-31a4-405c-bd12-ebabcc77131d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[CpuDevice(id=0)]\n", - "cpu\n", - "1\n", - "cpu\n" - ] - } - ], - "source": [ - "print(jax.devices())\n", - "print(jax.default_backend())\n", - "print(jax.device_count())\n", - "print(xla_bridge.get_backend().platform)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8198dccd-d852-4dbc-a98f-486b8733a030", - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "\n", - "import distrax\n", - "import haiku as hk\n", - "import jax\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "from jax import numpy as jnp\n", - "from jax import random\n", - "from jax import scipy as jsp\n", - "from jax import vmap\n", - "from functools import partial" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4e090efd-dfff-4fd1-b421-71f0166ffe59", - "metadata": {}, - "outputs": [], - "source": [ - "def likelihood_fn(theta):\n", - " mu = jnp.tile(theta[:2], 4)\n", - " s1, s2 = theta[2] ** 2, theta[3] ** 2\n", - " corr = s1 * s2 * jnp.tanh(theta[4])\n", - " cov = jnp.array([[s1**2, corr], [corr, s2**2]])\n", - " cov = jsp.linalg.block_diag(*[cov for _ in range(4)])\n", - " p = distrax.MultivariateNormalFullCovariance(mu, cov)\n", - " return p" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5bbedba0-346c-4604-ab7f-5318e9d3b15b", - "metadata": {}, - "outputs": [], - "source": [ - "lik = likelihood_fn(jr.normal(jr.PRNGKey(123), (5,)))\n", - "y = lik.sample(seed=jr.PRNGKey(1), sample_shape=(10,))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45482e72-6971-48cb-a58c-2a79be31b443", - "metadata": {}, - "outputs": [], - "source": [ - "def likelihood_fn(theta, y):\n", - " mu = jnp.tile(theta[:2], 4)\n", - " s1, s2 = theta[2] ** 2, theta[3] ** 2\n", - " corr = s1 * s2 * jnp.tanh(theta[4])\n", - " cov = jnp.array([[s1**2, corr], [corr, s2**2]])\n", - " cov = jsp.linalg.block_diag(*[cov for _ in range(4)])\n", - " p = distrax.MultivariateNormalFullCovariance(mu, cov)\n", - " return p.log_prob(y)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e07a4d59-80f8-4303-982e-fabf2acf24c5", - "metadata": {}, - "outputs": [], - "source": [ - "prior = distrax.Independent(\n", - " distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d898ca02-22f3-4102-b667-c2ef7abdd492", - "metadata": {}, - "outputs": [], - "source": [ - "def log_density_fn(theta, y):\n", - " prior_lp = prior.log_prob(theta)\n", - " likelihood_lp = likelihood_fn(theta, y)\n", - "\n", - " lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)\n", - " return lp\n", - "\n", - "target_log_prob_fn_partial = partial(log_density_fn, y=y)\n", - "target_log_prob_fn = lambda theta: jax.vmap(target_log_prob_fn_partial)(theta)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "55e0fa2c-6aaa-4ab6-922a-1b5f2559b5d1", - "metadata": {}, - "outputs": [], - "source": [ - "n_samples = 10000\n", - "n_warmup = 5000\n", - "n_chains = 4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "263b64ec-8640-4f61-8188-0f0953d53668", - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'functools.partial' object has no attribute 'experimental_default_event_space_bijector'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtfp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmcmc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwindowed_adaptive_nuts\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget_log_prob_fn_partial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_warmup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_step_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_tree_depth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_energy_diff\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000.0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43munrolled_leapfrog_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mparallel_iterations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:710\u001b[0m, in \u001b[0;36mwindowed_adaptive_nuts\u001b[0;34m(n_draws, joint_dist, n_chains, num_adaptation_steps, current_state, init_step_size, dual_averaging_kwargs, max_tree_depth, max_energy_diff, unrolled_leapfrog_steps, parallel_iterations, trace_fn, return_final_kernel_results, discard_tuning, chain_axis_names, seed, **pins)\u001b[0m\n\u001b[1;32m 703\u001b[0m dual_averaging_kwargs\u001b[38;5;241m.\u001b[39msetdefault(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtarget_accept_prob\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m0.85\u001b[39m)\n\u001b[1;32m 704\u001b[0m proposal_kernel_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 705\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep_size\u001b[39m\u001b[38;5;124m'\u001b[39m: init_step_size,\n\u001b[1;32m 706\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_tree_depth\u001b[39m\u001b[38;5;124m'\u001b[39m: max_tree_depth,\n\u001b[1;32m 707\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_energy_diff\u001b[39m\u001b[38;5;124m'\u001b[39m: max_energy_diff,\n\u001b[1;32m 708\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124munrolled_leapfrog_steps\u001b[39m\u001b[38;5;124m'\u001b[39m: unrolled_leapfrog_steps,\n\u001b[1;32m 709\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparallel_iterations\u001b[39m\u001b[38;5;124m'\u001b[39m: parallel_iterations}\n\u001b[0;32m--> 710\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_windowed_adaptive_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 711\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_draws\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_draws\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 712\u001b[0m \u001b[43m \u001b[49m\u001b[43mjoint_dist\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjoint_dist\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 713\u001b[0m \u001b[43m \u001b[49m\u001b[43mkind\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mnuts\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 714\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_chains\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[43m \u001b[49m\u001b[43mproposal_kernel_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproposal_kernel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 716\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 717\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_adaptation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 718\u001b[0m \u001b[43m \u001b[49m\u001b[43mdual_averaging_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdual_averaging_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 719\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrace_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrace_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 720\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_final_kernel_results\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_final_kernel_results\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 721\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiscard_tuning\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdiscard_tuning\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 722\u001b[0m \u001b[43m \u001b[49m\u001b[43mchain_axis_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchain_axis_names\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 723\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 724\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpins\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:890\u001b[0m, in \u001b[0;36m_windowed_adaptive_impl\u001b[0;34m(n_draws, joint_dist, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, chain_axis_names, **pins)\u001b[0m\n\u001b[1;32m 885\u001b[0m dual_averaging_kwargs\u001b[38;5;241m.\u001b[39msetdefault(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mexperimental_reduce_chain_axis_names\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 886\u001b[0m chain_axis_names)\n\u001b[1;32m 887\u001b[0m setup_seed, sample_seed \u001b[38;5;241m=\u001b[39m samplers\u001b[38;5;241m.\u001b[39msplit_seed(\n\u001b[1;32m 888\u001b[0m samplers\u001b[38;5;241m.\u001b[39msanitize_seed(seed), n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 889\u001b[0m (target_log_prob_fn, initial_transformed_position, bijector,\n\u001b[0;32m--> 890\u001b[0m step_broadcast, batch_shape, shard_axis_names) \u001b[38;5;241m=\u001b[39m \u001b[43m_setup_mcmc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 891\u001b[0m \u001b[43m \u001b[49m\u001b[43mjoint_dist\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 892\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_chains\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 893\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcurrent_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 894\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msetup_seed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 895\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpins\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 897\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m proposal_kernel_kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep_size\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 898\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_shape\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m (\u001b[38;5;241m0\u001b[39m,): \u001b[38;5;66;03m# Scalar batch has a 0-vector shape.\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:226\u001b[0m, in \u001b[0;36m_setup_mcmc\u001b[0;34m(model, n_chains, init_position, seed, **pins)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Construct bijector and transforms needed for windowed MCMC.\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \n\u001b[1;32m 194\u001b[0m \u001b[38;5;124;03mThis pins the initial model, constructs a bijector that unconstrains and\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;124;03m shard_axis_names: Shard axis names for the model\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 225\u001b[0m pinned_model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mexperimental_pin(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpins) \u001b[38;5;28;01mif\u001b[39;00m pins \u001b[38;5;28;01melse\u001b[39;00m model\n\u001b[0;32m--> 226\u001b[0m bijector, step_bijector \u001b[38;5;241m=\u001b[39m \u001b[43m_get_flat_unconstraining_bijector\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpinned_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m init_position \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 229\u001b[0m raw_init_dist \u001b[38;5;241m=\u001b[39m initialization\u001b[38;5;241m.\u001b[39minit_near_unconstrained_zero(pinned_model)\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:172\u001b[0m, in \u001b[0;36m_get_flat_unconstraining_bijector\u001b[0;34m(jd_model)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Create a bijector from a joint distribution that flattens and unconstrains.\u001b[39;00m\n\u001b[1;32m 153\u001b[0m \n\u001b[1;32m 154\u001b[0m \u001b[38;5;124;03mThe intention is (loosely) to go from a model joint distribution supported on\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;124;03m points, and the second may be used to initialize a step size.\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# TODO(b/180396233): This bijector is in general point-dependent.\u001b[39;00m\n\u001b[0;32m--> 172\u001b[0m event_space_bij \u001b[38;5;241m=\u001b[39m \u001b[43mjd_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental_default_event_space_bijector\u001b[49m()\n\u001b[1;32m 173\u001b[0m flat_bijector \u001b[38;5;241m=\u001b[39m restructure\u001b[38;5;241m.\u001b[39mpack_sequence_as(jd_model\u001b[38;5;241m.\u001b[39mevent_shape_tensor())\n\u001b[1;32m 175\u001b[0m unconstrained_shapes \u001b[38;5;241m=\u001b[39m event_space_bij(\n\u001b[1;32m 176\u001b[0m flat_bijector)\u001b[38;5;241m.\u001b[39minverse_event_shape_tensor(jd_model\u001b[38;5;241m.\u001b[39mevent_shape_tensor())\n", - "\u001b[0;31mAttributeError\u001b[0m: 'functools.partial' object has no attribute 'experimental_default_event_space_bijector'" - ] - } - ], - "source": [ - "nuts = tfp.mcmc.NoUTurnSampler(\n", - " target_log_prob_fn,\n", - " step_size=0.1,\n", - " max_tree_depth=10,\n", - " max_energy_diff=1000.0,\n", - " unrolled_leapfrog_steps=1,\n", - ")\n", - "\n", - "\n", - "nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(\n", - " inner_kernel=nuts,\n", - " num_adaptation_steps=int(0.8 * n_warmup),\n", - " target_accept_prob=jnp.asarray(0.75, jnp.float32)\n", - ")\n", - "\n", - "tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(\n", - " nuts,\n", - " initial_running_variance,\n", - " num_estimation_steps=None,\n", - " momentum_distribution_setter_fn=hmc_like_momentum_distribution_setter_fn,\n", - " momentum_distribution_getter_fn=hmc_like_momentum_distribution_getter_fn,\n", - " validate_args=False,\n", - " experimental_shard_axis_names=None,\n", - " name=None\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d826bb25-929a-4273-aab6-d281a54a0844", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.device_count()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b59b71ae-963c-43f1-ab00-4d3881354864", - "metadata": {}, - "outputs": [], - "source": [ - "initial_states = jr.normal(jr.PRNGKey(4), shape=(n_chains, 5))\n", - "samples = tfp.mcmc.sample_chain(\n", - " num_results=n_samples - n_warmup,\n", - " current_state=initial_states,\n", - " num_steps_between_results=1,\n", - " kernel=adaptive_sampler,\n", - " num_burnin_steps=n_warmup,\n", - " trace_fn=None,\n", - " seed=jr.PRNGKey(2),\n", - ")\n", - "samples = samples[n_warmup:, ...]" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "99db6c8b-49db-435e-a886-e5968106c512", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.devices()" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "e6d8884c-1153-48b3-867a-cac4020e344b", - "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[24], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msfunc\u001b[39m(x): \n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m: \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43msfunc\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/api.py:2253\u001b[0m, in \u001b[0;36m_cpp_pmap..cache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 2251\u001b[0m execute: Optional[functools\u001b[38;5;241m.\u001b[39mpartial] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 2252\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(top_trace, core\u001b[38;5;241m.\u001b[39mEvalTrace):\n\u001b[0;32m-> 2253\u001b[0m execute \u001b[38;5;241m=\u001b[39m \u001b[43mpxla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mxla_pmap_impl_lazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2254\u001b[0m out \u001b[38;5;241m=\u001b[39m map_bind_continuation(execute(\u001b[38;5;241m*\u001b[39mtracers))\n\u001b[1;32m 2255\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:974\u001b[0m, in \u001b[0;36mxla_pmap_impl_lazy\u001b[0;34m(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)\u001b[0m\n\u001b[1;32m 972\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _emap_apply_fn\n\u001b[1;32m 973\u001b[0m abstract_args \u001b[38;5;241m=\u001b[39m unsafe_map(xla\u001b[38;5;241m.\u001b[39mabstractify, args)\n\u001b[0;32m--> 974\u001b[0m compiled_fun, fingerprint \u001b[38;5;241m=\u001b[39m \u001b[43mparallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 975\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_axis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 976\u001b[0m \u001b[43m \u001b[49m\u001b[43min_axes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_axes_thunk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mabstract_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# Don't re-abstractify args unless logging is enabled for performance.\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39mjax_distributed_debug:\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/linear_util.py:303\u001b[0m, in \u001b[0;36mcache..memoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 301\u001b[0m fun\u001b[38;5;241m.\u001b[39mpopulate_stores(stores)\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 303\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 304\u001b[0m cache[key] \u001b[38;5;241m=\u001b[39m (ans, fun\u001b[38;5;241m.\u001b[39mstores)\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ans\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1245\u001b[0m, in \u001b[0;36mparallel_callable\u001b[0;34m(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *avals)\u001b[0m\n\u001b[1;32m 1232\u001b[0m \u001b[38;5;129m@lu\u001b[39m\u001b[38;5;241m.\u001b[39mcache\n\u001b[1;32m 1233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_callable\u001b[39m(fun: lu\u001b[38;5;241m.\u001b[39mWrappedFun,\n\u001b[1;32m 1234\u001b[0m backend_name: Optional[\u001b[38;5;28mstr\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1243\u001b[0m global_arg_shapes: Sequence[Optional[Tuple[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]]],\n\u001b[1;32m 1244\u001b[0m \u001b[38;5;241m*\u001b[39mavals):\n\u001b[0;32m-> 1245\u001b[0m pmap_computation \u001b[38;5;241m=\u001b[39m \u001b[43mlower_parallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1246\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_axis_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43min_axes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_axes_thunk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mavals\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1248\u001b[0m pmap_executable \u001b[38;5;241m=\u001b[39m pmap_computation\u001b[38;5;241m.\u001b[39mcompile()\n\u001b[1;32m 1249\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m WeakRefList([pmap_executable\u001b[38;5;241m.\u001b[39munsafe_call, pmap_executable\u001b[38;5;241m.\u001b[39mfingerprint])\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1414\u001b[0m, in \u001b[0;36mlower_parallel_callable\u001b[0;34m(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, avals)\u001b[0m\n\u001b[1;32m 1409\u001b[0m must_run_on_all_devices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 1411\u001b[0m pci \u001b[38;5;241m=\u001b[39m ParallelCallableInfo(\n\u001b[1;32m 1412\u001b[0m name, backend, axis_name, axis_size, global_axis_size, devices,\n\u001b[1;32m 1413\u001b[0m in_axes, out_axes_thunk, avals)\n\u001b[0;32m-> 1414\u001b[0m jaxpr, consts, replicas, parts, shards \u001b[38;5;241m=\u001b[39m \u001b[43mstage_parallel_callable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1415\u001b[0m \u001b[43m \u001b[49m\u001b[43mpci\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_arg_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1417\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m logger\u001b[38;5;241m.\u001b[39misEnabledFor(logging\u001b[38;5;241m.\u001b[39mDEBUG):\n\u001b[1;32m 1418\u001b[0m logger\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msharded_avals: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, shards\u001b[38;5;241m.\u001b[39msharded_avals)\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/pxla.py:1321\u001b[0m, in \u001b[0;36mstage_parallel_callable\u001b[0;34m(pci, fun, global_arg_shapes)\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m core\u001b[38;5;241m.\u001b[39mextend_axis_env(pci\u001b[38;5;241m.\u001b[39maxis_name, pci\u001b[38;5;241m.\u001b[39mglobal_axis_size, \u001b[38;5;28;01mNone\u001b[39;00m): \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 1318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m dispatch\u001b[38;5;241m.\u001b[39mlog_elapsed_time(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFinished tracing + transforming \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1319\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfor pmap in \u001b[39m\u001b[38;5;132;01m{elapsed_time}\u001b[39;00m\u001b[38;5;124m sec\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1320\u001b[0m event\u001b[38;5;241m=\u001b[39mdispatch\u001b[38;5;241m.\u001b[39mJAXPR_TRACE_EVENT):\n\u001b[0;32m-> 1321\u001b[0m jaxpr, out_sharded_avals, consts \u001b[38;5;241m=\u001b[39m \u001b[43mpe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrace_to_jaxpr_final\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1322\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_sharded_avals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdebug_info_final\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpmap\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1323\u001b[0m jaxpr \u001b[38;5;241m=\u001b[39m dispatch\u001b[38;5;241m.\u001b[39mapply_outfeed_rewriter(jaxpr)\n\u001b[1;32m 1325\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out_sharded_avals) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(pci\u001b[38;5;241m.\u001b[39mout_axes), (\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28mlen\u001b[39m(out_sharded_avals), \u001b[38;5;28mlen\u001b[39m(pci\u001b[38;5;241m.\u001b[39mout_axes))\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:2065\u001b[0m, in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info, keep_inputs)\u001b[0m\n\u001b[1;32m 2063\u001b[0m main\u001b[38;5;241m.\u001b[39mjaxpr_stack \u001b[38;5;241m=\u001b[39m () \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 2064\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m core\u001b[38;5;241m.\u001b[39mnew_sublevel():\n\u001b[0;32m-> 2065\u001b[0m jaxpr, out_avals, consts \u001b[38;5;241m=\u001b[39m \u001b[43mtrace_to_subjaxpr_dynamic\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2066\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_avals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdebug_info\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdebug_info\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2067\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m fun, main\n\u001b[1;32m 2068\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jaxpr, out_avals, consts\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1998\u001b[0m, in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals, keep_inputs, debug_info)\u001b[0m\n\u001b[1;32m 1996\u001b[0m in_tracers \u001b[38;5;241m=\u001b[39m _input_type_to_tracers(trace\u001b[38;5;241m.\u001b[39mnew_arg, in_avals)\n\u001b[1;32m 1997\u001b[0m in_tracers_ \u001b[38;5;241m=\u001b[39m [t \u001b[38;5;28;01mfor\u001b[39;00m t, keep \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(in_tracers, keep_inputs) \u001b[38;5;28;01mif\u001b[39;00m keep]\n\u001b[0;32m-> 1998\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[43mfun\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall_wrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43min_tracers_\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1999\u001b[0m out_tracers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmap\u001b[39m(trace\u001b[38;5;241m.\u001b[39mfull_raise, ans)\n\u001b[1;32m 2000\u001b[0m jaxpr, consts \u001b[38;5;241m=\u001b[39m frame\u001b[38;5;241m.\u001b[39mto_jaxpr(out_tracers)\n", - "File \u001b[0;32m~/miniconda3/envs/sbi-dev/lib/python3.9/site-packages/jax/linear_util.py:167\u001b[0m, in \u001b[0;36mWrappedFun.call_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m gen \u001b[38;5;241m=\u001b[39m gen_static_args \u001b[38;5;241m=\u001b[39m out_store \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 167\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# Some transformations yield from inside context managers, so we have to\u001b[39;00m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# interrupt them before reraising the exception. Otherwise they will only\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# get garbage-collected at some later time, running their cleanup tasks\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# only after this exception is handled, which can corrupt the global\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;66;03m# state.\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m stack:\n", - "Cell \u001b[0;32mIn[24], line 2\u001b[0m, in \u001b[0;36msfunc\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msfunc\u001b[39m(x): \n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m: \u001b[38;5;28;01mpass\u001b[39;00m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "def sfunc(x): \n", - " while True: pass\n", - "\n", - "jax.pmap(sfunc)(jnp.arange(4))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "9032ad64-c70e-4acb-9541-8ef14a36da2b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "numpy: cpu usage 0.6/8 wall_time:1.7s\n", - "vmap: cpu usage 1.8/8 wall_time:1.9s\n", - "dot: cpu usage 7.0/8 wall_time:3.7s\n" - ] - } - ], - "source": [ - "import time, os, jax, numpy as np, jax.numpy as jnp\n", - "jax.config.update('jax_platform_name', 'cpu') # insures we use the CPU\n", - "\n", - "def timer(name, f, x, shouldBlock=True):\n", - " # warmup\n", - " y = f(x).block_until_ready() if shouldBlock else f(x)\n", - " # running the code\n", - " start_wall = time.perf_counter()\n", - " start_cpu = time.process_time()\n", - " y = f(x).block_until_ready() if shouldBlock else f(x)\n", - " end_wall = time.perf_counter()\n", - " end_cpu = time.process_time()\n", - " # computing the metric and displaying it\n", - " wall_time = end_wall - start_wall\n", - " cpu_time = end_cpu - start_cpu\n", - " cpu_count = os.cpu_count()\n", - " print(f\"{name}: cpu usage {cpu_time/wall_time:.1f}/{cpu_count} wall_time:{wall_time:.1f}s\")\n", - "\n", - "# test functions\n", - "key = jax.random.PRNGKey(0)\n", - "x = jax.random.normal(key, shape=(500000000,), dtype=jnp.float64)\n", - "x_mat = jax.random.normal(key, shape=(10000,10000), dtype=jnp.float64)\n", - "f_numpy = np.cos\n", - "f_vmap = jax.jit(jax.vmap(jnp.cos))\n", - "f_dot = jax.jit(lambda x: jnp.dot(x,x.T)) # to show that JAX can indeed use all cores\n", - "\n", - "timer('numpy', f_numpy, x, shouldBlock=False)\n", - "timer('vmap', f_vmap, x)\n", - "timer('dot', f_dot, x_mat)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "abf84ca0-52d8-4469-82d5-1e9099a1d05b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "os.cpu_count()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "1cef2b44-f3ac-43e0-ba1e-604c8fa8bfc4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n" - ] - } - ], - "source": [ - "from jax import local_device_count\n", - "print(jax.local_devices()) #1 instead of 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9dd7daeb-449e-4111-9170-f7937b495991", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "sbi-dev", - "language": "python", - "name": "sbi-dev" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/bivariate_gaussian_smcabc.py b/examples/bivariate_gaussian_smcabc.py index 6a1014e..99d2e37 100644 --- a/examples/bivariate_gaussian_smcabc.py +++ b/examples/bivariate_gaussian_smcabc.py @@ -6,7 +6,8 @@ import jax import matplotlib.pyplot as plt import seaborn as sns -from jax import numpy as jnp, random as jr +from jax import numpy as jnp +from jax import random as jr from sbijax import SMCABC @@ -43,8 +44,7 @@ def run(): smc = SMCABC(fns, summary_fn, distance_fn) smc_samples, _ = smc.sample_posterior( - jr.PRNGKey(22), y_observed, - 10, 1000, 1000, 0.6, 500 + jr.PRNGKey(22), y_observed, 10, 1000, 1000, 0.6, 500 ) fig, axes = plt.subplots(2) diff --git a/sbijax/generator.py b/sbijax/generator.py index 0f0c7bc..9ba321d 100644 --- a/sbijax/generator.py +++ b/sbijax/generator.py @@ -2,7 +2,8 @@ import chex from jax import lax -from jax import numpy as jnp, random as jr +from jax import numpy as jnp +from jax import random as jr named_dataset = namedtuple("named_dataset", "y theta")