Skip to content

Commit

Permalink
Add utility for plotting a histogram of rhat vals for weights
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Nov 27, 2024
1 parent db9f3ba commit 49e861f
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions neurobayes/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import jax.numpy as jnp

import numpy as np
import matplotlib.pyplot as plt

import numpyro

import warnings


def infer_device(device_preference: str = None):
"""
Returns a JAX device based on the specified preference.
Expand Down Expand Up @@ -244,7 +248,7 @@ def flatten_params_dict(params_dict: Dict[str, Any]) -> Dict[str, Any]:

def set_fn(func: Callable) -> Callable:
"""
Transforms the given deterministic function to use a params dictionary
Transforms a given deterministic function to use a params dictionary
for its parameters, excluding the first one (assumed to be the dependent variable).
Args:
Expand Down Expand Up @@ -276,4 +280,11 @@ def set_fn(func: Callable) -> Callable:
exec(transformed_code, globals(), local_namespace)

# Return the transformed function
return local_namespace[func.__name__]
return local_namespace[func.__name__]


def plot_rhats(samples):
sgr = numpyro.diagnostics.split_gelman_rubin
rhats = [sgr(v).flatten() for (k, v) in samples.items() if k.endswith('kernel')]
rhats = np.concatenate(rhats)
plt.hist(rhats, bins=20, color='green', alpha=0.6);

0 comments on commit 49e861f

Please sign in to comment.