From 9841cda949e968e4d34f04d2245b0a1b2e7c387f Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Fri, 22 Sep 2023 23:52:19 +0200 Subject: [PATCH] prediction update apply to the target node only --- docs/source/api.rst | 2 - src/pyhgf/networks.py | 22 ++-- src/pyhgf/updates/binary.py | 109 ++----------------- src/pyhgf/updates/continuous.py | 139 ++----------------------- src/pyhgf/updates/prediction/binary.py | 37 +++---- tests/test_binary.py | 7 +- tests/test_continuous.py | 14 ++- 7 files changed, 50 insertions(+), 280 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 053ae6dcf..f6196e212 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -27,7 +27,6 @@ Core functionnalities to update *binary* nodes. binary_node_prediction_error binary_node_prediction binary_input_prediction_error - binary_input_prediction Updating continuous nodes ========================= @@ -42,7 +41,6 @@ Core functionnalities to update *continuous* nodes. continuous_node_prediction_error continuous_node_prediction continuous_input_prediction_error - continuous_input_prediction Updating categorical nodes ========================== diff --git a/src/pyhgf/networks.py b/src/pyhgf/networks.py index 9aa1afe62..9f3eb654d 100644 --- a/src/pyhgf/networks.py +++ b/src/pyhgf/networks.py @@ -12,14 +12,12 @@ from pyhgf.math import gaussian_surprise from pyhgf.typing import Indexes, UpdateSequence from pyhgf.updates.binary import ( - binary_input_prediction, binary_input_prediction_error, binary_node_prediction, binary_node_prediction_error, ) from pyhgf.updates.categorical import categorical_input_update from pyhgf.updates.continuous import ( - continuous_input_prediction, continuous_input_prediction_error, continuous_node_prediction, continuous_node_prediction_error, @@ -314,15 +312,9 @@ def get_update_sequence( ) for node_idx in node_idxs: - # if the node has no parent, exit here - if (hgf.edges[node_idx].value_parents is None) & ( - hgf.edges[node_idx].volatility_parents is None - ): - continue - - # if this node is part of a familly (the parent has multiple children) + # if this node is part of a family (the parent has multiple children) # and this is not the last of the children, exit here - # we apply this principle for every value / volatility parent + # we apply this principle for every value / volatility families is_youngest = False if node_idx in hgf.input_nodes_idx.idx: is_youngest = True # always update input nodes @@ -355,10 +347,8 @@ def get_update_sequence( ][0] if model_kind == "binary": update_fn = binary_input_prediction_error - prediction_fn = binary_input_prediction elif model_kind == "continuous": update_fn = continuous_input_prediction_error - prediction_fn = continuous_input_prediction elif model_kind == "categorical": continue @@ -380,12 +370,16 @@ def get_update_sequence( # create a new update and prediction sequence step and add it to the list # only the youngest of the family is updated, but all nodes get predictions + # ensure that the node is the youngest of the family (also implicitely ensure + # that it is not orphan, otherwise skip this the node will only have prediction) if is_youngest: new__update_sequence = node_idx, update_fn update_sequence.append(new__update_sequence) - new_prediction_sequence = node_idx, prediction_fn - prediction_sequence.append(new_prediction_sequence) + # no prediction step for an input node + if node_idx not in hgf.input_nodes_idx.idx: + new_prediction_sequence = node_idx, prediction_fn + prediction_sequence.append(new_prediction_sequence) # search recursively for the next update steps - make sure that all the # children have been updated before updating the parent(s) diff --git a/src/pyhgf/updates/binary.py b/src/pyhgf/updates/binary.py index d34054f11..eda817f1d 100644 --- a/src/pyhgf/updates/binary.py +++ b/src/pyhgf/updates/binary.py @@ -6,8 +6,7 @@ from jax import jit from pyhgf.typing import Edges -from pyhgf.updates.continuous import predict_mean, predict_precision -from pyhgf.updates.prediction.binary import predict_input_value_parent +from pyhgf.updates.prediction.binary import predict_binary_state_node from pyhgf.updates.prediction_error.binary import ( prediction_error_input_value_parent, prediction_error_value_parent, @@ -84,11 +83,7 @@ def binary_node_prediction_error( def binary_node_prediction( attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args ) -> Dict: - """Update the value parent(s) of a binary node. - - In a three-level HGF, this step will update the node :math:`x_2`. - - Then returns the new node tuple `(parameters, value_parents, volatility_parents)`. + """Update the expected mean and precision of a binary state node. Parameters ---------- @@ -111,7 +106,7 @@ def binary_node_prediction( Returns ------- attributes : - The updated node structure. + The new node structure with updated mean and expected precision. References ---------- @@ -120,28 +115,12 @@ def binary_node_prediction( arXiv. https://doi.org/10.48550/ARXIV.2305.10937 """ - # using the current node index, unwrap parameters and parents - value_parent_idxs = edges[node_idx].value_parents - - # Return here if no parents node are found - if value_parent_idxs is None: - return attributes - - ################################################################ - # Update the predictions of the continuous value parents (x-2) # - ################################################################ - if value_parent_idxs is not None: - for value_parent_idx in value_parent_idxs: - pihat_value_parent = predict_precision( - attributes, edges, time_step, value_parent_idx - ) - muhat_value_parent = predict_mean( - attributes, edges, time_step, value_parent_idx - ) + # Get the new expected value for the mean and precision + pihat, muhat = predict_binary_state_node(attributes, edges, time_step, node_idx) - # update the parent nodes' parameters - attributes[value_parent_idx]["pihat"] = pihat_value_parent - attributes[value_parent_idx]["muhat"] = muhat_value_parent + # Update the node's attributes + attributes[node_idx]["pihat"] = pihat + attributes[node_idx]["muhat"] = muhat return attributes @@ -224,75 +203,3 @@ def binary_input_prediction_error( attributes[node_idx]["surprise"] = surprise return attributes - - -@partial(jit, static_argnames=("edges", "node_idx")) -def binary_input_prediction( - attributes: Dict, - time_step: float, - node_idx: int, - edges: Edges, - value: float, -) -> Dict: - """Update the input node structure given one binary observation. - - This function is the entry-level of the binary node. It updates the parents of - the input node (:math:`x_1`). - - Parameters - ---------- - value : - The new observed value. - time_step : - The interval between the previous time point and the current time point. - attributes : - The attributes of the probabilistic nodes. - .. note:: - `"psis"` is the value coupling strength. It should have the same length as the - volatility parents' indexes. `"kappas"` is the volatility coupling strength. - It should have the same length as the volatility parents' indexes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - node_idx : - Pointer to the node that needs to be updated. After continuous updates, the - parameters of value and volatility parents (if any) will be different. - - Returns - ------- - attributes : - The updated parameters structure. - - See Also - -------- - update_continuous_parents, update_continuous_input_parents - - References - ---------- - .. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., & - Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1). - arXiv. https://doi.org/10.48550/ARXIV.2305.10937 - - """ - # list value and volatility parents - value_parent_idxs = edges[node_idx].value_parents - volatility_parent_idxs = edges[node_idx].volatility_parents - - if (value_parent_idxs is None) and (volatility_parent_idxs is None): - return attributes - - ####################################################### - # Update the value parent(s) of the binary input node # - ####################################################### - if value_parent_idxs is not None: - for value_parent_idx in value_parent_idxs: - pihat_value_parent, muhat_value_parent = predict_input_value_parent( - attributes, edges, time_step, value_parent_idx - ) - - # Update value parent's parameters - attributes[value_parent_idx]["pihat"] = pihat_value_parent - attributes[value_parent_idx]["muhat"] = muhat_value_parent - - return attributes diff --git a/src/pyhgf/updates/continuous.py b/src/pyhgf/updates/continuous.py index 6b48ab1f5..5ab579825 100644 --- a/src/pyhgf/updates/continuous.py +++ b/src/pyhgf/updates/continuous.py @@ -125,19 +125,7 @@ def continuous_node_prediction_error( def continuous_node_prediction( attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args ) -> Dict: - """Prediction step for the value and volatility parents of a continuous node. - - Updating the node's parents is a two-step process: - 1. Update value parent(s). - 2. Update volatility parent(s). - - If a value/volatility parent has multiple children, all the children will update - the parent together, therefor this function should only be called once per group - of child nodes. The method :py:meth:`pyhgf.model.HGF.get_update_sequence` - ensures that this function is only called once all the children have been - updated. - - Then returns the structure of the new parameters. + """Update the expected mean and precision of a continuous node. Parameters ---------- @@ -150,8 +138,7 @@ def continuous_node_prediction( time_step : The interval between the previous time point and the current time point. node_idx : - Pointer to the node that needs to be updated. After continuous updates, the - parameters of value and volatility parents (if any) will be different. + Pointer to the node that will be updated. edges : The edges of the probabilistic nodes as a tuple of :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. @@ -173,45 +160,14 @@ def continuous_node_prediction( arXiv. https://doi.org/10.48550/ARXIV.2305.10937 """ - # list value and volatility parents - value_parents_idxs = edges[node_idx].value_parents - volatility_parents_idxs = edges[node_idx].volatility_parents - - # return here if no parents node are provided - if (value_parents_idxs is None) and (volatility_parents_idxs is None): - return attributes - - ######################## - # Update value parents # - ######################## - if value_parents_idxs is not None: - for value_parent_idx in value_parents_idxs: - muhat_value_parent = predict_mean( - attributes, edges, time_step, value_parent_idx - ) - pihat_value_parent = predict_precision( - attributes, edges, time_step, value_parent_idx - ) - - # Update this parent's parameters - attributes[value_parent_idx]["pihat"] = pihat_value_parent - attributes[value_parent_idx]["muhat"] = muhat_value_parent + # Get the new expected mean + muhat = predict_mean(attributes, edges, time_step, node_idx) + # Get the new expected precision + pihat = predict_precision(attributes, edges, time_step, node_idx) - ############################# - # Update volatility parents # - ############################# - if volatility_parents_idxs is not None: - for volatility_parent_idx in volatility_parents_idxs: - muhat_volatility_parent = predict_mean( - attributes, edges, time_step, volatility_parent_idx - ) - pihat_volatility_parent = predict_precision( - attributes, edges, time_step, volatility_parent_idx - ) - - # Update this parent's parameters - attributes[volatility_parent_idx]["pihat"] = pihat_volatility_parent - attributes[volatility_parent_idx]["muhat"] = muhat_volatility_parent + # Update this node's parameters + attributes[node_idx]["pihat"] = pihat + attributes[node_idx]["muhat"] = muhat return attributes @@ -294,80 +250,3 @@ def continuous_input_prediction_error( attributes[value_parent_idx]["mu"] = mu_value_parent return attributes - - -@partial(jit, static_argnames=("edges", "node_idx")) -def continuous_input_prediction( - attributes: Dict, - time_step: float, - node_idx: int, - edges: Edges, - value: float, -) -> Dict: - """Update the input node structure. - - This function is the entry-level of the structure updates. It updates the parent - of the input node. - - Parameters - ---------- - value : - The new observed value. - time_step : - The interval between the previous time point and the current time point. - attributes : - The attributes of the probabilistic nodes. - .. note:: - The parameter structure also incorporate the value and volatility coupling - strenght with children and parents (i.e. `"psis_parents"`, `"psis_children"`, - `"kappas_parents"`, `"kappas_children"`). - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - node_idx : - Pointer to the node that needs to be updated. After continuous updates, the - parameters of value and volatility parents (if any) will be different. - - Returns - ------- - attributes : - The updated attributes of the probabilistic nodes. - - See Also - -------- - continuous_node_update, update_binary_input_parents - - References - ---------- - .. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., & - Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1). - arXiv. https://doi.org/10.48550/ARXIV.2305.10937 - - """ - # store timestep in the node's parameters - attributes[node_idx]["time_step"] = time_step - - # list value and volatility parents - value_parents_idxs = edges[node_idx].value_parents - - ######################## - # Update value parents # - ######################## - if value_parents_idxs is not None: - for value_parent_idx in value_parents_idxs: - # if this child is the last one relative to this parent's family, all the - # children will update the parent at once, otherwise just pass and wait - if edges[value_parent_idx].value_children[-1] == node_idx: - muhat_value_parent = predict_mean( - attributes, edges, time_step, value_parent_idx - ) - pihat_value_parent = predict_precision( - attributes, edges, time_step, value_parent_idx - ) - - # update input node's parameters - attributes[value_parent_idx]["pihat"] = pihat_value_parent - attributes[value_parent_idx]["muhat"] = muhat_value_parent - - return attributes diff --git a/src/pyhgf/updates/prediction/binary.py b/src/pyhgf/updates/prediction/binary.py index e8cff0c15..7c6c87808 100644 --- a/src/pyhgf/updates/prediction/binary.py +++ b/src/pyhgf/updates/prediction/binary.py @@ -9,14 +9,14 @@ from pyhgf.typing import Edges -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def predict_input_value_parent( +@partial(jit, static_argnames=("edges", "node_idx")) +def predict_binary_state_node( attributes: Dict, edges: Edges, time_step: float, - value_parent_idx: int, + node_idx: int, ) -> Tuple[Array, ...]: - r"""Prediction step for the value parent of a binary input node. + r"""Get the new expected mean and precision of a binary state node. Parameters ---------- @@ -39,33 +39,28 @@ def predict_input_value_parent( The mean (:math:`\\mu`) of the value parent. """ # List the (unique) value parent of the value parent - value_parent_value_parent_idxs = edges[value_parent_idx].value_parents[0] + value_parent_idx = edges[node_idx].value_parents[0] # Get the drift rate from the value parent of the value parent - driftrate = attributes[value_parent_value_parent_idxs]["rho"] + driftrate = attributes[value_parent_idx]["rho"] # Look at the (optional) value parent's value parents # and update the drift rate accordingly - if edges[value_parent_value_parent_idxs].value_parents is not None: + if edges[value_parent_idx].value_parents is not None: for ( - value_parent_value_parent_value_parent_idx, - psi_parent_parent, + value_parent_value_parent_idx, + psi_parent, ) in zip( - edges[value_parent_value_parent_idxs].value_parents, - attributes[value_parent_value_parent_idxs]["psis_parents"], + edges[value_parent_idx].value_parents, + attributes[value_parent_idx]["psis_parents"], ): - driftrate += ( - psi_parent_parent - * attributes[value_parent_value_parent_value_parent_idx]["mu"] - ) + driftrate += psi_parent * attributes[value_parent_value_parent_idx]["mu"] # Estimate the new expected mean of the value parent and apply the sigmoid transform - muhat_value_parent = ( - attributes[value_parent_value_parent_idxs]["mu"] + time_step * driftrate - ) - muhat_value_parent = sigmoid(muhat_value_parent) + muhat = attributes[value_parent_idx]["mu"] + time_step * driftrate + muhat = sigmoid(muhat) # Estimate the new expected precision of the value parent - pihat_value_parent = 1 / (muhat_value_parent * (1 - muhat_value_parent)) + pihat = 1 / (muhat * (1 - muhat)) - return pihat_value_parent, muhat_value_parent + return pihat, muhat diff --git a/tests/test_binary.py b/tests/test_binary.py index 1becba54d..9e71b61cd 100644 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -12,7 +12,6 @@ from pyhgf.networks import beliefs_propagation from pyhgf.typing import Indexes from pyhgf.updates.binary import ( - binary_input_prediction, binary_input_prediction_error, binary_node_prediction, binary_node_prediction_error, @@ -107,9 +106,9 @@ def test_update_binary_input_parents(self): ) # create update sequence - sequence1 = 0, binary_input_prediction - sequence2 = 1, binary_node_prediction - sequence3 = 2, continuous_node_prediction + sequence1 = 3, continuous_node_prediction + sequence2 = 2, continuous_node_prediction + sequence3 = 1, binary_node_prediction sequence4 = 0, binary_input_prediction_error sequence5 = 1, binary_node_prediction_error sequence6 = 2, continuous_node_prediction_error diff --git a/tests/test_continuous.py b/tests/test_continuous.py index 39d03bb9e..11cb93788 100644 --- a/tests/test_continuous.py +++ b/tests/test_continuous.py @@ -12,7 +12,6 @@ from pyhgf.networks import beliefs_propagation from pyhgf.typing import Indexes from pyhgf.updates.continuous import ( - continuous_input_prediction, continuous_input_prediction_error, continuous_node_prediction, continuous_node_prediction_error, @@ -69,9 +68,8 @@ def test_continuous_node_update(self): ########################################### # No value parent - no volatility parents # ########################################### - sequence1 = 0, continuous_input_prediction - sequence2 = 0, continuous_input_prediction_error - update_sequence = (sequence1, sequence2) + sequence1 = 0, continuous_input_prediction_error + update_sequence = (sequence1,) new_attributes, _ = beliefs_propagation( attributes=attributes, edges=edges, @@ -139,8 +137,8 @@ def test_continuous_input_update(self): ) # create update sequence - sequence1 = 0, continuous_input_prediction - sequence2 = 1, continuous_node_prediction + sequence1 = 1, continuous_node_prediction + sequence2 = 2, continuous_node_prediction sequence3 = 0, continuous_input_prediction_error sequence4 = 1, continuous_node_prediction_error update_sequence = (sequence1, sequence2, sequence3, sequence4) @@ -221,8 +219,8 @@ def test_scan_loop(self): ) # create update sequence - sequence1 = 0, continuous_input_prediction - sequence2 = 1, continuous_node_prediction + sequence1 = 1, continuous_node_prediction + sequence2 = 2, continuous_node_prediction sequence3 = 0, continuous_input_prediction_error sequence4 = 1, continuous_node_prediction_error update_sequence = (sequence1, sequence2, sequence3, sequence4)