Skip to content

Commit

Permalink
volatility coupling for input nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 31, 2023
1 parent 1cbc780 commit a27d602
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 49 deletions.
1 change: 1 addition & 0 deletions src/pyhgf/updates/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def continuous_input_prediction_error(
precision_value_parent = prediction_error_input_precision_value_parent(
attributes, edges, value_parent_idx
)

# Estimate the new mean of the value parent
mean_value_parent = prediction_error_input_mean_value_parent(
attributes, edges, value_parent_idx, precision_value_parent
Expand Down
120 changes: 71 additions & 49 deletions src/pyhgf/updates/prediction_error/inputs/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def prediction_error_input_precision_value_parent(
Returns
-------
pi_value_parent :
The updated value for the precision of the value parent (:math:`\\pi`).
precision_value_parent :
The updated value for the precision of the value parent.
See Also
--------
Expand All @@ -50,23 +50,30 @@ def prediction_error_input_precision_value_parent(

# Gather precision updates from all child nodes if the parent has many children.
# This part corresponds to the sum over children for the multi-children situations.
children_precisions = 0.0
expected_precision_children = 0.0
for child_idx, value_coupling in zip(
edges[value_parent_idx].value_children, # type: ignore
attributes[value_parent_idx]["value_coupling_children"],
):
if edges[child_idx].volatility_parents is None:
expected_precision_child = attributes[child_idx]["expected_precision"]
else:
expected_precision_child = attributes[
edges[child_idx].volatility_parents[0]
]["expected_mean"]
# the expected precision in the input node
expected_precision_child = attributes[child_idx]["expected_precision"]

# add the precision from volatility parents if any
if edges[child_idx].volatility_parents is not None:
volatility_coupling = attributes[edges[child_idx].volatility_parents[0]][
"volatility_coupling_children"
][0]
expected_precision_child *= 1 / jnp.exp(
attributes[edges[child_idx].volatility_parents[0]]["expected_mean"]
* volatility_coupling
)

children_precisions += value_coupling**2 * expected_precision_child
expected_precision_children += expected_precision_child * value_coupling**2

# Estimate new value for the precision of the value parent
precision_value_parent = expected_precision_value_parent + children_precisions

precision_value_parent = (
expected_precision_value_parent + expected_precision_children
)
return precision_value_parent


Expand Down Expand Up @@ -113,42 +120,43 @@ def prediction_error_input_precision_volatility_parent(
]

# gather volatility precisions from the child nodes
children_volatility_precision = 0.0
for child_idx, kappas_children in zip(
precision_volatility_children = 0.0
for child_idx, volatility_coupling in zip(
edges[volatility_parent_idx].volatility_children, # type: ignore
attributes[volatility_parent_idx]["volatility_coupling_children"],
):
# retireve the index of the value parent (assuming a unique value parent)
# we need this to compute the value PE, required for the volatility PE
this_value_parent_idx = edges[child_idx].value_parents[0]
this_volatility_parent_idx = edges[child_idx].volatility_parents[0]

# compute the expected precision from the input node
expected_precision_child = attributes[child_idx]["expected_precision"]

# add the precision from the volatility parent if any
expected_precision_child *= 1 / jnp.exp(
attributes[edges[child_idx].volatility_parents[0]]["expected_mean"]
* volatility_coupling
)

# compute the volatility prediction error for this input node
child_volatility_prediction_error = (
1 / attributes[this_volatility_parent_idx]["mean"]
+ (
expected_precision_child / attributes[this_value_parent_idx]["precision"]
+ expected_precision_child
* (
attributes[child_idx]["value"]
- attributes[this_value_parent_idx]["expected_mean"]
- attributes[this_value_parent_idx]["mean"]
)
** 2
) * attributes[this_volatility_parent_idx]["expected_mean"] - 1
- 1
)

children_volatility_precision += (
0.5
* (
kappas_children
* attributes[this_volatility_parent_idx]["expected_mean"]
)
** 2
* (
1
+ (1 - 1 / (attributes[this_volatility_parent_idx]["mean"]))
* child_volatility_prediction_error
)
precision_volatility_children += (
0.5 * volatility_coupling**2 * (1 + child_volatility_prediction_error)
)

# Estimate the new precision of the volatility parent
precision_volatility_parent = (
expected_precision_volatility_parent + children_volatility_precision
expected_precision_volatility_parent + precision_volatility_children
)
precision_volatility_parent = jnp.where(
precision_volatility_parent <= 0, jnp.nan, precision_volatility_parent
Expand Down Expand Up @@ -205,31 +213,40 @@ def prediction_error_input_mean_volatility_parent(

# Gather volatility prediction errors from the child nodes
children_volatility_prediction_error = 0.0
for child_idx, kappas_children in zip(
for child_idx, volatility_coupling in zip(
edges[volatility_parent_idx].volatility_children, # type: ignore
attributes[volatility_parent_idx]["volatility_coupling_children"],
):
# retireve the index of the value parent (assuming a unique value parent)
# we need this to compute the value PE, required for the volatility PE
this_value_parent_idx = edges[child_idx].value_parents[0]

# compute the volatility PE for this input node
# compute the expected precision from the input node
expected_precision_child = attributes[child_idx]["expected_precision"]

# add the precision from the volatility parent if anyu
expected_precision_child *= 1 / jnp.exp(
attributes[edges[child_idx].volatility_parents[0]]["expected_mean"]
* volatility_coupling
)

# compute the volatility prediction error for this input node
child_volatility_prediction_error = (
1 / attributes[volatility_parent_idx]["mean"]
+ (
expected_precision_child / attributes[this_value_parent_idx]["precision"]
+ expected_precision_child
* (
attributes[child_idx]["value"]
- attributes[this_value_parent_idx]["expected_mean"]
- attributes[this_value_parent_idx]["mean"]
)
** 2
) * attributes[volatility_parent_idx]["expected_mean"] - 1
- 1
)

# sum over all input nodes
children_volatility_prediction_error += (
0.5
* kappas_children
* attributes[volatility_parent_idx]["expected_mean"]
/ precision_volatility_parent
* child_volatility_prediction_error
* (volatility_coupling / attributes[volatility_parent_idx]["precision"])
)

# Estimate the new mean of the volatility parent
Expand Down Expand Up @@ -300,16 +317,21 @@ def prediction_error_input_mean_value_parent(
- attributes[this_value_parent_idx]["expected_mean"]
)

if edges[child_idx].volatility_parents is None:
expected_precision_child = attributes[child_idx]["expected_precision"]
else:
expected_precision_child = attributes[
edges[child_idx].volatility_parents[0]
]["expected_mean"]
expected_precision_child = attributes[child_idx]["expected_precision"]
if edges[child_idx].volatility_parents is not None:
volatility_coupling = attributes[edges[child_idx].volatility_parents[0]][
"volatility_coupling_children"
][0]
expected_precision_child *= 1 / jnp.exp(
attributes[edges[child_idx].volatility_parents[0]]["expected_mean"]
* volatility_coupling
)

children_prediction_errors += (
value_coupling * expected_precision_child * child_value_prediction_error
) / precision_value_parent
(expected_precision_child / precision_value_parent)
* child_value_prediction_error
* value_coupling
)

# Compute the new mean of the value parent
mean_value_parent = (
Expand Down

0 comments on commit a27d602

Please sign in to comment.