Skip to content

Commit

Permalink
add support for volatility coupling with input nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 30, 2023
1 parent 0aa0e6a commit 1cbc780
Show file tree
Hide file tree
Showing 12 changed files with 1,735 additions and 156 deletions.
179 changes: 111 additions & 68 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

1,204 changes: 1,204 additions & 0 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

(example_1)=
# Example 2: Volatility coupling with an input node

```{code-cell} ipython3
%%capture
import sys
if 'google.colab' in sys.modules:
! pip install pyhgf
```

```{code-cell} ipython3
from pyhgf.distribution import HGFDistribution
from pyhgf.model import HGF
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
```

Where the standard continuous HGF assumes a known precision in the input node (usually set to something high), this assumption can be relaxed and the filter can also try to estimate this quantity from the data.

```{code-cell} ipython3
input_data = np.random.normal(size=1000)
```

```{code-cell} ipython3
jget_hgf = (
HGF(model_type=None)
.add_input_node(kind="continuous")
.add_value_parent(children_idxs=[0])
.add_volatility_parent(children_idxs=[0])
.add_volatility_parent(children_idxs=[1])
.init()
)
jget_hgf.plot_network()
```

```{code-cell} ipython3
jget_hgf.attributes
```

```{code-cell} ipython3
jget_hgf.input_data(input_data[:30])
```

```{code-cell} ipython3
jget_hgf.plot_trajectories()
```

```{code-cell} ipython3
jget_hgf.to_pandas()
```

```{code-cell} ipython3
```
1 change: 1 addition & 0 deletions docs/source/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ glob:
| Notebook | Colab |
| --- | ---|
| {ref}`example_1` | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_1_Heart_rate_variability.ipynb)
| {ref}`example_2` | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb)

## Exercises

Expand Down
2 changes: 1 addition & 1 deletion src/pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def add_volatility_parent(
self,
children_idxs: Union[List, int],
volatility_coupling: Union[float, np.ndarray, ArrayLike] = 1.0,
mean: Union[float, np.ndarray, ArrayLike] = 0.0,
mean: Union[float, np.ndarray, ArrayLike] = 1.0,
precision: Union[float, np.ndarray, ArrayLike] = 1.0,
tonic_volatility: Union[float, np.ndarray, ArrayLike] = -4.0,
tonic_drift: Union[float, np.ndarray, ArrayLike] = 0.0,
Expand Down
2 changes: 1 addition & 1 deletion src/pyhgf/updates/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyhgf.typing import Edges
from pyhgf.updates.prediction.binary import predict_binary_state_node
from pyhgf.updates.prediction_error.binary import (
from pyhgf.updates.prediction_error.nodes.binary import (
prediction_error_input_value_parent,
prediction_error_value_parent,
)
Expand Down
21 changes: 13 additions & 8 deletions src/pyhgf/updates/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

from pyhgf.typing import Edges
from pyhgf.updates.prediction.continuous import predict_mean, predict_precision
from pyhgf.updates.prediction_error.continuous import (
from pyhgf.updates.prediction_error.inputs.continuous import (
prediction_error_input_mean_value_parent,
prediction_error_input_mean_volatility_parent,
prediction_error_input_precision_value_parent,
prediction_error_input_precision_volatility_parent,
)
from pyhgf.updates.prediction_error.nodes.continuous import (
prediction_error_mean_value_parent,
prediction_error_mean_volatility_parent,
prediction_error_precision_value_parent,
Expand Down Expand Up @@ -361,17 +366,17 @@ def continuous_input_prediction_error(
# children will update the parent at once, otherwise just pass and wait
if edges[value_parent_idx].value_children[-1] == node_idx:
# Estimate the new precision of the value parent
pi_value_parent = prediction_error_precision_value_parent(
precision_value_parent = prediction_error_input_precision_value_parent(
attributes, edges, value_parent_idx
)
# Estimate the new mean of the value parent
mu_value_parent = prediction_error_input_mean_value_parent(
attributes, edges, value_parent_idx, pi_value_parent
mean_value_parent = prediction_error_input_mean_value_parent(
attributes, edges, value_parent_idx, precision_value_parent
)

# update input node's parameters
attributes[value_parent_idx]["precision"] = pi_value_parent
attributes[value_parent_idx]["mean"] = mu_value_parent
attributes[value_parent_idx]["precision"] = precision_value_parent
attributes[value_parent_idx]["mean"] = mean_value_parent

#############################
# Update volatility parents #
Expand All @@ -388,7 +393,7 @@ def continuous_input_prediction_error(
]

# Estimate the new mean of the volatility parent
mean_volatility_parent = prediction_error_mean_volatility_parent(
mean_volatility_parent = prediction_error_input_mean_volatility_parent(
attributes,
edges,
time_step,
Expand All @@ -399,7 +404,7 @@ def continuous_input_prediction_error(

# Estimate the new precision of the volatility parent
precision_volatility_parent = (
prediction_error_precision_volatility_parent(
prediction_error_input_precision_volatility_parent(
attributes, edges, time_step, volatility_parent_idx
)
)
Expand Down
Empty file.
Loading

0 comments on commit 1cbc780

Please sign in to comment.