Skip to content

Commit

Permalink
simplify update functions
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 21, 2023
1 parent 6845d85 commit bc4cbbc
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 629 deletions.
58 changes: 25 additions & 33 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ Core functionnalities to update *binary* nodes.
.. autosummary::
:toctree: generated/pyhgf.updates.binary

binary_node_prediction
binary_input_prediction
binary_node_prediction_error
binary_node_prediction
binary_input_prediction_error
binary_input_prediction

Updating continuous nodes
=========================
Expand Down Expand Up @@ -61,24 +61,6 @@ Prediction error steps

Propagate prediction errors to the value and volatility parents of a given node.

Continuous nodes
~~~~~~~~~~~~~~~~

.. currentmodule:: pyhgf.updates.prediction_error.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.continuous

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_value_parent
prediction_error_volatility_parent
prediction_error_precision_volatility_parent
prediction_error_mean_volatility_parent
prediction_error_input_value_parent
prediction_error_input_mean_value_parent
prediction_error_input_precision_value_parent

Binary nodes
~~~~~~~~~~~~

Expand All @@ -94,24 +76,24 @@ Binary nodes
input_surprise_inf
input_surprise_reg

Prediction steps
================

Compute the expectation for future observation given the influence of parent nodes.

Continuous nodes
~~~~~~~~~~~~~~~~

.. currentmodule:: pyhgf.updates.prediction.continuous
.. currentmodule:: pyhgf.updates.prediction_error.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction.continuous
:toctree: generated/pyhgf.updates.prediction_error.continuous

prediction_mean_value_parent
prediction_precision_value_parent
prediction_value_parent
prediction_precision_volatility_parent
prediction_mean_volatility_parent
prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_precision_volatility_parent
prediction_error_mean_volatility_parent
prediction_error_input_mean_value_parent

Prediction steps
================

Compute the expectation for future observation given the influence of parent nodes.

Binary nodes
~~~~~~~~~~~~
Expand All @@ -121,8 +103,18 @@ Binary nodes
.. autosummary::
:toctree: generated/pyhgf.updates.prediction.binary

prediction_input_value_parent
predict_input_value_parent

Continuous nodes
~~~~~~~~~~~~~~~~

.. currentmodule:: pyhgf.updates.prediction.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction.continuous

predict_mean
predict_precision

Distribution
------------
Expand Down
11 changes: 7 additions & 4 deletions src/pyhgf/updates/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from jax import jit

from pyhgf.typing import Edges
from pyhgf.updates.continuous import prediction_value_parent
from pyhgf.updates.prediction.binary import prediction_input_value_parent
from pyhgf.updates.continuous import predict_mean, predict_precision
from pyhgf.updates.prediction.binary import predict_input_value_parent
from pyhgf.updates.prediction_error.binary import (
prediction_error_input_value_parent,
prediction_error_value_parent,
Expand Down Expand Up @@ -132,7 +132,10 @@ def binary_node_prediction(
################################################################
if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
pihat_value_parent, muhat_value_parent = prediction_value_parent(
pihat_value_parent = predict_precision(
attributes, edges, time_step, value_parent_idx
)
muhat_value_parent = predict_mean(
attributes, edges, time_step, value_parent_idx
)

Expand Down Expand Up @@ -284,7 +287,7 @@ def binary_input_prediction(
#######################################################
if value_parent_idxs is not None:
for value_parent_idx in value_parent_idxs:
pihat_value_parent, muhat_value_parent = prediction_input_value_parent(
pihat_value_parent, muhat_value_parent = predict_input_value_parent(
attributes, edges, time_step, value_parent_idx
)

Expand Down
63 changes: 40 additions & 23 deletions src/pyhgf/updates/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
from jax import jit

from pyhgf.typing import Edges
from pyhgf.updates.prediction.continuous import (
prediction_input_value_parent,
prediction_value_parent,
prediction_volatility_parent,
)
from pyhgf.updates.prediction.continuous import predict_mean, predict_precision
from pyhgf.updates.prediction_error.continuous import (
prediction_error_input_value_parent,
prediction_error_value_parent,
prediction_error_volatility_parent,
prediction_error_input_mean_value_parent,
prediction_error_mean_value_parent,
prediction_error_mean_volatility_parent,
prediction_error_precision_value_parent,
prediction_error_precision_volatility_parent,
)


Expand Down Expand Up @@ -82,9 +80,14 @@ def continuous_node_prediction_error(
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[value_parent_idx].value_children[-1] == node_idx:
(pi_value_parent, mu_value_parent) = prediction_error_value_parent(
# Estimate the precision of the posterior distribution
pi_value_parent = prediction_error_precision_value_parent(
attributes, edges, value_parent_idx
)
# Estimate the mean of the posterior distribution
mu_value_parent = prediction_error_mean_value_parent(
attributes, edges, value_parent_idx, pi_value_parent
)

# Update this parent's parameters
attributes[value_parent_idx]["pi"] = pi_value_parent
Expand All @@ -98,12 +101,18 @@ def continuous_node_prediction_error(
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[volatility_parent_idx].volatility_children[-1] == node_idx:
(
pi_volatility_parent,
mu_volatility_parent,
) = prediction_error_volatility_parent(
# Estimate the new precision of the volatility parent
pi_volatility_parent = prediction_error_precision_volatility_parent(
attributes, edges, time_step, volatility_parent_idx
)
# Estimate the new mean of the volatility parent
mu_volatility_parent = prediction_error_mean_volatility_parent(
attributes,
edges,
time_step,
volatility_parent_idx,
pi_volatility_parent,
)

# Update this parent's parameters
attributes[volatility_parent_idx]["pi"] = pi_volatility_parent
Expand Down Expand Up @@ -177,7 +186,10 @@ def continuous_node_prediction(
########################
if value_parents_idxs is not None:
for value_parent_idx in value_parents_idxs:
(pihat_value_parent, muhat_value_parent) = prediction_value_parent(
muhat_value_parent = predict_mean(
attributes, edges, time_step, value_parent_idx
)
pihat_value_parent = predict_precision(
attributes, edges, time_step, value_parent_idx
)

Expand All @@ -190,10 +202,10 @@ def continuous_node_prediction(
#############################
if volatility_parents_idxs is not None:
for volatility_parent_idx in volatility_parents_idxs:
(
pihat_volatility_parent,
muhat_volatility_parent,
) = prediction_volatility_parent(
muhat_volatility_parent = predict_mean(
attributes, edges, time_step, volatility_parent_idx
)
pihat_volatility_parent = predict_precision(
attributes, edges, time_step, volatility_parent_idx
)

Expand Down Expand Up @@ -268,12 +280,14 @@ def continuous_input_prediction_error(
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[value_parent_idx].value_children[-1] == node_idx:
(
pi_value_parent,
mu_value_parent,
) = prediction_error_input_value_parent(
# Estimate the new precision of the value parent
pi_value_parent = prediction_error_precision_value_parent(
attributes, edges, value_parent_idx
)
# Estimate the new mean of the value parent
mu_value_parent = prediction_error_input_mean_value_parent(
attributes, edges, value_parent_idx, pi_value_parent
)

# update input node's parameters
attributes[value_parent_idx]["pi"] = pi_value_parent
Expand Down Expand Up @@ -345,7 +359,10 @@ def continuous_input_prediction(
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[value_parent_idx].value_children[-1] == node_idx:
pihat_value_parent, muhat_value_parent = prediction_input_value_parent(
muhat_value_parent = predict_mean(
attributes, edges, time_step, value_parent_idx
)
pihat_value_parent = predict_precision(
attributes, edges, time_step, value_parent_idx
)

Expand Down
2 changes: 1 addition & 1 deletion src/pyhgf/updates/prediction/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@partial(jit, static_argnames=("edges", "value_parent_idx"))
def prediction_input_value_parent(
def predict_input_value_parent(
attributes: Dict,
edges: Edges,
time_step: float,
Expand Down
Loading

0 comments on commit bc4cbbc

Please sign in to comment.