Skip to content

Commit

Permalink
smal fixes in the plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 6, 2023
1 parent b6a6516 commit efa3c40
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,8 @@ def plot_nodes(

# compute mu +/- sd at time t-1
# and use the sigmoid transform before plotting
mu_parent = np.insert(
trajectories_df[f"x_{parent_idx}_mu"][:-1], 0, np.nan
)
pi_parent = np.insert(
trajectories_df[f"x_{parent_idx}_pi"][:-1], 0, np.nan
)
mu_parent = trajectories_df[f"x_{parent_idx}_muhat"]
pi_parent = trajectories_df[f"x_{parent_idx}_pihat"]
sd = np.sqrt(1 / pi_parent)
y1 = 1 / (1 + np.exp(-mu_parent + sd))
y2 = 1 / (1 + np.exp(-mu_parent - sd))
Expand Down Expand Up @@ -577,6 +573,28 @@ def plot_nodes(
alpha=0.5,
color=input_colors[ii],
)
else:
child_idx = np.where(
np.array(hgf.input_nodes_idx.idx) == child_idx
)[0][0]
axs[i].scatter(
trajectories_df.time,
trajectories_df[f"observation_input_{child_idx}"],
s=3,
label=f"Value child node - {ii}",
alpha=0.3,
color=input_colors[ii],
edgecolors="grey",
)
axs[i].plot(
trajectories_df.time,
trajectories_df[f"observation_input_{child_idx}"],
linewidth=0.5,
linestyle="--",
alpha=0.3,
color=input_colors[ii],
)
axs[i].legend()

# plotting surprise
# -----------------
Expand Down

0 comments on commit efa3c40

Please sign in to comment.