diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index 85960f1..1b98fed 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -7,9 +7,13 @@ import jax.numpy as jnp import numpy as np +import matplotlib.pyplot as plt + +import numpyro import warnings + def infer_device(device_preference: str = None): """ Returns a JAX device based on the specified preference. @@ -244,7 +248,7 @@ def flatten_params_dict(params_dict: Dict[str, Any]) -> Dict[str, Any]: def set_fn(func: Callable) -> Callable: """ - Transforms the given deterministic function to use a params dictionary + Transforms a given deterministic function to use a params dictionary for its parameters, excluding the first one (assumed to be the dependent variable). Args: @@ -276,4 +280,11 @@ def set_fn(func: Callable) -> Callable: exec(transformed_code, globals(), local_namespace) # Return the transformed function - return local_namespace[func.__name__] \ No newline at end of file + return local_namespace[func.__name__] + + +def plot_rhats(samples): + sgr = numpyro.diagnostics.split_gelman_rubin + rhats = [sgr(v).flatten() for (k, v) in samples.items() if k.endswith('kernel')] + rhats = np.concatenate(rhats) + plt.hist(rhats, bins=20, color='green', alpha=0.6); \ No newline at end of file