Skip to content

Commit

Permalink
credibility interval plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
joel-becker committed Jul 29, 2023
1 parent 84aebf1 commit 6f28dcf
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,11 @@ def plot_model_output(
#line_color="red",
mean_line_width=2,
#background_color="black"
show_grid=False, # new option to control grid lines
show_grid=True, # new option to control grid lines
grid_alpha=0.1, # option to control the transparency of grid lines
grid_style='dashed', # option to control the style of grid lines
plot_credibility_interval=True,
credibility_intervals=[0.5, 0.8, 0.95] # list of quantiles for intervals
):
"""
This function plots the output of the financial life model.
Expand Down Expand Up @@ -395,18 +397,36 @@ def plot_model_output(
# Transpose the data
value = np.transpose(value)

# Plot mean path
ax.plot(
value.mean(axis=1),
color=css.primary_color,
alpha=mean_line_alpha,
linewidth=mean_line_width,
) # plot mean path

ax.plot(
value,
color=css.primary_color,
alpha=alpha
) # plot individual paths
)

# Plot credibility intervals or individual paths
if plot_credibility_interval:
sorted_values = np.sort(value, axis=1) # sort values for each time step
for i in range(len(credibility_intervals)-1, -1, -1): # iterate over intervals in reverse order
lower_quantile = (1-credibility_intervals[i]) / 2 # calculate lower quantile
upper_quantile = 1 - lower_quantile # calculate upper quantile
ax.fill_between(
range(value.shape[0]), # x values (time steps)
np.quantile(sorted_values, lower_quantile, axis=1), # lower y values (lower quantile)
np.quantile(sorted_values, upper_quantile, axis=1), # upper y values (upper quantile)
color=css.primary_color,
alpha=0.5 * (i + 1) / len(credibility_intervals), # transparency
label=f'{int(credibility_intervals[i]*100)}% credibility interval' # label for legend
)

ax.legend() # show legend
else:
ax.plot(
value,
color=css.primary_color,
alpha=alpha
)

ax.set_title(key.replace("_", " ").title()) # prettify title
ax.yaxis.set_major_formatter(formatter) # format y axis with comma separator
Expand Down

0 comments on commit 6f28dcf

Please sign in to comment.