Skip to content

Commit

Permalink
fix error in posterior updates with multiple value children with unob…
Browse files Browse the repository at this point in the history
…served values (ComputationalPsychiatry#208)
  • Loading branch information
LegrandNico committed Aug 9, 2024
1 parent 5475ae7 commit c7f3798
Showing 1 changed file with 22 additions and 30 deletions.
52 changes: 22 additions & 30 deletions src/pyhgf/updates/posterior/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,21 +348,17 @@ def posterior_update_precision_continuous_node(
# sum the prediction errors from both value and volatility coupling
precision_weigthed_prediction_error = 0.0

# Value coupling updates - update the precision of a value parent
# ---------------------------------------------------------------
if edges[node_idx].value_children is not None:
for value_child_idx, value_coupling in zip(
edges[node_idx].value_children, # type: ignore
attributes[node_idx]["value_coupling_children"],
):
precision_weigthed_prediction_error += (
value_coupling**2 * attributes[value_child_idx]["expected_precision"]
)

# cancel the prediction error if the child value was not observed
precision_weigthed_prediction_error *= attributes[value_child_idx][
"observed"
]
# Value coupling updates - update the precision of a value parent
# ---------------------------------------------------------------
if edges[node_idx].value_children is not None:
for value_child_idx, value_coupling in zip(
edges[node_idx].value_children, # type: ignore
attributes[node_idx]["value_coupling_children"],
):
# cancel the prediction error if the child value was not observed
precision_weigthed_prediction_error += (
value_coupling**2 * attributes[value_child_idx]["expected_precision"]
) * attributes[value_child_idx]["observed"]

# Volatility coupling updates - update the precision of a volatility parent
# -------------------------------------------------------------------------
Expand All @@ -381,21 +377,17 @@ def posterior_update_precision_continuous_node(
"volatility_prediction_error"
]

# sum over all volatility children
precision_weigthed_prediction_error += (
0.5 * (volatility_coupling * effective_precision) ** 2
+ (volatility_coupling * effective_precision) ** 2
* volatility_prediction_error
- 0.5
* volatility_coupling**2
* effective_precision
* volatility_prediction_error
)

# cancel the prediction error if the child value was not observed
precision_weigthed_prediction_error *= attributes[volatility_child_idx][
"observed"
]
# sum over all volatility children
# cancel the prediction error if the child value was not observed
precision_weigthed_prediction_error += (
0.5 * (volatility_coupling * effective_precision) ** 2
+ (volatility_coupling * effective_precision) ** 2
* volatility_prediction_error
- 0.5
* volatility_coupling**2
* effective_precision
* volatility_prediction_error
) * attributes[volatility_child_idx]["observed"]

# Compute the new posterior precision
# using value prediction errors from both value and volatility coupling
Expand Down

0 comments on commit c7f3798

Please sign in to comment.