Skip to content

Commit

Permalink
split posterior update of precision into two branches
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Dec 18, 2024
1 parent 80ab968 commit c2a1204
Showing 1 changed file with 94 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import jax.numpy as jnp
from jax import grad, jit
from jax.lax import cond
from jax.tree_util import Partial

from pyhgf.typing import Edges

Expand All @@ -31,7 +33,7 @@ def posterior_update_precision_continuous_node(
Where :math:`\kappa_j` is the volatility coupling strength between the child node
and the state node and :math:`\delta_j^{(k)}` is the value prediction error that
was computed before hand by
was computed beforehand by
:py:func:`pyhgf.updates.prediction_errors.continuous.continuous_node_value_prediction_error`.
For non-linear value coupling:
Expand Down Expand Up @@ -80,8 +82,9 @@ def posterior_update_precision_continuous_node(
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
For each node, the index list value and volatility parents and children.
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
of nodes. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
time_step :
Expand All @@ -108,6 +111,60 @@ def posterior_update_precision_continuous_node(
Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# ----------------------------------------------------------------------------------
# Decide which update to use depending on the presence of observed value in the
# children nodes. If no values were observed, the precision should increase
# as a function of time using the function precision_missing_values(). Otherwise,
# we use regular HGF updates for value and volatility couplings.
# ----------------------------------------------------------------------------------

# For all children, get the `observed` flag - if all these values are 0.0, the node
# has not received any observations and we should call precision_missing_values()
observations = []
if edges[node_idx].value_children is not None:
for children_idx in edges[node_idx].value_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
if edges[node_idx].volatility_children is not None:
for children_idx in edges[node_idx].volatility_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
observations = jnp.any(jnp.array(observations))

posterior_precision = cond(
observations,
Partial(precision_update, edges=edges, node_idx=node_idx),
Partial(precision_update_missing_values, edges=edges, node_idx=node_idx),
attributes,
)

return posterior_precision


@partial(jit, static_argnames=("edges", "node_idx"))
def precision_update(attributes: Dict, edges: Edges, node_idx: int) -> float:
"""Compute new precision in the case of observed values.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
of nodes. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
time_step :
The time elapsed between this observation and the previous one.
Returns
-------
posterior_precision :
The new posterior precision when at least one of the children has
observed a new value. We then use the regular HGF update for volatility
coupling.
"""
# sum the prediction errors from both value and volatility coupling
precision_weigthed_prediction_error = 0.0
Expand Down Expand Up @@ -177,13 +234,41 @@ def posterior_update_precision_continuous_node(
)

# ensure the new precision is greater than 0
observed_posterior_precision = jnp.where(
posterior_precision = jnp.where(
posterior_precision > 1e-128, posterior_precision, jnp.nan
)

# additionnal steps for unobserved values
# ---------------------------------------
return posterior_precision


@partial(jit, static_argnames=("edges", "node_idx"))
def precision_update_missing_values(
attributes: Dict, edges: Edges, node_idx: int
) -> float:
"""Compute new precision in the case of missing observations.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
of nodes. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
time_step :
The time elapsed between this observation and the previous one.
Returns
-------
posterior_precision_missing_values :
The new posterior precision in the case of missing values in all child nodes.
The new precision decreases proportionally to the time elapsed, accounting for
the influence of volatility parents.
"""
# List the node's volatility parents
volatility_parents_idxs = edges[node_idx].volatility_parents

Expand All @@ -201,29 +286,13 @@ def posterior_update_precision_continuous_node(
volatility_coupling * attributes[volatility_parents_idx]["mean"]
)

# compute the predicted_volatility from the total volatility
# compute the new predicted_volatility from the total volatility
time_step = attributes[-1]["time_step"]
predicted_volatility = time_step * jnp.exp(total_volatility)

# Estimate the new precision for the continuous state node
unobserved_posterior_precision = 1 / (
posterior_precision_missing_values = 1 / (
(1 / attributes[node_idx]["precision"]) + predicted_volatility
)

# for all children, look at the values of VAPE
# if all these values are NaNs, the node has not received observations
observations = []
if edges[node_idx].value_children is not None:
for children_idx in edges[node_idx].value_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
if edges[node_idx].volatility_children is not None:
for children_idx in edges[node_idx].volatility_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
observations = jnp.any(jnp.array(observations))

posterior_precision = (
unobserved_posterior_precision * (1 - observations) # type: ignore
+ observed_posterior_precision * observations
)

return posterior_precision
return posterior_precision_missing_values

0 comments on commit c2a1204

Please sign in to comment.