From 0849cc989ddc4148a7584e2324830c681d04cae3 Mon Sep 17 00:00:00 2001 From: Nicolas Legrand Date: Thu, 20 Jun 2024 17:04:46 +0200 Subject: [PATCH] fix error in posterior updates with multiple value children with unobserved values (#208) --- src/pyhgf/updates/posterior/continuous.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/pyhgf/updates/posterior/continuous.py b/src/pyhgf/updates/posterior/continuous.py index 8ccfdf1dc..9a3968a63 100644 --- a/src/pyhgf/updates/posterior/continuous.py +++ b/src/pyhgf/updates/posterior/continuous.py @@ -131,8 +131,10 @@ def posterior_update_mean_continuous_node( # expected precisions from the value children # sum the precision weigthed prediction errors over all children value_precision_weigthed_prediction_error += ( - (value_coupling * attributes[value_child_idx]["expected_precision"]) - / node_precision + ( + (value_coupling * attributes[value_child_idx]["expected_precision"]) + / node_precision + ) ) * value_prediction_error # Volatility coupling updates - update the mean of a volatility parent @@ -286,14 +288,10 @@ def posterior_update_precision_continuous_node( edges[node_idx].value_children, # type: ignore attributes[node_idx]["value_coupling_children"], ): + # cancel the prediction error if the child value was not observed precision_weigthed_prediction_error += ( value_coupling**2 * attributes[value_child_idx]["expected_precision"] - ) - - # cancel the prediction error if the child value was not observed - precision_weigthed_prediction_error *= attributes[value_child_idx][ - "observed" - ] + ) * attributes[value_child_idx]["observed"] # Volatility coupling updates - update the precision of a volatility parent # ------------------------------------------------------------------------- @@ -313,6 +311,7 @@ def posterior_update_precision_continuous_node( ] # sum over all volatility children + # cancel the prediction error if the child value was not observed precision_weigthed_prediction_error += ( 0.5 * (volatility_coupling * effective_precision) ** 2 + (volatility_coupling * effective_precision) ** 2 @@ -321,12 +320,7 @@ def posterior_update_precision_continuous_node( * volatility_coupling**2 * effective_precision * volatility_prediction_error - ) - - # cancel the prediction error if the child value was not observed - precision_weigthed_prediction_error *= attributes[volatility_child_idx][ - "observed" - ] + ) * attributes[volatility_child_idx]["observed"] # Compute the new posterior precision # using value prediction errors from both value and volatility coupling