Skip to content

Commit

Permalink
prediction update apply to the target node only (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico authored Sep 22, 2023
1 parent 8a5537b commit e2a2e7f
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 280 deletions.
2 changes: 0 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ Core functionnalities to update *binary* nodes.
binary_node_prediction_error
binary_node_prediction
binary_input_prediction_error
binary_input_prediction

Updating continuous nodes
=========================
Expand All @@ -42,7 +41,6 @@ Core functionnalities to update *continuous* nodes.
continuous_node_prediction_error
continuous_node_prediction
continuous_input_prediction_error
continuous_input_prediction

Updating categorical nodes
==========================
Expand Down
22 changes: 8 additions & 14 deletions src/pyhgf/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
from pyhgf.math import gaussian_surprise
from pyhgf.typing import Indexes, UpdateSequence
from pyhgf.updates.binary import (
binary_input_prediction,
binary_input_prediction_error,
binary_node_prediction,
binary_node_prediction_error,
)
from pyhgf.updates.categorical import categorical_input_update
from pyhgf.updates.continuous import (
continuous_input_prediction,
continuous_input_prediction_error,
continuous_node_prediction,
continuous_node_prediction_error,
Expand Down Expand Up @@ -314,15 +312,9 @@ def get_update_sequence(
)

for node_idx in node_idxs:
# if the node has no parent, exit here
if (hgf.edges[node_idx].value_parents is None) & (
hgf.edges[node_idx].volatility_parents is None
):
continue

# if this node is part of a familly (the parent has multiple children)
# if this node is part of a family (the parent has multiple children)
# and this is not the last of the children, exit here
# we apply this principle for every value / volatility parent
# we apply this principle for every value / volatility families
is_youngest = False
if node_idx in hgf.input_nodes_idx.idx:
is_youngest = True # always update input nodes
Expand Down Expand Up @@ -355,10 +347,8 @@ def get_update_sequence(
][0]
if model_kind == "binary":
update_fn = binary_input_prediction_error
prediction_fn = binary_input_prediction
elif model_kind == "continuous":
update_fn = continuous_input_prediction_error
prediction_fn = continuous_input_prediction
elif model_kind == "categorical":
continue

Expand All @@ -380,12 +370,16 @@ def get_update_sequence(

# create a new update and prediction sequence step and add it to the list
# only the youngest of the family is updated, but all nodes get predictions
# ensure that the node is the youngest of the family (also implicitely ensure
# that it is not orphan, otherwise skip this the node will only have prediction)
if is_youngest:
new__update_sequence = node_idx, update_fn
update_sequence.append(new__update_sequence)

new_prediction_sequence = node_idx, prediction_fn
prediction_sequence.append(new_prediction_sequence)
# no prediction step for an input node
if node_idx not in hgf.input_nodes_idx.idx:
new_prediction_sequence = node_idx, prediction_fn
prediction_sequence.append(new_prediction_sequence)

# search recursively for the next update steps - make sure that all the
# children have been updated before updating the parent(s)
Expand Down
109 changes: 8 additions & 101 deletions src/pyhgf/updates/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from jax import jit

from pyhgf.typing import Edges
from pyhgf.updates.continuous import predict_mean, predict_precision
from pyhgf.updates.prediction.binary import predict_input_value_parent
from pyhgf.updates.prediction.binary import predict_binary_state_node
from pyhgf.updates.prediction_error.binary import (
prediction_error_input_value_parent,
prediction_error_value_parent,
Expand Down Expand Up @@ -84,11 +83,7 @@ def binary_node_prediction_error(
def binary_node_prediction(
attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args
) -> Dict:
"""Update the value parent(s) of a binary node.
In a three-level HGF, this step will update the node :math:`x_2`.
Then returns the new node tuple `(parameters, value_parents, volatility_parents)`.
"""Update the expected mean and precision of a binary state node.
Parameters
----------
Expand All @@ -111,7 +106,7 @@ def binary_node_prediction(
Returns
-------
attributes :
The updated node structure.
The new node structure with updated mean and expected precision.
References
----------
Expand All @@ -120,28 +115,12 @@ def binary_node_prediction(
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# using the current node index, unwrap parameters and parents
value_parent_idxs = edges[node_idx].value_parents

# Return here if no parents node are found
if value_parent_idxs is None:
return attributes

################################################################
# Update the predictions of the continuous value parents (x-2) #
################################################################
if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
pihat_value_parent = predict_precision(
attributes, edges, time_step, value_parent_idx
)
muhat_value_parent = predict_mean(
attributes, edges, time_step, value_parent_idx
)
# Get the new expected value for the mean and precision
pihat, muhat = predict_binary_state_node(attributes, edges, time_step, node_idx)

# update the parent nodes' parameters
attributes[value_parent_idx]["pihat"] = pihat_value_parent
attributes[value_parent_idx]["muhat"] = muhat_value_parent
# Update the node's attributes
attributes[node_idx]["pihat"] = pihat
attributes[node_idx]["muhat"] = muhat

return attributes

Expand Down Expand Up @@ -224,75 +203,3 @@ def binary_input_prediction_error(
attributes[node_idx]["surprise"] = surprise

return attributes


@partial(jit, static_argnames=("edges", "node_idx"))
def binary_input_prediction(
attributes: Dict,
time_step: float,
node_idx: int,
edges: Edges,
value: float,
) -> Dict:
"""Update the input node structure given one binary observation.
This function is the entry-level of the binary node. It updates the parents of
the input node (:math:`x_1`).
Parameters
----------
value :
The new observed value.
time_step :
The interval between the previous time point and the current time point.
attributes :
The attributes of the probabilistic nodes.
.. note::
`"psis"` is the value coupling strength. It should have the same length as the
volatility parents' indexes. `"kappas"` is the volatility coupling strength.
It should have the same length as the volatility parents' indexes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
For each node, the index list value and volatility parents and children.
node_idx :
Pointer to the node that needs to be updated. After continuous updates, the
parameters of value and volatility parents (if any) will be different.
Returns
-------
attributes :
The updated parameters structure.
See Also
--------
update_continuous_parents, update_continuous_input_parents
References
----------
.. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., &
Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# list value and volatility parents
value_parent_idxs = edges[node_idx].value_parents
volatility_parent_idxs = edges[node_idx].volatility_parents

if (value_parent_idxs is None) and (volatility_parent_idxs is None):
return attributes

#######################################################
# Update the value parent(s) of the binary input node #
#######################################################
if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
pihat_value_parent, muhat_value_parent = predict_input_value_parent(
attributes, edges, time_step, value_parent_idx
)

# Update value parent's parameters
attributes[value_parent_idx]["pihat"] = pihat_value_parent
attributes[value_parent_idx]["muhat"] = muhat_value_parent

return attributes
139 changes: 9 additions & 130 deletions src/pyhgf/updates/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,7 @@ def continuous_node_prediction_error(
def continuous_node_prediction(
attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args
) -> Dict:
"""Prediction step for the value and volatility parents of a continuous node.
Updating the node's parents is a two-step process:
1. Update value parent(s).
2. Update volatility parent(s).
If a value/volatility parent has multiple children, all the children will update
the parent together, therefor this function should only be called once per group
of child nodes. The method :py:meth:`pyhgf.model.HGF.get_update_sequence`
ensures that this function is only called once all the children have been
updated.
Then returns the structure of the new parameters.
"""Update the expected mean and precision of a continuous node.
Parameters
----------
Expand All @@ -150,8 +138,7 @@ def continuous_node_prediction(
time_step :
The interval between the previous time point and the current time point.
node_idx :
Pointer to the node that needs to be updated. After continuous updates, the
parameters of value and volatility parents (if any) will be different.
Pointer to the node that will be updated.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
Expand All @@ -173,45 +160,14 @@ def continuous_node_prediction(
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# list value and volatility parents
value_parents_idxs = edges[node_idx].value_parents
volatility_parents_idxs = edges[node_idx].volatility_parents

# return here if no parents node are provided
if (value_parents_idxs is None) and (volatility_parents_idxs is None):
return attributes

########################
# Update value parents #
########################
if value_parents_idxs is not None:
for value_parent_idx in value_parents_idxs:
muhat_value_parent = predict_mean(
attributes, edges, time_step, value_parent_idx
)
pihat_value_parent = predict_precision(
attributes, edges, time_step, value_parent_idx
)

# Update this parent's parameters
attributes[value_parent_idx]["pihat"] = pihat_value_parent
attributes[value_parent_idx]["muhat"] = muhat_value_parent
# Get the new expected mean
muhat = predict_mean(attributes, edges, time_step, node_idx)
# Get the new expected precision
pihat = predict_precision(attributes, edges, time_step, node_idx)

#############################
# Update volatility parents #
#############################
if volatility_parents_idxs is not None:
for volatility_parent_idx in volatility_parents_idxs:
muhat_volatility_parent = predict_mean(
attributes, edges, time_step, volatility_parent_idx
)
pihat_volatility_parent = predict_precision(
attributes, edges, time_step, volatility_parent_idx
)

# Update this parent's parameters
attributes[volatility_parent_idx]["pihat"] = pihat_volatility_parent
attributes[volatility_parent_idx]["muhat"] = muhat_volatility_parent
# Update this node's parameters
attributes[node_idx]["pihat"] = pihat
attributes[node_idx]["muhat"] = muhat

return attributes

Expand Down Expand Up @@ -294,80 +250,3 @@ def continuous_input_prediction_error(
attributes[value_parent_idx]["mu"] = mu_value_parent

return attributes


@partial(jit, static_argnames=("edges", "node_idx"))
def continuous_input_prediction(
attributes: Dict,
time_step: float,
node_idx: int,
edges: Edges,
value: float,
) -> Dict:
"""Update the input node structure.
This function is the entry-level of the structure updates. It updates the parent
of the input node.
Parameters
----------
value :
The new observed value.
time_step :
The interval between the previous time point and the current time point.
attributes :
The attributes of the probabilistic nodes.
.. note::
The parameter structure also incorporate the value and volatility coupling
strenght with children and parents (i.e. `"psis_parents"`, `"psis_children"`,
`"kappas_parents"`, `"kappas_children"`).
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
For each node, the index list value and volatility parents and children.
node_idx :
Pointer to the node that needs to be updated. After continuous updates, the
parameters of value and volatility parents (if any) will be different.
Returns
-------
attributes :
The updated attributes of the probabilistic nodes.
See Also
--------
continuous_node_update, update_binary_input_parents
References
----------
.. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., &
Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# store timestep in the node's parameters
attributes[node_idx]["time_step"] = time_step

# list value and volatility parents
value_parents_idxs = edges[node_idx].value_parents

########################
# Update value parents #
########################
if value_parents_idxs is not None:
for value_parent_idx in value_parents_idxs:
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[value_parent_idx].value_children[-1] == node_idx:
muhat_value_parent = predict_mean(
attributes, edges, time_step, value_parent_idx
)
pihat_value_parent = predict_precision(
attributes, edges, time_step, value_parent_idx
)

# update input node's parameters
attributes[value_parent_idx]["pihat"] = pihat_value_parent
attributes[value_parent_idx]["muhat"] = muhat_value_parent

return attributes
Loading

0 comments on commit e2a2e7f

Please sign in to comment.