diff --git a/docs/source/api.rst b/docs/source/api.rst index 1e8801271..69bca80d6 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,13 +9,13 @@ API +++ -Nodes updates -------------- +Updates functions +----------------- -Update functions used for prediction and prediction-error steps during the belief propagation over the network of probabilistic nodes. These functions call intenaly the more specific update function listed in the prediction and prediction-error sections below. +The `updates` module contains the update function used both for prediction and prediction-error steps during the belief propagation. These functions call intenaly the more specific update function listed in the prediction and prediction-error sub-modules. -Binary -====== +Updating binary nodes +===================== Core functionnalities to update *binary* nodes. @@ -31,8 +31,8 @@ Core functionnalities to update *binary* nodes. binary_surprise -Continuous -========== +Updating continuous nodes +========================= Core functionnalities to update *continuous* nodes. @@ -47,8 +47,8 @@ Core functionnalities to update *continuous* nodes. continuous_input_prediction gaussian_surprise -Categorical -=========== +Updating categorical nodes +========================== Core functionnalities to update *categorical* nodes. @@ -61,10 +61,12 @@ Core functionnalities to update *categorical* nodes. dirichlet_kullback_leibler Prediction error steps ----------------------- +====================== -Continuous -========== +Propagate prediction errors to the value and volatility parents of a given node. + +Continuous nodes +~~~~~~~~~~~~~~~~ .. currentmodule:: pyhgf.updates.prediction_error.continuous @@ -79,10 +81,12 @@ Continuous prediction_error_mean_volatility_parent Prediction steps ----------------- +================ + +Compute the expectation for future observation given the influence of parent nodes. -Continuous -========== +Continuous nodes +~~~~~~~~~~~~~~~~ .. currentmodule:: pyhgf.updates.prediction_error.continuous diff --git a/src/pyhgf/networks.py b/src/pyhgf/networks.py index f60df53f4..c707ae1b5 100644 --- a/src/pyhgf/networks.py +++ b/src/pyhgf/networks.py @@ -19,7 +19,7 @@ from pyhgf.updates.categorical import categorical_input_update from pyhgf.updates.continuous import ( continuous_input_prediction, - continuous_input_update, + continuous_input_prediction_error, continuous_node_prediction, continuous_node_prediction_error, gaussian_surprise, @@ -357,7 +357,7 @@ def get_update_sequence( update_fn = binary_input_update prediction_fn = binary_input_prediction elif model_kind == "continuous": - update_fn = continuous_input_update + update_fn = continuous_input_prediction_error prediction_fn = continuous_input_prediction elif model_kind == "categorical": continue diff --git a/src/pyhgf/updates/continuous.py b/src/pyhgf/updates/continuous.py index dea8edbcd..c060ce60e 100644 --- a/src/pyhgf/updates/continuous.py +++ b/src/pyhgf/updates/continuous.py @@ -9,6 +9,7 @@ from pyhgf.typing import Edges from pyhgf.updates.prediction.continuous import ( + prediction_input_value_parent, prediction_value_parent, prediction_volatility_parent, ) @@ -213,7 +214,7 @@ def continuous_node_prediction( @partial(jit, static_argnames=("edges", "node_idx")) -def continuous_input_update( +def continuous_input_prediction_error( attributes: Dict, time_step: float, node_idx: int, @@ -424,49 +425,8 @@ def continuous_input_prediction( # if this child is the last one relative to this parent's family, all the # children will update the parent at once, otherwise just pass and wait if edges[value_parent_idx].value_children[-1] == node_idx: - # list value and volatility parents - value_parent_value_parents_idxs = edges[value_parent_idx].value_parents - value_parent_volatility_parents_idxs = edges[ - value_parent_idx - ].volatility_parents - - # Compute new value for nu and pihat - logvol = attributes[value_parent_idx]["omega"] - - # Look at the (optional) va_pa's volatility parents - # and update logvol accordingly - if value_parent_volatility_parents_idxs is not None: - for value_parent_volatility_parents_idx, k in zip( - value_parent_volatility_parents_idxs, - attributes[value_parent_idx]["kappas_parents"], - ): - logvol += ( - k * attributes[value_parent_volatility_parents_idx]["mu"] - ) - - # Estimate new_nu - nu = time_step * jnp.exp(logvol) - new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) - pihat_value_parent = 1 / ( - 1 / attributes[value_parent_idx]["pi"] + new_nu - ) - - # Compute new muhat - driftrate = attributes[value_parent_idx]["rho"] - - # Look at the (optional) va_pa's value parents - # and update drift rate accordingly - if value_parent_value_parents_idxs is not None: - for ( - value_parent_value_parents_idx - ) in value_parent_value_parents_idxs: - driftrate += ( - attributes[value_parent_idx]["psis_parents"][0] - * attributes[value_parent_value_parents_idx]["mu"] - ) - - muhat_value_parent = ( - attributes[value_parent_idx]["mu"] + time_step * driftrate + pihat_value_parent, muhat_value_parent = prediction_input_value_parent( + attributes, edges, time_step, value_parent_idx ) # update input node's parameters diff --git a/src/pyhgf/updates/prediction/continuous.py b/src/pyhgf/updates/prediction/continuous.py index fb566058c..e647ddb73 100644 --- a/src/pyhgf/updates/prediction/continuous.py +++ b/src/pyhgf/updates/prediction/continuous.py @@ -212,3 +212,77 @@ def prediction_volatility_parent( ) return pi_volatility_parent, mu_volatility_parent + + +@partial(jit, static_argnames=("edges", "value_parent_idx")) +def prediction_input_value_parent( + attributes: Dict, + edges: Edges, + time_step: float, + value_parent_idx: int, +) -> Array: + muhat_value_parent = prediction_input_mean_value_parent( + attributes, edges, time_step, value_parent_idx + ) + pihat_value_parent = prediction_input_precision_value_parent( + attributes, edges, time_step, value_parent_idx + ) + + return pihat_value_parent, muhat_value_parent + + +@partial(jit, static_argnames=("edges", "value_parent_idx")) +def prediction_input_mean_value_parent( + attributes: Dict, + edges: Edges, + time_step: float, + value_parent_idx: int, +) -> Array: + # list value parents + value_parent_value_parents_idxs = edges[value_parent_idx].value_parents + + # Compute new muhat + driftrate = attributes[value_parent_idx]["rho"] + + # Look at the (optional) va_pa's value parents + # and update drift rate accordingly + if value_parent_value_parents_idxs is not None: + for value_parent_value_parents_idx in value_parent_value_parents_idxs: + driftrate += ( + attributes[value_parent_idx]["psis_parents"][0] + * attributes[value_parent_value_parents_idx]["mu"] + ) + + muhat_value_parent = attributes[value_parent_idx]["mu"] + time_step * driftrate + + return muhat_value_parent + + +@partial(jit, static_argnames=("edges", "value_parent_idx")) +def prediction_input_precision_value_parent( + attributes: Dict, + edges: Edges, + time_step: float, + value_parent_idx: int, +) -> Array: + # list volatility parents + value_parent_volatility_parents_idxs = edges[value_parent_idx].volatility_parents + + # Compute new value for nu and pihat + logvol = attributes[value_parent_idx]["omega"] + + # Look at the (optional) va_pa's volatility parents + # and update logvol accordingly + if value_parent_volatility_parents_idxs is not None: + for value_parent_volatility_parents_idx, k in zip( + value_parent_volatility_parents_idxs, + attributes[value_parent_idx]["kappas_parents"], + ): + logvol += k * attributes[value_parent_volatility_parents_idx]["mu"] + + # Estimate new_nu + nu = time_step * jnp.exp(logvol) + new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) + pihat_value_parent = 1 / (1 / attributes[value_parent_idx]["pi"] + new_nu) + + return pihat_value_parent diff --git a/tests/test_continuous.py b/tests/test_continuous.py index 8d8e6ca2c..3ca5361a5 100644 --- a/tests/test_continuous.py +++ b/tests/test_continuous.py @@ -12,7 +12,7 @@ from pyhgf.typing import Indexes from pyhgf.updates.continuous import ( continuous_input_prediction, - continuous_input_update, + continuous_input_prediction_error, continuous_node_prediction, continuous_node_prediction_error, gaussian_surprise, @@ -72,7 +72,7 @@ def test_continuous_node_update(self): # No value parent - no volatility parents # ########################################### sequence1 = 0, continuous_input_prediction - sequence2 = 0, continuous_input_update + sequence2 = 0, continuous_input_prediction_error update_sequence = (sequence1, sequence2) new_attributes, _ = beliefs_propagation( attributes=attributes, @@ -145,7 +145,7 @@ def test_continuous_input_update(self): # create update sequence sequence1 = 0, continuous_input_prediction sequence2 = 1, continuous_node_prediction - sequence3 = 0, continuous_input_update + sequence3 = 0, continuous_input_prediction_error sequence4 = 1, continuous_node_prediction_error update_sequence = (sequence1, sequence2, sequence3, sequence4) data = jnp.array([0.2, 1.0]) @@ -229,7 +229,7 @@ def test_scan_loop(self): # create update sequence sequence1 = 0, continuous_input_prediction sequence2 = 1, continuous_node_prediction - sequence3 = 0, continuous_input_update + sequence3 = 0, continuous_input_prediction_error sequence4 = 1, continuous_node_prediction_error update_sequence = (sequence1, sequence2, sequence3, sequence4) diff --git a/tests/test_structure.py b/tests/test_structure.py index 005d54e1f..e3ba509e7 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -8,7 +8,7 @@ from pyhgf.networks import beliefs_propagation, list_branches, trim_sequence from pyhgf.typing import Indexes from pyhgf.updates.continuous import ( - continuous_input_update, + continuous_input_prediction_error, continuous_node_prediction_error, ) @@ -66,7 +66,7 @@ def test_beliefs_propagation(self): ) # create update sequence - sequence1 = 0, continuous_input_update + sequence1 = 0, continuous_input_prediction_error sequence2 = 1, continuous_node_prediction_error update_sequence = (sequence1, sequence2) @@ -106,7 +106,7 @@ def test_trim_sequence(self): Indexes(None, None, (3,), None), ) update_sequence = ( - (0, continuous_input_update), + (0, continuous_input_prediction_error), (1, continuous_node_prediction_error), (2, continuous_node_prediction_error), (3, continuous_node_prediction_error),