Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reassess plotting internals to make use of arviz plotting #1371

Open
wd60622 opened this issue Jan 14, 2025 · 0 comments
Open

Reassess plotting internals to make use of arviz plotting #1371

wd60622 opened this issue Jan 14, 2025 · 0 comments

Comments

@wd60622
Copy link
Contributor

wd60622 commented Jan 14, 2025

This function has some similarities to what arviz offers:

def plot_prior_vs_posterior(
self,
var_name: str,
alphabetical_sort: bool = True,
figsize: tuple[int, int] | None = None,
) -> plt.Figure:
"""
Plot the prior vs posterior distribution for a specified variable in a 3 columngrid layout.
This function generates KDE plots for each MMM channel, showing the prior predictive
and posterior distributions with their respective means highlighted.
It sorts the plots either alphabetically or based on the difference between the
posterior and prior means, with the largest difference (posterior - prior) at the top.
Parameters
----------
var_name: str
The variable to analyze (e.g., 'adstock_alpha').
alphabetical_sort: bool, optional
Whether to sort the channels alphabetically (True) or by the difference
between the posterior and prior means (False). Default is True.
figsize : tuple of int, optional
Figure size in inches. If None, it will be calculated based on the number of channels.
Returns
-------
fig : plt.Figure
The matplotlib figure object
Raises
------
ValueError
If the required attributes (prior, posterior) were not found.
ValueError
If var_name is not a string.
"""
if not hasattr(self, "fit_result") or not hasattr(self, "prior"):
raise ValueError(
"Required attributes (fit_result, prior) not found. "
"Ensure you've called model.fit() and model.sample_prior_predictive()"
)
if not isinstance(var_name, str):
raise ValueError(
"var_name must be a string. Please provide a single variable name."
)
# Determine the number of channels and set up the grid
num_channels = len(self.channel_columns)
num_cols = 3
num_rows = (num_channels + num_cols - 1) // num_cols # Calculate rows needed
if figsize is None:
figsize = (25, 5 * num_rows)
# Calculate prior and posterior means for sorting
channel_means = []
for channel in self.channel_columns:
prior_mean = self.prior[var_name].sel(channel=channel).mean().values
posterior_mean = (
self.fit_result[var_name].sel(channel=channel).mean().values
)
difference = posterior_mean - prior_mean
channel_means.append((channel, prior_mean, posterior_mean, difference))
# Choose how to sort the channels
if alphabetical_sort:
sorted_channels = sorted(channel_means, key=lambda x: x[0])
else:
# Otherwise, sort on difference between posterior and prior means
sorted_channels = sorted(channel_means, key=lambda x: x[3], reverse=True)
fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
axs = axs.flatten() # Flatten the array for easy iteration
# Plot for each channel
for i, (channel, prior_mean, posterior_mean, difference) in enumerate(
sorted_channels
):
# Extract prior samples for the current channel
prior_samples = self.prior[var_name].sel(channel=channel).values.flatten()
# Plot the prior predictive distribution
sns.kdeplot(
prior_samples,
ax=axs[i],
label="Prior Predictive",
color="blue",
fill=True,
)
# Add a vertical line for the mean of the prior distribution
axs[i].axvline(
prior_mean,
color="blue",
linestyle="--",
linewidth=2,
label=f"Prior Mean: {prior_mean:.2f}",
)
# Extract posterior samples for the current channel
posterior_samples = (
self.fit_result[var_name].sel(channel=channel).values.flatten()
)
# Plot the prior predictive distribution
sns.kdeplot(
posterior_samples,
ax=axs[i],
label="Posterior Predictive",
color="red",
fill=True,
alpha=0.15,
)
# Add a vertical line for the mean of the posterior distribution
axs[i].axvline(
posterior_mean,
color="red",
linestyle="--",
linewidth=2,
label=f"Posterior Mean: {posterior_mean:.2f} (Diff: {difference:.2f})",
)
# Set titles and labels
axs[i].set_title(channel) # Subplot title is just the channel name
axs[i].set_xlabel(var_name.capitalize())
axs[i].set_ylabel("Density")
axs[i].legend(loc="upper right")
# Set the overall figure title
fig.suptitle(f"Prior vs Posterior Distributions | {var_name}", fontsize=16)
# Hide any unused subplots
for j in range(i + 1, len(axs)):
fig.delaxes(axs[j])
# Adjust layout
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to fit the title
return fig

Original comment: #1368 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant