From 26837ba0b4b99483f5062d79fb5b0a13d27eccff Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Mon, 27 May 2024 15:26:38 -0400 Subject: [PATCH] Fix a bug of plotter which contains inf (#165) * filter out inf * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nan filter * fix comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change to isfinite * fix a bug --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- appletree/plot.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/appletree/plot.py b/appletree/plot.py index e587c475..b59bfd9d 100644 --- a/appletree/plot.py +++ b/appletree/plot.py @@ -27,11 +27,23 @@ def __init__(self, backend_file_name, discard=0, thin=1): backend = emcee.backends.HDFBackend(self.backend_file_name, read_only=True) self.chain = backend.get_chain(discard=discard, thin=thin) - self.flat_chain = backend.get_chain(discard=discard, thin=thin, flat=True) self.posterior = backend.get_log_prob(discard=discard, thin=thin) - self.flat_posterior = backend.get_log_prob(discard=discard, thin=thin, flat=True) self.prior = backend.get_blobs(discard=discard, thin=thin) + # We drop iterations with inf and nan posterior + mask = np.isfinite(self.posterior) + mask = np.all(mask, axis=1) + self.chain = self.chain[mask] + self.posterior = self.posterior[mask] + self.prior = self.prior[mask] + + self.flat_chain = backend.get_chain(discard=discard, thin=thin, flat=True) + self.flat_posterior = backend.get_log_prob(discard=discard, thin=thin, flat=True) self.flat_prior = backend.get_blobs(discard=discard, thin=thin, flat=True) + # We drop samples with inf and nan posterior + mask = np.isfinite(self.flat_posterior) + self.flat_chain = self.flat_chain[mask] + self.flat_posterior = self.flat_posterior[mask] + self.flat_prior = self.flat_prior[mask] with h5py.File(self.backend_file_name, "r") as f: self.param_names = f["mcmc"].attrs["parameter_fit"] @@ -199,7 +211,8 @@ def plot_corner(self, fig=None): if fig is None: fig = plt.figure(figsize=(2 * (self.n_param + 2), 2 * (self.n_param + 2))) samples = np.concatenate( - (self.flat_chain, self.flat_posterior[:, None], self.flat_prior[:, None]), axis=1 + (self.flat_chain, self.flat_posterior[:, None], self.flat_prior[:, None]), + axis=1, ) labels = np.concatenate((self.param_names, ["log posterior", "log prior"]))