diff --git a/appletree/plot.py b/appletree/plot.py index 00c3fa10..900c8071 100644 --- a/appletree/plot.py +++ b/appletree/plot.py @@ -160,18 +160,26 @@ def plot_marginal_posterior(self, fig=None, **hist_kwargs): axes = [] for i in range(self.n_param): ax = fig.add_subplot(n_rows, n_cols, i + 1) - ax.hist(self.flat_chain[:, i], density=True, **hist_kwargs) + ax.hist(self.flat_chain[:, i], density=True, label='Posterior', **hist_kwargs) prior = self.param_prior[self.param_names[i]] prior_type = prior["prior_type"] args = prior["prior_args"] if prior_type != "free": x = np.linspace(*ax.get_xlim(), 100) - ax.plot(x, pdf[prior_type](x, **args), color="grey", ls="--") + ax.plot(x, pdf[prior_type](x, **args), color="grey", ls="--", label='Prior') ax.set_xlabel(self.param_names[i]) ax.set_ylabel("PDF") ax.set_ylim(0, None) axes.append(ax) + # Set legend + handles, labels = axes[-1].get_legend_handles_labels() + fig.legend( + loc='lower center', + handles=handles, labels=labels, + bbox_to_anchor=(0.5, 1.0), + ) + plt.tight_layout() return fig, axes @@ -295,12 +303,19 @@ def autocorr_new(y, c=5.0): ax.set_xscale("log") ax.set_yscale("log") ax.set_ylabel(f"Auto correlation of {self.param_names[i]}") - ax.legend() axes.append(ax) # Set xlabels of the last two axes axes[-1].set_xlabel("Number of iterations") axes[-2].set_xlabel("Number of iterations") + # Set legend + handles, labels = ax[-1].get_legend_handles_labels() + fig.legend( + loc='lower center', + handles=handles, labels=labels, + bbox_to_anchor=(0.5, 1.0), + ) + plt.tight_layout() return fig, axes