From 49e861f1946cb6066e9dfd46cf5d2d4ef4f77870 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Tue, 26 Nov 2024 18:58:56 -0800 Subject: [PATCH] Add utility for plotting a histogram of rhat vals for weights --- neurobayes/utils/utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index 85960f1..1b98fed 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -7,9 +7,13 @@ import jax.numpy as jnp import numpy as np +import matplotlib.pyplot as plt + +import numpyro import warnings + def infer_device(device_preference: str = None): """ Returns a JAX device based on the specified preference. @@ -244,7 +248,7 @@ def flatten_params_dict(params_dict: Dict[str, Any]) -> Dict[str, Any]: def set_fn(func: Callable) -> Callable: """ - Transforms the given deterministic function to use a params dictionary + Transforms a given deterministic function to use a params dictionary for its parameters, excluding the first one (assumed to be the dependent variable). Args: @@ -276,4 +280,11 @@ def set_fn(func: Callable) -> Callable: exec(transformed_code, globals(), local_namespace) # Return the transformed function - return local_namespace[func.__name__] \ No newline at end of file + return local_namespace[func.__name__] + + +def plot_rhats(samples): + sgr = numpyro.diagnostics.split_gelman_rubin + rhats = [sgr(v).flatten() for (k, v) in samples.items() if k.endswith('kernel')] + rhats = np.concatenate(rhats) + plt.hist(rhats, bins=20, color='green', alpha=0.6); \ No newline at end of file