Skip to content

Commit

Permalink
fixing various continuous input cases
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 19, 2024
1 parent 0da432a commit 42385c2
Show file tree
Hide file tree
Showing 10 changed files with 528 additions and 873 deletions.
7 changes: 4 additions & 3 deletions docs/source/notebooks/0.1-Theory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Last updated: Fri Aug 23 2024\n",
"Last updated: Wed Sep 18 2024\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.12.3\n",
Expand All @@ -743,10 +743,11 @@
"jax : 0.4.31\n",
"jaxlib: 0.4.31\n",
"\n",
"matplotlib: 3.8.4\n",
"numpy : 1.26.0\n",
"seaborn : 0.13.2\n",
"sys : 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0]\n",
"IPython : 8.23.0\n",
"numpy : 1.26.0\n",
"matplotlib: 3.8.4\n",
"\n",
"Watermark: 2.4.3\n",
"\n"
Expand Down
621 changes: 288 additions & 333 deletions docs/source/notebooks/0.2-Creating_networks.ipynb

Large diffs are not rendered by default.

62 changes: 31 additions & 31 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

547 changes: 110 additions & 437 deletions docs/source/notebooks/1.3-Continuous_HGF.ipynb

Large diffs are not rendered by default.

129 changes: 68 additions & 61 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def input_idxs(self):
for idx in input_idxs:
if self.edges[idx].node_type == 2:
self.attributes[idx]["autoconnection_strength"] = 0.0
self.attributes[idx]["tonic_volatility"] = 0.0
return input_idxs

@input_idxs.setter
Expand Down Expand Up @@ -466,7 +467,7 @@ def add_nodes(
elif kind == "binary-state":
default_parameters = {
"observed": 1,
"mean": 0.5,
"mean": 0,
"expected_mean": 0.5,
"precision": 1.0,
"expected_precision": 1.0,
Expand Down
5 changes: 0 additions & 5 deletions src/pyhgf/updates/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,4 @@ def set_observation(
attributes[node_idx]["mean"] = value
attributes[node_idx]["observed"] = observed

# this step is central as it cancel the diffusion of uncertainty and fix
# the expected precision to the precision of the input node

attributes[node_idx]["expected_precision"] = attributes[node_idx]["precision"]

return attributes
14 changes: 13 additions & 1 deletion src/pyhgf/updates/prediction/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,19 @@ def continuous_node_prediction(
)

# Update this node's parameters
attributes[node_idx]["expected_precision"] = expected_precision

# 1. input node without volatility parent
if (
(edges[node_idx].value_children is None)
and (edges[node_idx].volatility_children is None)
and (edges[node_idx].volatility_parents is None)
):
attributes[node_idx]["expected_precision"] = attributes[node_idx]["precision"]

# 2. regular continuous state node, or input with volatility parent
else:
attributes[node_idx]["expected_precision"] = expected_precision

attributes[node_idx]["temp"]["effective_precision"] = effective_precision
attributes[node_idx]["expected_mean"] = expected_mean

Expand Down
10 changes: 10 additions & 0 deletions src/pyhgf/updates/prediction_error/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def continuous_node_value_prediction_error(
attributes[node_idx]["mean"] - attributes[node_idx]["expected_mean"]
)

# divide by the number of value parents
if attributes[node_idx]["value_coupling_parents"] is not None:
value_prediction_error /= len(attributes[node_idx]["value_coupling_parents"])

# send to the value parent node for later use in the update step
attributes[node_idx]["temp"]["value_prediction_error"] = value_prediction_error

Expand Down Expand Up @@ -101,6 +105,12 @@ def continuous_node_volatility_prediction_error(
- 1
)

# divide by the number of volatility parents
if attributes[node_idx]["volatility_coupling_parents"] is not None:
volatility_prediction_error /= len(
attributes[node_idx]["volatility_coupling_parents"]
)

attributes[node_idx]["temp"][
"volatility_prediction_error"
] = volatility_prediction_error
Expand Down
3 changes: 2 additions & 1 deletion src/pyhgf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ def to_pandas(network: "Network") -> pd.DataFrame:
[
(f"x_{i}_{var}", network.node_trajectories[i][var])
for i in states_indexes
for var in ["mean", "precision", "expected_mean", "expected_precision"]
for var in network.node_trajectories[i].keys()
if (("mean" in var) or ("precision" in var))
]
)
)
Expand Down

0 comments on commit 42385c2

Please sign in to comment.