From b064bda3dce18b47522029b0f1184eb096385274 Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Wed, 6 Mar 2024 02:59:44 -0500 Subject: [PATCH] Plotter for MCMC diagnostics (#146) * add functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more documentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm duplicated import * change to isinstance * modify * remove duplicated legends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bug * add kwargs * set power * Add test of plot functions * Remove redundant codes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: dachengx --- .gitignore | 1 + appletree/__init__.py | 4 +- appletree/plot.py | 327 ++++++++++++++++++++++++++++++++++++++++++ tests/test_context.py | 4 +- tests/test_plot.py | 19 +++ 5 files changed, 352 insertions(+), 3 deletions(-) create mode 100644 appletree/plot.py create mode 100644 tests/test_plot.py diff --git a/.gitignore b/.gitignore index c77f53a3..c9cf8f36 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ *.egg-info *.eggs *.svg +*.h5 build docs/build docs/source/_build diff --git a/appletree/__init__.py b/appletree/__init__.py index 6589b272..cac0b972 100644 --- a/appletree/__init__.py +++ b/appletree/__init__.py @@ -35,6 +35,8 @@ from .context import * +from .plot import * + # check CUDA support setup from warnings import warn @@ -59,4 +61,4 @@ print("Using aptext package from https://github.com/XENONnT/applefiles") except ImportError: HAVE_APTEXT = False - print("Can not find aptext") + print("Cannot find aptext") diff --git a/appletree/plot.py b/appletree/plot.py new file mode 100644 index 00000000..e587c475 --- /dev/null +++ b/appletree/plot.py @@ -0,0 +1,327 @@ +from warnings import warn +import json +import numpy as np +from scipy.stats import norm +import h5py +import emcee +import corner +import matplotlib +import matplotlib.cm as cm +import matplotlib.pyplot as plt + +from appletree.utils import errors_to_two_half_norm_sigmas +from appletree.randgen import TwoHalfNorm + + +class Plotter: + def __init__(self, backend_file_name, discard=0, thin=1): + """Plotter for the MCMC chain. + + Args: + backend_file_name: the file name of the backend file. + discard: the number of iterations to discard. + thin: use samples every thin steps. + + """ + self.backend_file_name = backend_file_name + backend = emcee.backends.HDFBackend(self.backend_file_name, read_only=True) + + self.chain = backend.get_chain(discard=discard, thin=thin) + self.flat_chain = backend.get_chain(discard=discard, thin=thin, flat=True) + self.posterior = backend.get_log_prob(discard=discard, thin=thin) + self.flat_posterior = backend.get_log_prob(discard=discard, thin=thin, flat=True) + self.prior = backend.get_blobs(discard=discard, thin=thin) + self.flat_prior = backend.get_blobs(discard=discard, thin=thin, flat=True) + + with h5py.File(self.backend_file_name, "r") as f: + self.param_names = f["mcmc"].attrs["parameter_fit"] + self.param_prior = json.loads(f["mcmc"].attrs["par_config"]) + + param_mpe = self.flat_chain[np.argmax(self.flat_posterior), :] + self.param_mpe = {key: param_mpe[i] for i, key in enumerate(self.param_names)} + + self.n_iter, self.n_walker, self.n_param = self.chain.shape + + def make_all_plots(self, save=False, save_path=".", fmt=["png", "pdf"], **save_kwargs): + """Make all plots and save them if save is True. + + The plot styles are default. save_kwargs will be passed to fig.savefig(). + + """ + + def save_fig(fig, name, fmt): + if isinstance(fmt, str): + fmt = [fmt] + for f in fmt: + fig.savefig(f"{save_path}/{name}.{f}", **save_kwargs) + + fig, axes = self.plot_burn_in() + if save: + save_fig(fig, "burn_in", fmt) + + fig, axes = self.plot_marginal_posterior() + if save: + save_fig(fig, "marginal_posterior", fmt) + + fig, axes = self.plot_corner() + if save: + save_fig(fig, "corner", fmt) + + fig, axes = self.plot_autocorr() + if save: + save_fig(fig, "autocorr", fmt) + + @staticmethod + def _norm_pdf(x, mean, std): + return np.exp(-((x - mean) ** 2) / std**2 / 2) / np.sqrt(2 * np.pi) / std + + @staticmethod + def _uniform_pdf(x, lower, upper): + return np.full_like(x, 1 / (upper - lower)) + + @staticmethod + def _thn_pdf(x, mu, sigma_pos, sigma_neg): + # Convert errors to sigmas + sigma_pos, sigma_neg = errors_to_two_half_norm_sigmas((sigma_pos, sigma_neg)) + return np.exp(TwoHalfNorm.logpdf(x, mu, sigma_pos, sigma_neg)) + + def plot_burn_in(self, fig=None, **plot_kwargs): + """Plot the burn-in of the chain, the log posterior and the log prior. + + Args: + fig: the figure to plot on. If None, a new figure will be created. + plot_kwargs: the keyword arguments passed to plt.plot(). + Returns: + fig: the figure. + axes: the axes of the figure. + + """ + n_cols = 2 + n_rows = int(np.ceil((self.n_param + 2) / n_cols)) + + if fig is None: + fig = plt.figure(figsize=(10, 1.5 * n_rows)) + plot_kwargs.setdefault("lw", 0.1) + + axes = [] + for i in range(self.n_param): + ax = fig.add_subplot(n_rows, n_cols, i + 1) + ax.plot(self.chain[:, :, i], **plot_kwargs) + ax.set_ylabel(self.param_names[i]) + ax.set_xlim(0, self.n_iter) + axes.append(ax) + + ax = fig.add_subplot(n_rows, n_cols, self.n_param + 1) + ax.plot(self.posterior, **plot_kwargs) + ax.set_ylabel("log posterior") + ax.set_xlim(0, self.n_iter) + ax.set_ylim(self.posterior.max() - 100, self.posterior.max()) + axes.append(ax) + + ax = fig.add_subplot(n_rows, n_cols, self.n_param + 2) + ax.plot(self.prior, **plot_kwargs) + ax.set_ylabel("log prior") + ax.set_xlim(0, self.n_iter) + ax.set_ylim(self.prior.max() - 100, self.prior.max()) + axes.append(ax) + + # Set xlabels of the last two axes + axes[-1].set_xlabel("Number of iterations") + axes[-2].set_xlabel("Number of iterations") + + plt.tight_layout() + return fig, axes + + def plot_marginal_posterior(self, fig=None, **hist_kwargs): + """Plot the marginal posterior distribution of each parameter. + + Args: + fig: the figure to plot on. If None, a new figure will be created. + hist_kwargs: the keyword arguments passed to plt.hist(). + Returns: + fig: the figure. + axes: the axes of the figure. + + """ + n_cols = 2 + n_rows = int(np.ceil(self.n_param / n_cols)) + + if fig is None: + fig = plt.figure(figsize=(10, 2 * n_rows)) + hist_kwargs.setdefault("histtype", "step") + hist_kwargs.setdefault("bins", 50) + hist_kwargs.setdefault("color", "k") + + pdf = { + "norm": self._norm_pdf, + "uniform": self._uniform_pdf, + "twohalfnorm": self._thn_pdf, + } + + axes = [] + for i in range(self.n_param): + ax = fig.add_subplot(n_rows, n_cols, i + 1) + ax.hist(self.flat_chain[:, i], density=True, label="Posterior", **hist_kwargs) + prior = self.param_prior[self.param_names[i]] + prior_type = prior["prior_type"] + args = prior["prior_args"] + if prior_type != "free": + x = np.linspace(*ax.get_xlim(), 100) + ax.plot(x, pdf[prior_type](x, **args), color="grey", ls="--", label="Prior") + ax.set_xlabel(self.param_names[i]) + ax.set_ylabel("PDF") + ax.set_ylim(0, None) + ax.yaxis.get_major_formatter().set_powerlimits((0, 1)) + axes.append(ax) + + # Set legend + handles, labels = axes[-1].get_legend_handles_labels() + fig.legend( + loc="lower center", + handles=handles, + labels=labels, + bbox_to_anchor=(0.5, 1.0), + ) + + plt.tight_layout() + return fig, axes + + def plot_corner(self, fig=None): + """Plot the corner plot of the chain, the log posterior and the log prior. + + Args: + fig: the figure to plot on. If None, a new figure will be created. + Returns: + fig: the figure. + axes: the axes of the figure. + + """ + if fig is None: + fig = plt.figure(figsize=(2 * (self.n_param + 2), 2 * (self.n_param + 2))) + samples = np.concatenate( + (self.flat_chain, self.flat_posterior[:, None], self.flat_prior[:, None]), axis=1 + ) + labels = np.concatenate((self.param_names, ["log posterior", "log prior"])) + + corner.corner( + samples, + labels=labels, + quantiles=norm.cdf([-1, 0, 1]), + hist_kwargs={"density": True}, + fig=fig, + ) + + axes = np.array(fig.axes).reshape((self.n_param + 2, self.n_param + 2)) + corr_matrix = np.corrcoef(samples, rowvar=False) + normalize = matplotlib.colors.Normalize(vmin=-1, vmax=1) + cmap = cm.coolwarm + m = cm.ScalarMappable(norm=normalize, cmap=cmap) + + for yi in range(self.n_param + 2): + for xi in range(yi): + ax = axes[yi, xi] + corr = corr_matrix[yi, xi] + ax.set_facecolor(m.to_rgba(corr, alpha=0.5)) + + for i in range(self.n_param + 2): + key = labels[i] + ax = axes[i, i] + if key in self.param_prior: + prior = self.param_prior[key] + x = np.linspace(*ax.get_xbound(), 101) + if key in self.param_names: + ax.axvline(self.param_mpe[key], color="r") + if prior["prior_type"] == "norm": + ax.plot(x, self._norm_pdf(x, **prior["prior_args"]), color="b") + elif prior["prior_type"] == "uniform": + ax.plot(x, self._uniform_pdf(x, **prior["prior_args"]), color="b") + elif prior["prior_type"] == "twohalfnorm": + ax.plot(x, self._thn_pdf(x, **prior["prior_args"]), color="b") + + return fig, axes + + def plot_autocorr(self, fig=None, **plot_kwargs): + """Plot the autocorrelation time of each parameter, as the diagnostic of the convergence. + + Args: + fig: the figure to plot on. If None, a new figure will be created. + plot_kwargs: the keyword arguments passed to plt.plot(). + Returns: + fig: the figure. + axes: the axes of the figure. + + """ + n_cols = 2 + n_rows = int(np.ceil(self.n_param / n_cols)) + + if fig is None: + fig = plt.figure(figsize=(10, 3 * n_rows)) + plot_kwargs.setdefault("marker", "o") + + def autocorr_func_1d(x, norm=True): + x = np.atleast_1d(x) + if len(x.shape) != 1: + raise ValueError("invalid dimensions for 1D autocorrelation function") + n = next_pow_two(len(x)) + f = np.fft.fft(x - np.mean(x), n=2 * n) + acf = np.fft.ifft(f * np.conjugate(f))[: len(x)].real + acf /= 4 * n + if norm: + acf /= acf[0] + return acf + + def next_pow_two(n): + i = 1 + while i < n: + i = i << 1 + return i + + def auto_window(taus, c): + m = np.arange(len(taus)) < c * taus + if np.any(m): + return np.argmin(m) + return len(taus) - 1 + + def autocorr_new(y, c=5.0): + f = np.zeros(y.shape[1]) + for yy in y: + f += autocorr_func_1d(yy) + f /= len(y) + taus = 2.0 * np.cumsum(f) - 1.0 + window = auto_window(taus, c) + return taus[window] + + if self.n_iter < 1000: + warn("The chain is too short to compute the autocorrelation time!") + + N = np.geomspace(100, self.n_iter, 10).astype(int) + axes = [] + for i in range(self.n_param): + chain = self.chain[:, :, i].T + tau = np.empty(len(N)) + for j, n in enumerate(N): + tau[j] = autocorr_new(chain[:, :n]) + + ax = fig.add_subplot(n_rows, n_cols, i + 1) + ax.plot(N, tau, label="Sample estimation", **plot_kwargs) + ax.plot(N, N / 50, "k--", label="N / 50") + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_ylabel(f"Auto correlation of {self.param_names[i]}") + axes.append(ax) + + # Set xlabels of the last two axes + axes[-1].set_xlabel("Number of iterations") + axes[-2].set_xlabel("Number of iterations") + + # Set legend + handles, labels = axes[-1].get_legend_handles_labels() + fig.legend( + loc="lower center", + handles=handles, + labels=labels, + bbox_to_anchor=(0.5, 1.0), + ) + + plt.tight_layout() + return fig, axes diff --git a/tests/test_context.py b/tests/test_context.py index 2e8fc433..6e462433 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -4,7 +4,7 @@ def test_rn220_context(): - """Test Context of Rn220 combine fitting.""" + """Test Context of Rn220 fitting.""" _cached_functions.clear() _cached_configs.clear() context = apt.ContextRn220() @@ -18,7 +18,7 @@ def test_rn220_context(): def test_rn220_context_1d(): - """Test 1D Context of Rn220 combine fitting.""" + """Test 1D Context of Rn220 fitting.""" instruction = load_json("rn220.json") bins = instruction["likelihoods"]["rn220_llh"]["bins"][1] diff --git a/tests/test_plot.py b/tests/test_plot.py new file mode 100644 index 00000000..25090311 --- /dev/null +++ b/tests/test_plot.py @@ -0,0 +1,19 @@ +import appletree as apt +from appletree import Plotter +from appletree.utils import load_json + + +def test_plot(): + """Test plot of Rn220 fitting.""" + instruction = load_json("rn220.json") + + filename = "rn220.h5" + instruction["backend_h5"] = filename + + context = apt.Context(instruction) + + context.print_context_summary(short=False) + context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) + + plotter = Plotter(filename) + plotter.make_all_plots()