diff --git a/docs/source/api.rst b/docs/source/api.rst index aed0479e6..053ae6dcf 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -24,10 +24,10 @@ Core functionnalities to update *binary* nodes. .. autosummary:: :toctree: generated/pyhgf.updates.binary - binary_node_prediction - binary_input_prediction binary_node_prediction_error + binary_node_prediction binary_input_prediction_error + binary_input_prediction Updating continuous nodes ========================= @@ -61,24 +61,6 @@ Prediction error steps Propagate prediction errors to the value and volatility parents of a given node. -Continuous nodes -~~~~~~~~~~~~~~~~ - -.. currentmodule:: pyhgf.updates.prediction_error.continuous - -.. autosummary:: - :toctree: generated/pyhgf.updates.prediction_error.continuous - - prediction_error_mean_value_parent - prediction_error_precision_value_parent - prediction_error_value_parent - prediction_error_volatility_parent - prediction_error_precision_volatility_parent - prediction_error_mean_volatility_parent - prediction_error_input_value_parent - prediction_error_input_mean_value_parent - prediction_error_input_precision_value_parent - Binary nodes ~~~~~~~~~~~~ @@ -94,24 +76,24 @@ Binary nodes input_surprise_inf input_surprise_reg -Prediction steps -================ - -Compute the expectation for future observation given the influence of parent nodes. - Continuous nodes ~~~~~~~~~~~~~~~~ -.. currentmodule:: pyhgf.updates.prediction.continuous +.. currentmodule:: pyhgf.updates.prediction_error.continuous .. autosummary:: - :toctree: generated/pyhgf.updates.prediction.continuous + :toctree: generated/pyhgf.updates.prediction_error.continuous - prediction_mean_value_parent - prediction_precision_value_parent - prediction_value_parent - prediction_precision_volatility_parent - prediction_mean_volatility_parent + prediction_error_mean_value_parent + prediction_error_precision_value_parent + prediction_error_precision_volatility_parent + prediction_error_mean_volatility_parent + prediction_error_input_mean_value_parent + +Prediction steps +================ + +Compute the expectation for future observation given the influence of parent nodes. Binary nodes ~~~~~~~~~~~~ @@ -121,8 +103,18 @@ Binary nodes .. autosummary:: :toctree: generated/pyhgf.updates.prediction.binary - prediction_input_value_parent + predict_input_value_parent + +Continuous nodes +~~~~~~~~~~~~~~~~ + +.. currentmodule:: pyhgf.updates.prediction.continuous + +.. autosummary:: + :toctree: generated/pyhgf.updates.prediction.continuous + predict_mean + predict_precision Distribution ------------ diff --git a/src/pyhgf/updates/binary.py b/src/pyhgf/updates/binary.py index 46e1777c6..d34054f11 100644 --- a/src/pyhgf/updates/binary.py +++ b/src/pyhgf/updates/binary.py @@ -6,8 +6,8 @@ from jax import jit from pyhgf.typing import Edges -from pyhgf.updates.continuous import prediction_value_parent -from pyhgf.updates.prediction.binary import prediction_input_value_parent +from pyhgf.updates.continuous import predict_mean, predict_precision +from pyhgf.updates.prediction.binary import predict_input_value_parent from pyhgf.updates.prediction_error.binary import ( prediction_error_input_value_parent, prediction_error_value_parent, @@ -132,7 +132,10 @@ def binary_node_prediction( ################################################################ if value_parent_idxs is not None: for value_parent_idx in value_parent_idxs: - pihat_value_parent, muhat_value_parent = prediction_value_parent( + pihat_value_parent = predict_precision( + attributes, edges, time_step, value_parent_idx + ) + muhat_value_parent = predict_mean( attributes, edges, time_step, value_parent_idx ) @@ -284,7 +287,7 @@ def binary_input_prediction( ####################################################### if value_parent_idxs is not None: for value_parent_idx in value_parent_idxs: - pihat_value_parent, muhat_value_parent = prediction_input_value_parent( + pihat_value_parent, muhat_value_parent = predict_input_value_parent( attributes, edges, time_step, value_parent_idx ) diff --git a/src/pyhgf/updates/continuous.py b/src/pyhgf/updates/continuous.py index ac3e278a1..6b48ab1f5 100644 --- a/src/pyhgf/updates/continuous.py +++ b/src/pyhgf/updates/continuous.py @@ -6,15 +6,13 @@ from jax import jit from pyhgf.typing import Edges -from pyhgf.updates.prediction.continuous import ( - prediction_input_value_parent, - prediction_value_parent, - prediction_volatility_parent, -) +from pyhgf.updates.prediction.continuous import predict_mean, predict_precision from pyhgf.updates.prediction_error.continuous import ( - prediction_error_input_value_parent, - prediction_error_value_parent, - prediction_error_volatility_parent, + prediction_error_input_mean_value_parent, + prediction_error_mean_value_parent, + prediction_error_mean_volatility_parent, + prediction_error_precision_value_parent, + prediction_error_precision_volatility_parent, ) @@ -82,9 +80,14 @@ def continuous_node_prediction_error( # if this child is the last one relative to this parent's family, all the # children will update the parent at once, otherwise just pass and wait if edges[value_parent_idx].value_children[-1] == node_idx: - (pi_value_parent, mu_value_parent) = prediction_error_value_parent( + # Estimate the precision of the posterior distribution + pi_value_parent = prediction_error_precision_value_parent( attributes, edges, value_parent_idx ) + # Estimate the mean of the posterior distribution + mu_value_parent = prediction_error_mean_value_parent( + attributes, edges, value_parent_idx, pi_value_parent + ) # Update this parent's parameters attributes[value_parent_idx]["pi"] = pi_value_parent @@ -98,12 +101,18 @@ def continuous_node_prediction_error( # if this child is the last one relative to this parent's family, all the # children will update the parent at once, otherwise just pass and wait if edges[volatility_parent_idx].volatility_children[-1] == node_idx: - ( - pi_volatility_parent, - mu_volatility_parent, - ) = prediction_error_volatility_parent( + # Estimate the new precision of the volatility parent + pi_volatility_parent = prediction_error_precision_volatility_parent( attributes, edges, time_step, volatility_parent_idx ) + # Estimate the new mean of the volatility parent + mu_volatility_parent = prediction_error_mean_volatility_parent( + attributes, + edges, + time_step, + volatility_parent_idx, + pi_volatility_parent, + ) # Update this parent's parameters attributes[volatility_parent_idx]["pi"] = pi_volatility_parent @@ -177,7 +186,10 @@ def continuous_node_prediction( ######################## if value_parents_idxs is not None: for value_parent_idx in value_parents_idxs: - (pihat_value_parent, muhat_value_parent) = prediction_value_parent( + muhat_value_parent = predict_mean( + attributes, edges, time_step, value_parent_idx + ) + pihat_value_parent = predict_precision( attributes, edges, time_step, value_parent_idx ) @@ -190,10 +202,10 @@ def continuous_node_prediction( ############################# if volatility_parents_idxs is not None: for volatility_parent_idx in volatility_parents_idxs: - ( - pihat_volatility_parent, - muhat_volatility_parent, - ) = prediction_volatility_parent( + muhat_volatility_parent = predict_mean( + attributes, edges, time_step, volatility_parent_idx + ) + pihat_volatility_parent = predict_precision( attributes, edges, time_step, volatility_parent_idx ) @@ -268,12 +280,14 @@ def continuous_input_prediction_error( # if this child is the last one relative to this parent's family, all the # children will update the parent at once, otherwise just pass and wait if edges[value_parent_idx].value_children[-1] == node_idx: - ( - pi_value_parent, - mu_value_parent, - ) = prediction_error_input_value_parent( + # Estimate the new precision of the value parent + pi_value_parent = prediction_error_precision_value_parent( attributes, edges, value_parent_idx ) + # Estimate the new mean of the value parent + mu_value_parent = prediction_error_input_mean_value_parent( + attributes, edges, value_parent_idx, pi_value_parent + ) # update input node's parameters attributes[value_parent_idx]["pi"] = pi_value_parent @@ -345,7 +359,10 @@ def continuous_input_prediction( # if this child is the last one relative to this parent's family, all the # children will update the parent at once, otherwise just pass and wait if edges[value_parent_idx].value_children[-1] == node_idx: - pihat_value_parent, muhat_value_parent = prediction_input_value_parent( + muhat_value_parent = predict_mean( + attributes, edges, time_step, value_parent_idx + ) + pihat_value_parent = predict_precision( attributes, edges, time_step, value_parent_idx ) diff --git a/src/pyhgf/updates/prediction/binary.py b/src/pyhgf/updates/prediction/binary.py index 3c2c9ff4a..e8cff0c15 100644 --- a/src/pyhgf/updates/prediction/binary.py +++ b/src/pyhgf/updates/prediction/binary.py @@ -10,7 +10,7 @@ @partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_input_value_parent( +def predict_input_value_parent( attributes: Dict, edges: Edges, time_step: float, diff --git a/src/pyhgf/updates/prediction/continuous.py b/src/pyhgf/updates/prediction/continuous.py index cde7efe8a..58a289bad 100644 --- a/src/pyhgf/updates/prediction/continuous.py +++ b/src/pyhgf/updates/prediction/continuous.py @@ -1,7 +1,7 @@ # Author: Nicolas Legrand from functools import partial -from typing import Dict, Tuple +from typing import Dict import jax.numpy as jnp from jax import Array, jit @@ -9,14 +9,14 @@ from pyhgf.typing import Edges -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_mean_value_parent( +@partial(jit, static_argnames=("edges", "node_idx")) +def predict_mean( attributes: Dict, edges: Edges, time_step: float, - value_parent_idx: int, + node_idx: int, ) -> Array: - r"""Expected value for the mean of the value parent. + r"""Expected value for the mean of a probabilistic node. Parameters ---------- @@ -28,39 +28,39 @@ def prediction_mean_value_parent( For each node, the index list value and volatility parents and children. time_step : The interval between the previous time point and the current time point. - value_parent_idx : + node_idx : Pointer to the node that will be updated. Returns ------- - muhat_value_parent : + muhat : The expected value for the mean of the value parent (:math:`\\hat{\\mu}`). """ - # List the value and volatility parents of the value parent - value_parent_value_parents_idxs = edges[value_parent_idx].value_parents + # List the node's value parents + value_parents_idxs = edges[node_idx].value_parents - # Get the drift rate of the value parent - driftrate = attributes[value_parent_idx]["rho"] + # Get the drift rate from the node + driftrate = attributes[node_idx]["rho"] - # Look at the (optional) value parents of the value parent + # Look at the (optional) value parents for this node # and update the drift rate accordingly - if value_parent_value_parents_idxs is not None: - for value_parent_value_parent_idx, psi in zip( - value_parent_value_parents_idxs, - attributes[value_parent_idx]["psis_parents"], + if value_parents_idxs is not None: + for value_parent_idx, psi in zip( + value_parents_idxs, + attributes[node_idx]["psis_parents"], ): - driftrate += psi * attributes[value_parent_value_parent_idx]["mu"] + driftrate += psi * attributes[value_parent_idx]["mu"] - # Compute the new expected mean for the value parent - muhat_value_parent = attributes[value_parent_idx]["mu"] + time_step * driftrate + # Compute the new expected mean this node + muhat = attributes[node_idx]["mu"] + time_step * driftrate - return muhat_value_parent + return muhat -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_precision_value_parent( - attributes: Dict, edges: Edges, time_step: float, value_parent_idx: int +@partial(jit, static_argnames=("edges", "node_idx")) +def predict_precision( + attributes: Dict, edges: Edges, time_step: float, node_idx: int ) -> Array: r"""Expected value for the precision of the value parent. @@ -74,380 +74,35 @@ def prediction_precision_value_parent( For each node, the index list value and volatility parents and children. time_step : The interval between the previous time point and the current time point. - value_parent_idx : + node_idx : Pointer to the node that will be updated. Returns ------- - pihat_value_parent : + pihat : The expected value for the mean of the value parent (:math:`\\hat{\\pi}`). """ - # List the value parent's volatility parents - value_parent_volatility_parents_idxs = edges[value_parent_idx].volatility_parents + # List the node's volatility parents + volatility_parents_idxs = edges[node_idx].volatility_parents - # Get the log volatility from the value parent - logvol = attributes[value_parent_idx]["omega"] + # Get the log volatility from the node + logvol = attributes[node_idx]["omega"] - # Look at the (optional) value parent's volatility parents + # Look at the (optional) volatility parents # and update the log volatility accordingly - if value_parent_volatility_parents_idxs is not None: - for value_parent_volatility_parents_idx, k in zip( - value_parent_volatility_parents_idxs, - attributes[value_parent_idx]["kappas_parents"], + if volatility_parents_idxs is not None: + for volatility_parents_idx, k in zip( + volatility_parents_idxs, + attributes[node_idx]["kappas_parents"], ): - logvol += k * attributes[value_parent_volatility_parents_idx]["mu"] + logvol += k * attributes[volatility_parents_idx]["mu"] # Estimate new nu nu = time_step * jnp.exp(logvol) new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) - # Estimate the new expected precision for the value parent - pihat_value_parent = 1 / (1 / attributes[value_parent_idx]["pi"] + new_nu) + # Estimate the new expected precision for the node + pihat = 1 / (1 / attributes[node_idx]["pi"] + new_nu) - return pihat_value_parent - - -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_value_parent( - attributes: Dict, - edges: Edges, - time_step: float, - value_parent_idx: int, -) -> Tuple[Array, ...]: - """Prediction step for the value parent(s) of a continuous node. - - Updating the posterior distribution of the value parent is a two-step process: - #. Update the posterior precision using - :py:func:`pyhgf.updates.prediction.continuous.prediction_precision_value_parent`. - #. Update the posterior mean using - :py:func:`pyhgf.updates.prediction.continuous.prediction_mean_value_parent`. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - time_step : - The interval between the previous time point and the current time point. - value_parent_idx : - Pointer to the value parent node. - - Returns - ------- - pi_value_parent : - The precision (:math:`\\pi`) of the value parent. - mu_value_parent : - The mean (:math:`\\mu`) of the value parent. - - """ - # Get the new estimation for the value parent's precision - pi_value_parent = prediction_precision_value_parent( - attributes, edges, time_step, value_parent_idx - ) - # Get the new estimation for the value parent's mean - mu_value_parent = prediction_mean_value_parent( - attributes, edges, time_step, value_parent_idx - ) - - return pi_value_parent, mu_value_parent - - -@partial(jit, static_argnames=("edges", "volatility_parent_idx")) -def prediction_precision_volatility_parent( - attributes: Dict, edges: Edges, time_step: float, volatility_parent_idx: int -) -> Array: - r"""Expected value for the precision of the volatility parent of a contiuous node. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - time_step : - The interval between the previous time point and the current time point. - volatility_parent_idx : - Pointer to the volatility parent node that will be updated. - - Returns - ------- - pihat_volatility_parent : - The new expected value for the mean of the volatility parent - (:math:`\\hat{\\pi}`). - - """ - # List the volatility parent's volatility parents - volatility_parent_volatility_parents_idx = edges[ - volatility_parent_idx - ].volatility_parents - - # Get the log volatility from the volatility parent - logvol = attributes[volatility_parent_idx]["omega"] - - # Look at the (optional) volatility parent's volatility parents - # and update the log volatility accordingly - if volatility_parent_volatility_parents_idx is not None: - for vo_pa_vo_pa, k in zip( - volatility_parent_volatility_parents_idx, - attributes[volatility_parent_idx]["kappas_parents"], - ): - logvol += k * attributes[vo_pa_vo_pa]["mu"] - - # Estimate new_nu - new_nu = time_step * jnp.exp(logvol) - new_nu = jnp.where(new_nu > 1e-128, new_nu, jnp.nan) - - # Estimate the new expected precision for the volatility parent - pihat_volatility_parent = 1 / (1 / attributes[volatility_parent_idx]["pi"] + new_nu) - - return pihat_volatility_parent - - -@partial(jit, static_argnames=("edges", "volatility_parent_idx")) -def prediction_mean_volatility_parent( - attributes: Dict, edges: Edges, time_step: float, volatility_parent_idx: int -) -> Array: - r"""Expected value for the mean of the volatility parent of a contiuous node. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - time_step : - The interval between the previous time point and the current time point. - volatility_parent_idx : - Pointer to the volatility parent node that will be updated. - - Returns - ------- - muhat_volatility_parent : - The new expected value for the mean of the volatility parent - (:math:`\\hat{\\mu}`). - - """ - # List the volatility parent's value parents - volatility_parent_value_parents_idx = edges[volatility_parent_idx].value_parents - - # Get the drift rate from the volatility parent - driftrate = attributes[volatility_parent_idx]["rho"] - - # Look at the (optional) volatility parent's value parents - # and update the drift rate accordingly - if volatility_parent_value_parents_idx is not None: - for vo_pa_va_pa, psi in zip( - volatility_parent_value_parents_idx, - attributes[volatility_parent_idx]["psi_parents"], - ): - driftrate += psi * attributes[vo_pa_va_pa]["mu"] - - # Estimate the new expected mean for the volatility parent - muhat_volatility_parent = ( - attributes[volatility_parent_idx]["mu"] + time_step * driftrate - ) - - return muhat_volatility_parent - - -@partial(jit, static_argnames=("edges", "volatility_parent_idx")) -def prediction_volatility_parent( - attributes: Dict, - edges: Edges, - time_step: float, - volatility_parent_idx: int, -) -> Tuple[Array, ...]: - r"""Prediction step for the volatility parent(s) of a continuous node. - - Updating the posterior distribution of the volatility parent is a two-step process: - #. Update the posterior precision using - :py:fun:`update_precision_volatility_parent`. - #. Update the posterior mean value using - :py:fun:`update_mean_volatility_parent`. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - time_step : - The interval between the previous time point and the current time point. - volatility_parent_idx : - Pointer to the value parent node. - - Returns - ------- - pi_volatility_parent : - The precision (:math:`\\pi`) of the volatility parent. - mu_volatility_parent : - The mean (:math:`\\mu`) of the volatility parent. - - """ - pi_volatility_parent = prediction_precision_volatility_parent( - attributes, edges, time_step, volatility_parent_idx - ) - mu_volatility_parent = prediction_mean_volatility_parent( - attributes, edges, time_step, volatility_parent_idx - ) - - return pi_volatility_parent, mu_volatility_parent - - -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_input_value_parent( - attributes: Dict, - edges: Edges, - time_step: float, - value_parent_idx: int, -) -> Array: - """Prediction step for the value parent of a continuous input node. - - Updating the posterior distribution of the value parent of a continuous input node - is a two-step process: - #. Update the parent's expected mean using - :py:func:`pyhgf.updates.prediction.continuous.prediction_input_mean_value_parent`. - #. Update the parent's expected precision using - :py:func:`pyhgf.updates.prediction.continuous.prediction_input_precision_value_parent`. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - time_step : - The interval between the previous time point and the current time point. - value_parent_idx : - Pointer to the value parent node. - - Returns - ------- - pihat_value_parent : - The precision (:math:`\\hat{\\pi}`) of the value parent. - muhat_value_parent : - The mean (:math:`\\hat{\\mu}`) of the value parent. - - """ - muhat_value_parent = prediction_input_mean_value_parent( - attributes, edges, time_step, value_parent_idx - ) - pihat_value_parent = prediction_input_precision_value_parent( - attributes, edges, time_step, value_parent_idx - ) - - return pihat_value_parent, muhat_value_parent - - -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_input_mean_value_parent( - attributes: Dict, - edges: Edges, - time_step: float, - value_parent_idx: int, -) -> Array: - r"""Expected value for the mean of the input's value parent. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - time_step : - The interval between the previous time point and the current time point. - value_parent_idx : - Pointer to the node that will be updated. - - Returns - ------- - muhat_value_parent : - The expected value for the mean of the value parent (:math:`\\hat{\\mu}`). - - """ - # list the value parent's value parents - value_parent_value_parents_idxs = edges[value_parent_idx].value_parents - - # Get the value parent's log volatility - driftrate = attributes[value_parent_idx]["rho"] - - # Look at the (optional) value parent's value parents - # and update the drift rate accordingly - if value_parent_value_parents_idxs is not None: - for value_parent_value_parents_idx in value_parent_value_parents_idxs: - driftrate += ( - attributes[value_parent_idx]["psis_parents"][0] - * attributes[value_parent_value_parents_idx]["mu"] - ) - - # Estimate the new expected mean for the value parent - muhat_value_parent = attributes[value_parent_idx]["mu"] + time_step * driftrate - - return muhat_value_parent - - -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_input_precision_value_parent( - attributes: Dict, - edges: Edges, - time_step: float, - value_parent_idx: int, -) -> Array: - r"""Expected value for the precision of the input's value parent. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - time_step : - The interval between the previous time point and the current time point. - value_parent_idx : - Pointer to the node that will be updated. - - Returns - ------- - pihat_value_parent : - The expected value for the mean of the value parent (:math:`\\hat{\\pi}`). - - """ - # List volatility parents - value_parent_volatility_parents_idxs = edges[value_parent_idx].volatility_parents - - # Get the log volatility from the value parent - logvol = attributes[value_parent_idx]["omega"] - - # Look at the (optional) value parent's volatility parents - # and update the log volatility accordingly - if value_parent_volatility_parents_idxs is not None: - for value_parent_volatility_parents_idx, k in zip( - value_parent_volatility_parents_idxs, - attributes[value_parent_idx]["kappas_parents"], - ): - logvol += k * attributes[value_parent_volatility_parents_idx]["mu"] - - # Estimate new nu - nu = time_step * jnp.exp(logvol) - new_nu = jnp.where(nu > 1e-128, nu, jnp.nan) - - # Estimate the new expected precision for the value parent - pihat_value_parent = 1 / (1 / attributes[value_parent_idx]["pi"] + new_nu) - - return pihat_value_parent + return pihat diff --git a/src/pyhgf/updates/prediction_error/continuous.py b/src/pyhgf/updates/prediction_error/continuous.py index d97288398..90358b6b6 100644 --- a/src/pyhgf/updates/prediction_error/continuous.py +++ b/src/pyhgf/updates/prediction_error/continuous.py @@ -1,7 +1,7 @@ # Author: Nicolas Legrand from functools import partial -from typing import Dict, Tuple +from typing import Dict import jax.numpy as jnp from jax import Array, jit @@ -104,99 +104,6 @@ def prediction_error_precision_value_parent( return pi_value_parent -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_error_value_parent( - attributes: Dict, - edges: Edges, - value_parent_idx: int, -) -> Tuple[Array, ...]: - r"""Update the mean and precision of the value parent of a continuous node. - - Updating the posterior distribution of the value parent is a two-step process: - #. Update the posterior precision using - :py:func:`pyhgf.updates.prediction_error.continuous.prediction_error_precision_value_parent`. - #. Update the posterior mean value using - :py:func:`pyhgf.updates.prediction_error.continuous.prediction_error_mean_value_parent`. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the network as a tuple of :py:class:`pyhgf.typing.Indexes` with - the same length as node number. For each node, the index list value and - volatility parents. - value_parent_idx : - Pointer to the value parent node. - - Returns - ------- - pi_value_parent : - The precision (:math:`\\pi`) of the value parent. - mu_value_parent : - The mean (:math:`\\mu`) of the value parent. - - """ - # Estimate the precision of the posterior distribution - pi_value_parent = prediction_error_precision_value_parent( - attributes, edges, value_parent_idx - ) - # Estimate the mean of the posterior distribution - mu_value_parent = prediction_error_mean_value_parent( - attributes, edges, value_parent_idx, pi_value_parent - ) - - return pi_value_parent, mu_value_parent - - -@partial(jit, static_argnames=("edges", "volatility_parent_idx")) -def prediction_error_volatility_parent( - attributes: Dict, - edges: Edges, - time_step: float, - volatility_parent_idx: int, -) -> Tuple[Array, ...]: - """Update the mean and precision of the volatility parent of a continuous node. - - Updating the posterior distribution of the volatility parent is a two-step process: - #. Update the posterior precision using - :py:func:`pyhgf.updates.prediction_error.continuous.prediction_error_precision_volatility_parent`. - #. Update the posterior mean value using - :py:func:`pyhgf.updates.prediction_error.continuous.prediction_error_mean_volatility_parent`. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the network as a tuple of :py:class:`pyhgf.typing.Indexes` with - the same length as node number. For each node, the index list value and - volatility parents. - time_step : - The interval between the previous time point and the current time point. - volatility_parent_idx : - Pointer to the value parent node. - - Returns - ------- - pi_volatility_parent : - The precision (:math:`\\pi`) of the value parent. - mu_volatility_parent : - The mean (:math:`\\mu`) of the value parent. - - """ - # Estimate the new precision of the volatility parent - pi_volatility_parent = prediction_error_precision_volatility_parent( - attributes, edges, time_step, volatility_parent_idx - ) - # Estimate the new mean of the volatility parent - mu_volatility_parent = prediction_error_mean_volatility_parent( - attributes, edges, time_step, volatility_parent_idx, pi_volatility_parent - ) - - return pi_volatility_parent, mu_volatility_parent - - @partial(jit, static_argnames=("edges", "volatility_parent_idx")) def prediction_error_precision_volatility_parent( attributes: Dict, edges: Edges, time_step: float, volatility_parent_idx: int @@ -339,51 +246,6 @@ def prediction_error_mean_volatility_parent( return mu_volatility_parent -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_error_input_value_parent( - attributes: Dict, - edges: Edges, - value_parent_idx: int, -) -> Array: - r"""Update the mean and precision of the value parent of a continuous input node. - - Updating the posterior distribution of the value parent is a two-step process: - #. Update the posterior precision using - :py:func:`pyhgf.updates.prediction_error.continuous.prediction_error_input_precision_value_parent`. - #. Update the posterior mean value using - :py:func:`pyhgf.updates.prediction_error.continuous.prediction_error_input_mean_value_parent`. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - value_parent_idx : - Pointer to the node that will be updated. - - Returns - ------- - pi_value_parent : - The updated value for the mean of the value parent (:math:`\\pi`). - mu_value_parent : - The updated value for the mean of the value parent (:math:`\\mu`). - - """ - # Estimate the new precision of the value parent - pi_value_parent = prediction_error_input_precision_value_parent( - attributes, edges, value_parent_idx - ) - # Estimate the new mean of the value parent - mu_value_parent = prediction_error_input_mean_value_parent( - attributes, edges, value_parent_idx, pi_value_parent - ) - - return pi_value_parent, mu_value_parent - - @partial(jit, static_argnames=("edges", "value_parent_idx")) def prediction_error_input_mean_value_parent( attributes: Dict, @@ -434,49 +296,3 @@ def prediction_error_input_mean_value_parent( mu_value_parent = muhat_value_parent + pe_children return mu_value_parent - - -@partial(jit, static_argnames=("edges", "value_parent_idx")) -def prediction_error_input_precision_value_parent( - attributes: Dict, - edges: Edges, - value_parent_idx: int, -) -> Array: - r"""Send prediction-error to the precision of a continuous input's value parent. - - Parameters - ---------- - attributes : - The attributes of the probabilistic nodes. - edges : - The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. - value_parent_idx : - Pointer to the node that will be updated. - - Returns - ------- - pi_value_parent : - The updated value for the mean of the value parent (:math:`\\pi`). - - """ - # Get the current expected precision for the volatility parent - # The prediction sequence was triggered by the new observation so this value is - # already in the node attributes - pihat_value_parent = attributes[value_parent_idx]["pihat"] - - # gather precisions updates from other input nodes - # in the case of a multivariate descendency - pi_children = 0.0 - for child_idx, psi_child in zip( - edges[value_parent_idx].value_children, - attributes[value_parent_idx]["psis_children"], - ): - pihat_child = attributes[child_idx]["pihat"] - pi_children += psi_child**2 * pihat_child - - # Compute the new precision of the value parent - pi_value_parent = pihat_value_parent + pi_children - - return pi_value_parent