Skip to content

Commit

Permalink
remove duplicated legends
Browse files Browse the repository at this point in the history
  • Loading branch information
zihaoxu98 committed Mar 6, 2024
1 parent f499352 commit 88343a0
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions appletree/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,26 @@ def plot_marginal_posterior(self, fig=None, **hist_kwargs):
axes = []
for i in range(self.n_param):
ax = fig.add_subplot(n_rows, n_cols, i + 1)
ax.hist(self.flat_chain[:, i], density=True, **hist_kwargs)
ax.hist(self.flat_chain[:, i], density=True, label='Posterior', **hist_kwargs)
prior = self.param_prior[self.param_names[i]]
prior_type = prior["prior_type"]
args = prior["prior_args"]
if prior_type != "free":
x = np.linspace(*ax.get_xlim(), 100)
ax.plot(x, pdf[prior_type](x, **args), color="grey", ls="--")
ax.plot(x, pdf[prior_type](x, **args), color="grey", ls="--", label='Prior')
ax.set_xlabel(self.param_names[i])
ax.set_ylabel("PDF")
ax.set_ylim(0, None)
axes.append(ax)

# Set legend
handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(
loc='lower center',
handles=handles, labels=labels,
bbox_to_anchor=(0.5, 1.0),
)

plt.tight_layout()
return fig, axes

Expand Down Expand Up @@ -295,12 +303,19 @@ def autocorr_new(y, c=5.0):
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_ylabel(f"Auto correlation of {self.param_names[i]}")
ax.legend()
axes.append(ax)

# Set xlabels of the last two axes
axes[-1].set_xlabel("Number of iterations")
axes[-2].set_xlabel("Number of iterations")

# Set legend
handles, labels = ax[-1].get_legend_handles_labels()
fig.legend(
loc='lower center',
handles=handles, labels=labels,
bbox_to_anchor=(0.5, 1.0),
)

plt.tight_layout()
return fig, axes

0 comments on commit 88343a0

Please sign in to comment.