From c3ee951c81b699dfa9106d672646f29e6922b1d0 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 7 Nov 2023 09:27:48 +0100 Subject: [PATCH] updates functions and docstrings --- src/pyhgf/model.py | 4 +- src/pyhgf/updates/continuous.py | 5 +- src/pyhgf/updates/prediction/continuous.py | 83 +++++++++++++------ .../prediction_error/inputs/continuous.py | 2 +- .../prediction_error/nodes/continuous.py | 56 ++++++------- 5 files changed, 88 insertions(+), 62 deletions(-) diff --git a/src/pyhgf/model.py b/src/pyhgf/model.py index a9ec9fb87..0707155f8 100644 --- a/src/pyhgf/model.py +++ b/src/pyhgf/model.py @@ -686,6 +686,7 @@ def add_value_parent( "tonic_drift": tonic_drift, "autoregressive_coefficient": autoregressive_coefficient, "autoregressive_intercept": autoregressive_intercept, + "temp": {"predicted_volatility": 0.0}, } # add more parameters (optional) @@ -735,7 +736,7 @@ def add_volatility_parent( self, children_idxs: Union[List, int], volatility_coupling: Union[float, np.ndarray, ArrayLike] = 1.0, - mean: Union[float, np.ndarray, ArrayLike] = 1.0, + mean: Union[float, np.ndarray, ArrayLike] = 0.0, precision: Union[float, np.ndarray, ArrayLike] = 1.0, tonic_volatility: Union[float, np.ndarray, ArrayLike] = -4.0, tonic_drift: Union[float, np.ndarray, ArrayLike] = 0.0, @@ -805,6 +806,7 @@ def add_volatility_parent( "tonic_drift": tonic_drift, "autoregressive_coefficient": autoregressive_coefficient, "autoregressive_intercept": autoregressive_intercept, + "temp": {"predicted_volatility": 0.0}, } # add more parameters (optional) diff --git a/src/pyhgf/updates/continuous.py b/src/pyhgf/updates/continuous.py index fa47e6933..6afdcf45e 100644 --- a/src/pyhgf/updates/continuous.py +++ b/src/pyhgf/updates/continuous.py @@ -290,10 +290,13 @@ def continuous_node_prediction( # Get the new expected mean expected_mean = predict_mean(attributes, edges, time_step, node_idx) # Get the new expected precision - expected_precision = predict_precision(attributes, edges, time_step, node_idx) + expected_precision, predicted_volatility = predict_precision( + attributes, edges, time_step, node_idx + ) # Update this node's parameters attributes[node_idx]["expected_precision"] = expected_precision + attributes[node_idx]["temp"]["predicted_volatility"] = predicted_volatility attributes[node_idx]["expected_mean"] = expected_mean return attributes diff --git a/src/pyhgf/updates/prediction/continuous.py b/src/pyhgf/updates/prediction/continuous.py index 8c64db127..caa009909 100644 --- a/src/pyhgf/updates/prediction/continuous.py +++ b/src/pyhgf/updates/prediction/continuous.py @@ -16,25 +16,26 @@ def predict_mean( time_step: float, node_idx: int, ) -> Array: - r"""Compute the expected mean of a probabilistic node. + r"""Compute the expected mean of a continuous state node. Parameters ---------- attributes : - The attributes of the probabilistic nodes. + The attributes of the probabilistic network that contains the continuous state + node. edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. + The edges of the probabilistic network as a tuple of + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number of + nodes. For each node, the index list value/volatility - parents/children. time_step : - The interval between the previous time point and the current time point. + The time interval between the previous time point and the current time point. node_idx : - Pointer to the node that will be updated. + Index of the node that should be updated. Returns ------- expected_mean : - The expected value for the mean of the value parent (:math:`\\hat{\\mu}`). + The new expected mean of the value parent. """ # List the node's value parents @@ -75,47 +76,75 @@ def predict_mean( def predict_precision( attributes: Dict, edges: Edges, time_step: float, node_idx: int ) -> Array: - r"""Compute the expected precision of a probabilistic node. + r"""Compute the expected precision of a continuous state node node. + + The expected precision at time :math:`k` for a state node :math:`a` is given by: + + .. math:: + + \hat{\pi}_a^{(k)} = \frac{1}{\frac{1}{\pi_a^{(k-1)}} + \Omega_a^{(k)}} + + where :math:`\Omega_a^{(k)}` is the *total predicted volatility*. This term is the + sum of the tonic (endogenous) and phasic (exogenous) volatility, such as: + + .. math:: + + \Omega_a^{(k)} = t^{(k)} \\ + \exp{ \left( \omega_a + \sum_{j=1}^{N_{vopa}} \kappa_j \mu_a^{(k-1)} \right) } + + + with :math:`\kappa_j` the volatility coupling strength with the volatility parent + :math:`j`. Parameters ---------- attributes : - The attributes of the probabilistic nodes. + The attributes of the probabilistic network that contains the continuous state + node. edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. + The edges of the probabilistic network as a tuple of + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number of + nodes. For each node, the index list value/volatility - parents/children. time_step : - The interval between the previous time point and the current time point. + The time interval between the previous time point and the current time point. node_idx : - Pointer to the node that will be updated. + Index of the node that should be updated. Returns ------- expected_precision : - The expected value for the mean of the value parent (:math:`\\hat{\\pi}`). + The new expected precision of the value parent. + predicted volatility : + The predicted volatility :math:`\Omega_a^{(k)}`. This value is stored in the + node for latter use in the prediction-error steps. """ # List the node's volatility parents volatility_parents_idxs = edges[node_idx].volatility_parents - # Get the log volatility from the node - logvol = attributes[node_idx]["tonic_volatility"] + # Get the tonic volatility from the node + total_volatility = attributes[node_idx]["tonic_volatility"] - # Look at the (optional) volatility parents - # and update the log volatility accordingly + # Look at the (optional) volatility parents and add their value to the tonic + # volatility to get the total volatility if volatility_parents_idxs is not None: for volatility_parents_idx, volatility_coupling in zip( volatility_parents_idxs, attributes[node_idx]["volatility_coupling_parents"], ): - logvol += volatility_coupling * attributes[volatility_parents_idx]["mean"] - - # Estimate new nu - nu = time_step * jnp.exp(logvol) - new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) + total_volatility += ( + volatility_coupling * attributes[volatility_parents_idx]["mean"] + ) + + # compute the predicted_volatility from the total volatility + predicted_volatility = time_step * jnp.exp(total_volatility) + predicted_volatility = jnp.where( + predicted_volatility > 1e-128, predicted_volatility, jnp.nan + ) # Estimate the new expected precision for the node - expected_precision = 1 / (1 / attributes[node_idx]["precision"] + new_nu) + expected_precision = 1 / ( + (1 / attributes[node_idx]["precision"]) + predicted_volatility + ) - return expected_precision + return expected_precision, predicted_volatility diff --git a/src/pyhgf/updates/prediction_error/inputs/continuous.py b/src/pyhgf/updates/prediction_error/inputs/continuous.py index c3a221386..59aec10dd 100644 --- a/src/pyhgf/updates/prediction_error/inputs/continuous.py +++ b/src/pyhgf/updates/prediction_error/inputs/continuous.py @@ -14,7 +14,7 @@ def prediction_error_input_precision_value_parent( attributes: Dict, edges: Edges, value_parent_idx: int ) -> Array: - r"""Send prediction-error and update the precision of a value parent (continuous). + r"""Send prediction-error and update the precision of a continuous value parent. Parameters ---------- diff --git a/src/pyhgf/updates/prediction_error/nodes/continuous.py b/src/pyhgf/updates/prediction_error/nodes/continuous.py index ef376e29d..c6251c17a 100644 --- a/src/pyhgf/updates/prediction_error/nodes/continuous.py +++ b/src/pyhgf/updates/prediction_error/nodes/continuous.py @@ -172,45 +172,37 @@ def prediction_error_precision_volatility_parent( # gather volatility precisions from the child nodes children_volatility_precision = 0.0 - for child_idx, kappas_children in zip( + for child_idx, volatility_coupling in zip( edges[volatility_parent_idx].volatility_children, # type: ignore attributes[volatility_parent_idx]["volatility_coupling_children"], ): - # Look at the (optional) volatility parents and update logvol accordingly - logvol = attributes[child_idx]["tonic_volatility"] - if edges[child_idx].volatility_parents is not None: - for children_volatility_parents, volatility_coupling in zip( - edges[child_idx].volatility_parents, - attributes[child_idx]["volatility_coupling_parents"], - ): - logvol += ( - volatility_coupling - * attributes[children_volatility_parents]["mean"] - ) + # retrieve the predicted volatility that was computed in the prediction step + predicted_volatility = attributes[child_idx]["temp"]["predicted_volatility"] - # Compute new value for nu - nu_children = time_step * jnp.exp(logvol) - nu_children = jnp.where(nu_children > 1e-128, nu_children, jnp.nan) + # compute the volatility weigthed precision + volatility_weigthed_precision = ( + predicted_volatility * attributes[child_idx]["expected_precision"] + ) + # compute the volatility prediction error (VOPE) vope_children = ( - 1 / attributes[child_idx]["precision"] - + (attributes[child_idx]["mean"] - attributes[child_idx]["expected_mean"]) + ( + attributes[child_idx]["expected_precision"] + / attributes[child_idx]["precision"] + ) + + attributes[child_idx]["expected_precision"] + * (attributes[child_idx]["mean"] - attributes[child_idx]["expected_mean"]) ** 2 - ) * attributes[child_idx]["expected_precision"] - 1 + - 1 + ) children_volatility_precision += ( - 0.5 - * ( - kappas_children - * nu_children - * attributes[child_idx]["expected_precision"] - ) - ** 2 - * ( - 1 - + (1 - 1 / (nu_children * attributes[child_idx]["precision"])) - * vope_children - ) + 0.5 * (volatility_coupling * volatility_weigthed_precision) ** 2 + + (volatility_coupling * volatility_weigthed_precision) ** 2 * vope_children + - 0.5 + * volatility_coupling**2 + * volatility_weigthed_precision + * vope_children ) # Estimate the new precision of the volatility parent @@ -272,7 +264,7 @@ def prediction_error_mean_volatility_parent( # Gather volatility prediction errors from the child nodes children_volatility_prediction_error = 0.0 - for child_idx, kappas_children in zip( + for child_idx, volatility_coupling in zip( edges[volatility_parent_idx].volatility_children, # type: ignore attributes[volatility_parent_idx]["volatility_coupling_children"], ): @@ -299,7 +291,7 @@ def prediction_error_mean_volatility_parent( ) * attributes[child_idx]["expected_precision"] - 1 children_volatility_prediction_error += ( 0.5 - * kappas_children + * volatility_coupling * nu_children * attributes[child_idx]["expected_precision"] / precision_volatility_parent