Skip to content

Commit

Permalink
refactor update for continuous nodes (volatility parents)
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 15, 2023
1 parent 0bf0296 commit ded0e23
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 126 deletions.
132 changes: 18 additions & 114 deletions src/pyhgf/updates/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
from jax.typing import ArrayLike

from pyhgf.typing import Edges
from pyhgf.updates.posterior.continuous import continuous_node_update_value_parent
from pyhgf.updates.posterior.continuous import (
update_value_parent,
update_volatility_parent,
)


@partial(jit, static_argnames=("edges", "node_idx"))
def continuous_node_update(
attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args
) -> Dict:
"""Update the value and volatility parent(s) of a continuous node.
"""Update the posterior of the value and volatility parent(s) of a continuous node.
Updating the node's parents is a two-step process:
1. Update value parent(s).
Expand All @@ -32,26 +35,21 @@ def continuous_node_update(
Parameters
----------
attributes :
The structure of nodes' parameters. Each parameter is a dictionary with the
following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for
continuous nodes.
.. note::
The parameter structure also incorporate the value and volatility coupling
strenght with children and parents (i.e. `"psis_parents"`, `"psis_children"`,
`"kappas_parents"`, `"kappas_children"`).
The nodes' parameters.
time_step :
The interval between the previous time point and the current time point.
node_idx :
Pointer to the node that needs to be updated. After continuous updates, the
parameters of value and volatility parents (if any) will be different.
edges :
Tuple of :py:class:`pyhgf.typing.Indexes` with the same length as node number.
For each node, the index list value and volatility parents.
The edges of the network as a tuple of :py:class:`pyhgf.typing.Indexes` with
the same length as node number. For each node, the index list value and
volatility parents.
Returns
-------
attributes :
The updated parameters structure.
The updated nodes' parameters.
See Also
--------
Expand All @@ -76,20 +74,15 @@ def continuous_node_update(
# Update value parents #
########################
if value_parents_idxs is not None:
# the strength of the value coupling between the base node and the parents nodes
psis = attributes[node_idx]["psis_parents"]

for value_parent_idx, psi in zip(value_parents_idxs, psis):
for value_parent_idx in value_parents_idxs:
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[value_parent_idx].value_children[-1] == node_idx:
(
pi_value_parent,
mu_value_parent,
nu_value_parent,
) = continuous_node_update_value_parent(
attributes, edges, time_step, value_parent_idx, psi
)
) = update_value_parent(attributes, edges, time_step, value_parent_idx)

# Update this parent's parameters
attributes[value_parent_idx]["pi"] = pi_value_parent
Expand All @@ -104,101 +97,12 @@ def continuous_node_update(
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[volatility_parent_idx].volatility_children[-1] == node_idx:
# list the value and volatility parents
volatility_parent_value_parents_idx = edges[
volatility_parent_idx
].value_parents
volatility_parent_volatility_parents_idx = edges[
volatility_parent_idx
].volatility_parents

# Compute new value for nu and pihat
logvol = attributes[volatility_parent_idx]["omega"]

# Look at the (optional) vo_pa's volatility parents
# and update logvol accordingly
if volatility_parent_volatility_parents_idx is not None:
for vo_pa_vo_pa, k in zip(
volatility_parent_volatility_parents_idx,
attributes[volatility_parent_idx]["kappas_parents"],
):
logvol += k * attributes[vo_pa_vo_pa]["mu"]

# Estimate new_nu
new_nu = time_step * jnp.exp(logvol)
new_nu = jnp.where(new_nu > 1e-128, new_nu, jnp.nan)

pihat_volatility_parent, nu_volatility_parent = [
1 / (1 / attributes[volatility_parent_idx]["pi"] + new_nu),
new_nu,
]

# gather volatility precisions from the child nodes
children_volatility_precision = 0.0
for child_idx, kappas_children in zip(
edges[volatility_parent_idx].volatility_children,
attributes[volatility_parent_idx]["kappas_children"],
):
nu_children = attributes[child_idx]["nu"]
pihat_children = attributes[child_idx]["pihat"]
pi_children = attributes[child_idx]["pi"]
vope_children = (
1 / attributes[child_idx]["pi"]
+ (attributes[child_idx]["mu"] - attributes[child_idx]["muhat"])
** 2
) * attributes[child_idx]["pihat"] - 1

children_volatility_precision += (
0.5
* (kappas_children * nu_children * pihat_children) ** 2
* (1 + (1 - 1 / (nu_children * pi_children)) * vope_children)
)

pi_volatility_parent = (
pihat_volatility_parent + children_volatility_precision
)

pi_volatility_parent = jnp.where(
pi_volatility_parent <= 0, jnp.nan, pi_volatility_parent
)

# drift rate of the GRW
driftrate = attributes[volatility_parent_idx]["rho"]

# Look at the (optional) va_pa's value parents
# and update drift rate accordingly
if volatility_parent_value_parents_idx is not None:
for vo_pa_va_pa in volatility_parent_value_parents_idx:
driftrate += psi * attributes[vo_pa_va_pa]["mu"]

muhat_volatility_parent = (
attributes[volatility_parent_idx]["mu"] + time_step * driftrate
)

# gather volatility prediction errors from the child nodes
children_volatility_prediction_error = 0.0
for child_idx, kappas_children in zip(
edges[volatility_parent_idx].volatility_children,
attributes[volatility_parent_idx]["kappas_children"],
):
nu_children = attributes[child_idx]["nu"]
pihat_children = attributes[child_idx]["pihat"]
vope_children = (
1 / attributes[child_idx]["pi"]
+ (attributes[child_idx]["mu"] - attributes[child_idx]["muhat"])
** 2
) * attributes[child_idx]["pihat"] - 1
children_volatility_prediction_error += (
0.5
* kappas_children
* nu_children
* pihat_children
/ pi_volatility_parent
* vope_children
)

mu_volatility_parent = (
muhat_volatility_parent + children_volatility_prediction_error
(
pi_volatility_parent,
mu_volatility_parent,
nu_volatility_parent,
) = update_volatility_parent(
attributes, edges, time_step, volatility_parent_idx
)

# Update this parent's parameters
Expand Down
Loading

0 comments on commit ded0e23

Please sign in to comment.