From efa3c405e12322186b6764f10b6f8cf7f1ea0922 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Wed, 6 Sep 2023 15:19:31 +0200 Subject: [PATCH] smal fixes in the plotting functions --- pyhgf/plots.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/pyhgf/plots.py b/pyhgf/plots.py index 8e98f1b73..807cc078d 100644 --- a/pyhgf/plots.py +++ b/pyhgf/plots.py @@ -491,12 +491,8 @@ def plot_nodes( # compute mu +/- sd at time t-1 # and use the sigmoid transform before plotting - mu_parent = np.insert( - trajectories_df[f"x_{parent_idx}_mu"][:-1], 0, np.nan - ) - pi_parent = np.insert( - trajectories_df[f"x_{parent_idx}_pi"][:-1], 0, np.nan - ) + mu_parent = trajectories_df[f"x_{parent_idx}_muhat"] + pi_parent = trajectories_df[f"x_{parent_idx}_pihat"] sd = np.sqrt(1 / pi_parent) y1 = 1 / (1 + np.exp(-mu_parent + sd)) y2 = 1 / (1 + np.exp(-mu_parent - sd)) @@ -577,6 +573,28 @@ def plot_nodes( alpha=0.5, color=input_colors[ii], ) + else: + child_idx = np.where( + np.array(hgf.input_nodes_idx.idx) == child_idx + )[0][0] + axs[i].scatter( + trajectories_df.time, + trajectories_df[f"observation_input_{child_idx}"], + s=3, + label=f"Value child node - {ii}", + alpha=0.3, + color=input_colors[ii], + edgecolors="grey", + ) + axs[i].plot( + trajectories_df.time, + trajectories_df[f"observation_input_{child_idx}"], + linewidth=0.5, + linestyle="--", + alpha=0.3, + color=input_colors[ii], + ) + axs[i].legend() # plotting surprise # -----------------