From ded0e2331dbf9acdc4ea6861c2d9cc8110eeefc9 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Fri, 15 Sep 2023 15:24:52 +0200 Subject: [PATCH] refactor update for continuous nodes (volatility parents) --- src/pyhgf/updates/continuous.py | 132 ++----------- src/pyhgf/updates/posterior/continuous.py | 218 ++++++++++++++++++++-- 2 files changed, 224 insertions(+), 126 deletions(-) diff --git a/src/pyhgf/updates/continuous.py b/src/pyhgf/updates/continuous.py index d9e419c79..b399b68ec 100644 --- a/src/pyhgf/updates/continuous.py +++ b/src/pyhgf/updates/continuous.py @@ -8,14 +8,17 @@ from jax.typing import ArrayLike from pyhgf.typing import Edges -from pyhgf.updates.posterior.continuous import continuous_node_update_value_parent +from pyhgf.updates.posterior.continuous import ( + update_value_parent, + update_volatility_parent, +) @partial(jit, static_argnames=("edges", "node_idx")) def continuous_node_update( attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args ) -> Dict: - """Update the value and volatility parent(s) of a continuous node. + """Update the posterior of the value and volatility parent(s) of a continuous node. Updating the node's parents is a two-step process: 1. Update value parent(s). @@ -32,26 +35,21 @@ def continuous_node_update( Parameters ---------- attributes : - The structure of nodes' parameters. Each parameter is a dictionary with the - following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for - continuous nodes. - .. note:: - The parameter structure also incorporate the value and volatility coupling - strenght with children and parents (i.e. `"psis_parents"`, `"psis_children"`, - `"kappas_parents"`, `"kappas_children"`). + The nodes' parameters. time_step : The interval between the previous time point and the current time point. node_idx : Pointer to the node that needs to be updated. After continuous updates, the parameters of value and volatility parents (if any) will be different. edges : - Tuple of :py:class:`pyhgf.typing.Indexes` with the same length as node number. - For each node, the index list value and volatility parents. + The edges of the network as a tuple of :py:class:`pyhgf.typing.Indexes` with + the same length as node number. For each node, the index list value and + volatility parents. Returns ------- attributes : - The updated parameters structure. + The updated nodes' parameters. See Also -------- @@ -76,10 +74,7 @@ def continuous_node_update( # Update value parents # ######################## if value_parents_idxs is not None: - # the strength of the value coupling between the base node and the parents nodes - psis = attributes[node_idx]["psis_parents"] - - for value_parent_idx, psi in zip(value_parents_idxs, psis): + for value_parent_idx in value_parents_idxs: # if this child is the last one relative to this parent's family, all the # children will update the parent at once, otherwise just pass and wait if edges[value_parent_idx].value_children[-1] == node_idx: @@ -87,9 +82,7 @@ def continuous_node_update( pi_value_parent, mu_value_parent, nu_value_parent, - ) = continuous_node_update_value_parent( - attributes, edges, time_step, value_parent_idx, psi - ) + ) = update_value_parent(attributes, edges, time_step, value_parent_idx) # Update this parent's parameters attributes[value_parent_idx]["pi"] = pi_value_parent @@ -104,101 +97,12 @@ def continuous_node_update( # if this child is the last one relative to this parent's family, all the # children will update the parent at once, otherwise just pass and wait if edges[volatility_parent_idx].volatility_children[-1] == node_idx: - # list the value and volatility parents - volatility_parent_value_parents_idx = edges[ - volatility_parent_idx - ].value_parents - volatility_parent_volatility_parents_idx = edges[ - volatility_parent_idx - ].volatility_parents - - # Compute new value for nu and pihat - logvol = attributes[volatility_parent_idx]["omega"] - - # Look at the (optional) vo_pa's volatility parents - # and update logvol accordingly - if volatility_parent_volatility_parents_idx is not None: - for vo_pa_vo_pa, k in zip( - volatility_parent_volatility_parents_idx, - attributes[volatility_parent_idx]["kappas_parents"], - ): - logvol += k * attributes[vo_pa_vo_pa]["mu"] - - # Estimate new_nu - new_nu = time_step * jnp.exp(logvol) - new_nu = jnp.where(new_nu > 1e-128, new_nu, jnp.nan) - - pihat_volatility_parent, nu_volatility_parent = [ - 1 / (1 / attributes[volatility_parent_idx]["pi"] + new_nu), - new_nu, - ] - - # gather volatility precisions from the child nodes - children_volatility_precision = 0.0 - for child_idx, kappas_children in zip( - edges[volatility_parent_idx].volatility_children, - attributes[volatility_parent_idx]["kappas_children"], - ): - nu_children = attributes[child_idx]["nu"] - pihat_children = attributes[child_idx]["pihat"] - pi_children = attributes[child_idx]["pi"] - vope_children = ( - 1 / attributes[child_idx]["pi"] - + (attributes[child_idx]["mu"] - attributes[child_idx]["muhat"]) - ** 2 - ) * attributes[child_idx]["pihat"] - 1 - - children_volatility_precision += ( - 0.5 - * (kappas_children * nu_children * pihat_children) ** 2 - * (1 + (1 - 1 / (nu_children * pi_children)) * vope_children) - ) - - pi_volatility_parent = ( - pihat_volatility_parent + children_volatility_precision - ) - - pi_volatility_parent = jnp.where( - pi_volatility_parent <= 0, jnp.nan, pi_volatility_parent - ) - - # drift rate of the GRW - driftrate = attributes[volatility_parent_idx]["rho"] - - # Look at the (optional) va_pa's value parents - # and update drift rate accordingly - if volatility_parent_value_parents_idx is not None: - for vo_pa_va_pa in volatility_parent_value_parents_idx: - driftrate += psi * attributes[vo_pa_va_pa]["mu"] - - muhat_volatility_parent = ( - attributes[volatility_parent_idx]["mu"] + time_step * driftrate - ) - - # gather volatility prediction errors from the child nodes - children_volatility_prediction_error = 0.0 - for child_idx, kappas_children in zip( - edges[volatility_parent_idx].volatility_children, - attributes[volatility_parent_idx]["kappas_children"], - ): - nu_children = attributes[child_idx]["nu"] - pihat_children = attributes[child_idx]["pihat"] - vope_children = ( - 1 / attributes[child_idx]["pi"] - + (attributes[child_idx]["mu"] - attributes[child_idx]["muhat"]) - ** 2 - ) * attributes[child_idx]["pihat"] - 1 - children_volatility_prediction_error += ( - 0.5 - * kappas_children - * nu_children - * pihat_children - / pi_volatility_parent - * vope_children - ) - - mu_volatility_parent = ( - muhat_volatility_parent + children_volatility_prediction_error + ( + pi_volatility_parent, + mu_volatility_parent, + nu_volatility_parent, + ) = update_volatility_parent( + attributes, edges, time_step, volatility_parent_idx ) # Update this parent's parameters diff --git a/src/pyhgf/updates/posterior/continuous.py b/src/pyhgf/updates/posterior/continuous.py index 926a49fc8..a8f7f7fc1 100644 --- a/src/pyhgf/updates/posterior/continuous.py +++ b/src/pyhgf/updates/posterior/continuous.py @@ -10,8 +10,13 @@ from pyhgf.typing import Edges -def continuous_node_update_mean_value_parent( - attributes, edges, time_step, value_parent_idx, pi_value_parent, psi +@partial(jit, static_argnames=("edges", "value_parent_idx")) +def update_mean_value_parent( + attributes: Dict, + edges: Edges, + time_step: float, + value_parent_idx: int, + pi_value_parent: ArrayLike, ) -> Array: # list the value and volatility parents value_parent_value_parents_idxs = edges[value_parent_idx].value_parents @@ -19,10 +24,12 @@ def continuous_node_update_mean_value_parent( # Compute new muhat driftrate = attributes[value_parent_idx]["rho"] - # Look at the (optional) va_pa's value parents - # and update drift rate accordingly + # Look at the (optional) valu parents of the value parents and update drift rate if value_parent_value_parents_idxs is not None: - for va_pa_va_pa in value_parent_value_parents_idxs: + for va_pa_va_pa, psi in zip( + value_parent_value_parents_idxs, + attributes[value_parent_idx]["psis_parents"], + ): driftrate += psi * attributes[va_pa_va_pa]["mu"] muhat_value_parent = attributes[value_parent_idx]["mu"] + time_step * driftrate @@ -44,8 +51,9 @@ def continuous_node_update_mean_value_parent( return mu_value_parent -def continuous_node_update_precision_value_parent( - attributes, edges, time_step, value_parent_idx +@partial(jit, static_argnames=("edges", "value_parent_idx")) +def update_precision_value_parent( + attributes: Dict, edges: Edges, time_step: float, value_parent_idx: int ) -> Array: # list the value and volatility parents value_parent_volatility_parents_idxs = edges[value_parent_idx].volatility_parents @@ -88,18 +96,204 @@ def continuous_node_update_precision_value_parent( @partial(jit, static_argnames=("edges", "value_parent_idx")) -def continuous_node_update_value_parent( +def update_value_parent( attributes: Dict, edges: Edges, time_step: float, value_parent_idx: int, - psi: ArrayLike, ) -> Tuple[Array, ...]: - pi_value_parent, nu_value_parent = continuous_node_update_precision_value_parent( + """Update the mean and precision of the value parent of a continuous node. + + Updating the posterior distribution of the value parent is a two-step process: + 1. Update the posterior precision using + :py:fun:`continuous_node_update_precision_value_parent`. + 2. Update the posterior mean value using + :py:fun:`continuous_node_update_mean_value_parent`. + + Parameters + ---------- + attributes : + The nodes' parameters. + edges : + The edges of the network as a tuple of :py:class:`pyhgf.typing.Indexes` with + the same length as node number. For each node, the index list value and + volatility parents. + time_step : + The interval between the previous time point and the current time point. + value_parent_idx : + Pointer to the value parent node. + + Returns + ------- + pi_value_parent : + The precision (:math:`\\pi`) of the value parent. + mu_value_parent : + The mean (:math:`\\mu`) of the value parent. + nu_value_parent : + + """ + pi_value_parent, nu_value_parent = update_precision_value_parent( attributes, edges, time_step, value_parent_idx ) - mu_value_parent = continuous_node_update_mean_value_parent( - attributes, edges, time_step, value_parent_idx, pi_value_parent, psi + mu_value_parent = update_mean_value_parent( + attributes, edges, time_step, value_parent_idx, pi_value_parent ) return pi_value_parent, mu_value_parent, nu_value_parent + + +@partial(jit, static_argnames=("edges", "volatility_parent_idx")) +def update_volatility_parent( + attributes: Dict, + edges: Edges, + time_step: float, + volatility_parent_idx: int, +) -> Tuple[Array, ...]: + """Update the mean and precision of the volatility parent of a continuous node. + + Updating the posterior distribution of the volatility parent is a two-step process: + 1. Update the posterior precision using + :py:fun:`update_precision_volatility_parent`. + 2. Update the posterior mean value using + :py:fun:`update_mean_volatility_parent`. + + Parameters + ---------- + attributes : + The nodes' parameters. + edges : + The edges of the network as a tuple of :py:class:`pyhgf.typing.Indexes` with + the same length as node number. For each node, the index list value and + volatility parents. + time_step : + The interval between the previous time point and the current time point. + volatility_parent_idx : + Pointer to the value parent node. + + Returns + ------- + pi_value_parent : + The precision (:math:`\\pi`) of the value parent. + mu_value_parent : + The mean (:math:`\\mu`) of the value parent. + nu_value_parent : + + """ + pi_volatility_parent, nu_volatility_parent = update_precision_volatility_parent( + attributes, edges, time_step, volatility_parent_idx + ) + mu_volatility_parent = update_mean_volatility_parent( + attributes, edges, time_step, volatility_parent_idx, pi_volatility_parent + ) + + return pi_volatility_parent, mu_volatility_parent, nu_volatility_parent + + +@partial(jit, static_argnames=("edges", "volatility_parent_idx")) +def update_precision_volatility_parent( + attributes: Dict, edges: Edges, time_step: float, volatility_parent_idx: int +) -> Array: + # list the value parents of the volatility parent + volatility_parent_volatility_parents_idx = edges[ + volatility_parent_idx + ].volatility_parents + + # Compute new value for nu and pihat + logvol = attributes[volatility_parent_idx]["omega"] + + # Look at the (optional) vo_pa's volatility parents + # and update logvol accordingly + if volatility_parent_volatility_parents_idx is not None: + for vo_pa_vo_pa, k in zip( + volatility_parent_volatility_parents_idx, + attributes[volatility_parent_idx]["kappas_parents"], + ): + logvol += k * attributes[vo_pa_vo_pa]["mu"] + + # Estimate new_nu + new_nu = time_step * jnp.exp(logvol) + new_nu = jnp.where(new_nu > 1e-128, new_nu, jnp.nan) + + pihat_volatility_parent, nu_volatility_parent = [ + 1 / (1 / attributes[volatility_parent_idx]["pi"] + new_nu), + new_nu, + ] + + # gather volatility precisions from the child nodes + children_volatility_precision = 0.0 + for child_idx, kappas_children in zip( + edges[volatility_parent_idx].volatility_children, + attributes[volatility_parent_idx]["kappas_children"], + ): + nu_children = attributes[child_idx]["nu"] + pihat_children = attributes[child_idx]["pihat"] + pi_children = attributes[child_idx]["pi"] + vope_children = ( + 1 / attributes[child_idx]["pi"] + + (attributes[child_idx]["mu"] - attributes[child_idx]["muhat"]) ** 2 + ) * attributes[child_idx]["pihat"] - 1 + + children_volatility_precision += ( + 0.5 + * (kappas_children * nu_children * pihat_children) ** 2 + * (1 + (1 - 1 / (nu_children * pi_children)) * vope_children) + ) + + pi_volatility_parent = pihat_volatility_parent + children_volatility_precision + + pi_volatility_parent = jnp.where( + pi_volatility_parent <= 0, jnp.nan, pi_volatility_parent + ) + + return pi_volatility_parent, nu_volatility_parent + + +@partial(jit, static_argnames=("edges", "volatility_parent_idx")) +def update_mean_volatility_parent( + attributes, edges, time_step, volatility_parent_idx, pi_volatility_parent: ArrayLike +) -> Array: + # list the volatility parents of the volatility parent + volatility_parent_value_parents_idx = edges[volatility_parent_idx].value_parents + + # drift rate of the GRW + driftrate = attributes[volatility_parent_idx]["rho"] + + # Look at the (optional) va_pa's value parents + # and update drift rate accordingly + if volatility_parent_value_parents_idx is not None: + for vo_pa_va_pa, psi in zip( + volatility_parent_value_parents_idx, + attributes[volatility_parent_idx]["psi_parents"], + ): + driftrate += psi * attributes[vo_pa_va_pa]["mu"] + + muhat_volatility_parent = ( + attributes[volatility_parent_idx]["mu"] + time_step * driftrate + ) + + # gather volatility prediction errors from the child nodes + children_volatility_prediction_error = 0.0 + for child_idx, kappas_children in zip( + edges[volatility_parent_idx].volatility_children, + attributes[volatility_parent_idx]["kappas_children"], + ): + nu_children = attributes[child_idx]["nu"] + pihat_children = attributes[child_idx]["pihat"] + vope_children = ( + 1 / attributes[child_idx]["pi"] + + (attributes[child_idx]["mu"] - attributes[child_idx]["muhat"]) ** 2 + ) * attributes[child_idx]["pihat"] - 1 + children_volatility_prediction_error += ( + 0.5 + * kappas_children + * nu_children + * pihat_children + / pi_volatility_parent + * vope_children + ) + + mu_volatility_parent = ( + muhat_volatility_parent + children_volatility_prediction_error + ) + + return mu_volatility_parent