Skip to content

Commit

Permalink
add equation for input nodes noise parents
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Nov 7, 2023
1 parent 55653fc commit 83fc5c7
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 97 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,7 @@ slideshow:
%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
```

```{code-cell} ipython3
```
120 changes: 89 additions & 31 deletions src/pyhgf/updates/prediction_error/inputs/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def prediction_error_input_precision_value_parent(
attributes: Dict, edges: Edges, value_parent_idx: int
) -> Array:
r"""Send prediction-error and update the precision of a continuous value parent.
r"""Update the precision of the value parent of a continuous input node.
Parameters
----------
Expand Down Expand Up @@ -81,7 +81,37 @@ def prediction_error_input_precision_value_parent(
def prediction_error_input_precision_volatility_parent(
attributes: Dict, edges: Edges, time_step: float, volatility_parent_idx: int
) -> Array:
"""Send prediction-error and update the precision of the volatility parent.
r"""Update the precision of the volatility parent.
The new mean of the volatility parent :math:`a` of an input node at time :math:`k`
is given by:
.. math::
\pi_a^{(k)} = \hat{\pi}_a^{(k)} + \sum_{j=1}^{N_{children}} \frac{1}{2} \\
\kappa_j^2 \left( 1 + \epsilon_j^{(k)} \right)
where :math:`\kappa_j` is the volatility coupling strength between the volatility
parent and the volatility children :math:`j` and :math:`\epsilon_j^{(k)}` is the
noise prediction error given by:
.. math::
\epsilon_j^{(k)} = \frac{\hat{\pi}_j^{(k)}}{\pi_{vapa}^{(k)}} + \\
\hat{\pi}_j^{(k)} \left( u^{(k)} - \mu_{vapa}^{(k)} \right)^2 - 1
Note that, because we are working with continuous input nodes,
:math:`\epsilon_j^{(k)}` is not a function of the value prediction error but uses
the posterior of the value parent(s).
The expected precision of the input is the sum of the tonic and phasic volatility,
given by:
.. math::
\hat{\pi}_j^{(k)} = \frac{1}{\zeta} * \frac{1}{e^{\kappa_j \mu_a}}
where :math:`\zeta` is the continuous input precision (in real space).
Parameters
----------
Expand All @@ -99,7 +129,7 @@ def prediction_error_input_precision_volatility_parent(
Returns
-------
precision_volatility_parent :
The new precision of the value parent.
The new precision of the volatility parent.
See Also
--------
Expand All @@ -125,22 +155,21 @@ def prediction_error_input_precision_volatility_parent(
edges[volatility_parent_idx].volatility_children, # type: ignore
attributes[volatility_parent_idx]["volatility_coupling_children"],
):
# retireve the index of the value parent (assuming a unique value parent)
# we need this to compute the value PE, required for the volatility PE
this_value_parent_idx = edges[child_idx].value_parents[0]

# compute the expected precision from the input node
expected_precision_child = attributes[child_idx]["expected_precision"]

# add the precision from the volatility parent if any
expected_precision_child *= 1 / jnp.exp(
attributes[edges[child_idx].volatility_parents[0]]["expected_mean"]
* volatility_coupling
attributes[volatility_parent_idx]["expected_mean"] * volatility_coupling
)

# retireve the index of the value parent (assuming a unique value parent)
# we need this to compute the value PE, required for the volatility PE
this_value_parent_idx = edges[child_idx].value_parents[0]

# compute the volatility prediction error for this input node
child_volatility_prediction_error = (
expected_precision_child / attributes[this_value_parent_idx]["precision"]
(expected_precision_child / attributes[this_value_parent_idx]["precision"])
+ expected_precision_child
* (
attributes[child_idx]["value"]
Expand Down Expand Up @@ -173,7 +202,37 @@ def prediction_error_input_mean_volatility_parent(
volatility_parent_idx: int,
precision_volatility_parent: ArrayLike,
) -> Array:
r"""Send prediction-error and update the mean of the volatility parent.
r"""Update the mean of the volatility parent of a continuous input node.
The new mean of the volatility parent :math:`a` of an input node at time :math:`k`
is given by:
.. math::
\mu_a^{(k)} = \hat{\mu}_a^{(k)} + \frac{1}{2\pi_a} \\
\sum_{j=1}^{N_{children}} \kappa_j\epsilon_j^{(k)}
where :math:`\kappa_j` is the volatility coupling strength between the volatility
parent and the volatility children :math:`j` and :math:`\epsilon_j^{(k)}` is the
noise prediction error given by:
.. math::
\epsilon_j^{(k)} = \frac{\hat{\pi}_j^{(k)}}{\pi_{vapa}^{(k)}} + \\
\hat{\pi}_j^{(k)} \left( u^{(k)} - \mu_{vapa}^{(k)} \right)^2 - 1
Note that, because we are working with continuous input nodes,
:math:`\epsilon_j^{(k)}` is not a function of the value prediction error but uses
the posterior of the value parent(s).
The expected precision of the input is the sum of the tonic and phasic volatility,
given by:
.. math::
\hat{\pi}_j^{(k)} = \frac{1}{\zeta} * \frac{1}{e^{\kappa_j \mu_a}}
where :math:`\zeta` is the continuous input precision (in real space).
Parameters
----------
Expand All @@ -193,11 +252,11 @@ def prediction_error_input_mean_volatility_parent(
Returns
-------
mean_volatility_parent :
The updated value for the mean of the value parent (:math:`\\mu`).
The new mean of the volatility parent.
See Also
--------
prediction_error_volatility_volatility_parent
prediction_error_input_precision_volatility_parent
References
----------
Expand All @@ -206,31 +265,29 @@ def prediction_error_input_mean_volatility_parent(
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# Get the current expected mean for the volatility parent
# The prediction sequence was triggered by the new observation so this value is
# already in the node attributes
# Get the current expected mean for the volatility parent - this assumes that the
# value was computed in the prediction step before
expected_mean_volatility_parent = attributes[volatility_parent_idx]["expected_mean"]

# Gather volatility prediction errors from the child nodes
children_volatility_prediction_error = 0.0
children_noise_prediction_error = 0.0
for child_idx, volatility_coupling in zip(
edges[volatility_parent_idx].volatility_children, # type: ignore
attributes[volatility_parent_idx]["volatility_coupling_children"],
):
# retireve the index of the value parent (assuming a unique value parent)
# we need this to compute the value PE, required for the volatility PE
this_value_parent_idx = edges[child_idx].value_parents[0]

# compute the expected precision from the input node
# compute the total volatility at the level of the child node
expected_precision_child = attributes[child_idx]["expected_precision"]

# add the precision from the volatility parent if anyu
# add the precision from the volatility parent if any
expected_precision_child *= 1 / jnp.exp(
attributes[edges[child_idx].volatility_parents[0]]["expected_mean"]
* volatility_coupling
attributes[volatility_parent_idx]["expected_mean"] * volatility_coupling
)

# compute the volatility prediction error for this input node
# retireve the index of the value parent (assuming a unique value parent)
# we need this to compute the value PE, required for the volatility PE
this_value_parent_idx = edges[child_idx].value_parents[0]

# compute the noise prediction error for this input node (ε)
child_volatility_prediction_error = (
expected_precision_child / attributes[this_value_parent_idx]["precision"]
+ expected_precision_child
Expand All @@ -243,15 +300,16 @@ def prediction_error_input_mean_volatility_parent(
)

# sum over all input nodes
children_volatility_prediction_error += (
0.5
* child_volatility_prediction_error
* (volatility_coupling / attributes[volatility_parent_idx]["precision"])
children_noise_prediction_error += (
child_volatility_prediction_error * volatility_coupling
)

# scale using the expected precision of the volatility parent
children_noise_prediction_error *= 1 / (2 * precision_volatility_parent)

# Estimate the new mean of the volatility parent
mean_volatility_parent = (
expected_mean_volatility_parent + children_volatility_prediction_error
expected_mean_volatility_parent + children_noise_prediction_error
)

return mean_volatility_parent
Expand Down
2 changes: 1 addition & 1 deletion src/pyhgf/updates/prediction_error/nodes/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def prediction_error_mean_volatility_parent(
.. math::
\mu_a^{(k)} = \hat{\mu}_a^{(k)} + \frac{1}{2\pi_a} \\
\sum_{j=1}^{N_{children}} \kappa_j^2 \gamma_j^{(k)} \Delta_j^{(k)}
\sum_{j=1}^{N_{children}} \kappa_j \gamma_j^{(k)} \Delta_j^{(k)}
where :math:`\kappa_j` is the volatility coupling strength between the volatility
parent and the volatility children :math:`j` and :math:`\Delta_j^{(k)}` is the
Expand Down
14 changes: 9 additions & 5 deletions tests/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_update_binary_input_parents(self):
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
"temp": {"predicted_volatility": 0.0},
}
node_parameters_2 = {
"expected_precision": 1.0,
Expand All @@ -82,6 +83,7 @@ def test_update_binary_input_parents(self):
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
"temp": {"predicted_volatility": 0.0},
}
node_parameters_3 = {
"expected_precision": 1.0,
Expand All @@ -96,6 +98,7 @@ def test_update_binary_input_parents(self):
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
"temp": {"predicted_volatility": 0.0},
}

edges = (
Expand Down Expand Up @@ -148,7 +151,7 @@ def test_update_binary_input_parents(self):
assert jnp.isclose(new_attributes[2][idx], val)
for idx, val in zip(
["mean", "expected_mean", "precision", "expected_precision"],
[0.5611493, 1.0, 0.5380009, 0.26894143],
[0.5050575, 1.0, 0.47702926, 0.26894143],
):
assert jnp.isclose(new_attributes[3][idx], val)

Expand All @@ -168,19 +171,20 @@ def test_update_binary_input_parents(self):

# Run the entire for loop
last, _ = scan(scan_fn, attributes, data)
for idx, val in zip(["surprise", "value"], [3.1497865, 0.0]):
for idx, val in zip(["surprise", "value"], [3.1274157, 0.0]):
assert jnp.isclose(last[0][idx], val)
for idx, val in zip(
["mean", "expected_mean", "expected_precision"], [0.0, 0.9571387, 24.37586]
["mean", "expected_mean", "expected_precision"],
[0.0, 0.95616907, 23.860779],
):
assert jnp.isclose(last[1][idx], val)
for idx, val in zip(
["mean", "expected_mean", "precision", "expected_precision"],
[-2.358439, 3.1059794, 0.17515838, 0.13413419],
[-2.1582031, 3.0825963, 0.18244718, 0.1405374],
):
assert jnp.isclose(last[2][idx], val)
for idx, val in zip(
["expected_mean", "expected_precision"], [-0.22977911, 0.14781797]
["expected_mean", "expected_precision"], [-0.30260748, 0.14332297]
):
assert jnp.isclose(last[3][idx], val)

Expand Down

0 comments on commit 83fc5c7

Please sign in to comment.