Skip to content

Commit

Permalink
Add support for AR1 processes (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico authored Oct 13, 2023
1 parent 16978e3 commit 50a6c69
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 712 deletions.
571 changes: 195 additions & 376 deletions docs/source/notebooks/0-Theory.ipynb

Large diffs are not rendered by default.

441 changes: 112 additions & 329 deletions docs/source/notebooks/0-Theory.md

Large diffs are not rendered by default.

34 changes: 31 additions & 3 deletions src/pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,8 @@ def add_value_parent(
precision: Union[float, np.ndarray, ArrayLike] = 1.0,
tonic_volatility: Union[float, np.ndarray, ArrayLike] = -4.0,
tonic_drift: Union[float, np.ndarray, ArrayLike] = 0.0,
autoregressive_coefficient: float = 0.0,
autoregressive_intercept: float = 0.0,
additional_parameters: Optional[Dict] = None,
):
"""Add a value parent to a given set of nodes.
Expand All @@ -621,8 +623,9 @@ def add_value_parent(
children_idxs :
The child(s) node index(es).
value_coupling :
The value_coupling between the child and parent node. This is will be
appended to the `psis` parameters in the parent and child node(s).
The value coupling between the child and parent node. This is will be
appended to the `value_coupling_children` parameters in the parent node,
and to the `value_coupling_parents` in the child node(s).
mean :
The mean of the Gaussian distribution. This value is passed both to the
current and expected states.
Expand All @@ -634,6 +637,15 @@ def add_value_parent(
volatility parent(s)).
tonic_drift :
The drift of the random walk. Defaults to `0.0` (no drift).
autoregressive_coefficient :
The autoregressive coefficient is only used to parametrize the AR1 process
and represents the autoregressive coefficient. If
:math:`-1 \\le \\phi \\le 1`, the process is stationary and will revert to
the autoregressive intercept.
autoregressive_intercept :
The parameter `m` is only used to parametrize the AR1 process and represents
the autoregressive intercept. If :math:`-1 \\le \\phi \\le 1`, this is the
value to which the process will revert to.
additional_parameters :
Add more custom parameters to the node.
Expand Down Expand Up @@ -664,6 +676,8 @@ def add_value_parent(
"value_coupling_parents": None,
"tonic_volatility": tonic_volatility,
"tonic_drift": tonic_drift,
"autoregressive_coefficient": autoregressive_coefficient,
"autoregressive_intercept": autoregressive_intercept,
}

# add more parameters (optional)
Expand Down Expand Up @@ -717,6 +731,8 @@ def add_volatility_parent(
precision: Union[float, np.ndarray, ArrayLike] = 1.0,
tonic_volatility: Union[float, np.ndarray, ArrayLike] = -4.0,
tonic_drift: Union[float, np.ndarray, ArrayLike] = 0.0,
autoregressive_coefficient: float = 0.0,
autoregressive_intercept: float = 0.0,
additional_parameters: Optional[Dict] = None,
):
"""Add a volatility parent to a given set of nodes.
Expand All @@ -727,7 +743,8 @@ def add_volatility_parent(
The child(s) node index(es).
volatility_coupling :
The volatility coupling between the child and parent node. This is will be
appended to the `kappas` parameters in the parent and child node(s).
appended to the `volatility_coupling_children` parameters in the parent
node, and to the `volatility_coupling_parents` in the child node(s).
mean :
The mean of the Gaussian distribution. This value is passed both to the
current and expected states.
Expand All @@ -739,6 +756,15 @@ def add_volatility_parent(
volatility parent(s)).
tonic_drift :
The drift of the random walk. Defaults to `0.0` (no drift).
autoregressive_coefficient :
The autoregressive coefficient is only used to parametrize the AR1 process
and represents the autoregressive coefficient. If
:math:`-1 \\le \\phi \\le 1`, the process is stationary and will revert to
the autoregressive intercept.
autoregressive_intercept :
The parameter `m` is only used to parametrize the AR1 process and represents
the autoregressive intercept. If :math:`-1 \\le \\phi \\le 1`, this is the
value to which the process will revert to.
additional_parameters :
Add more custom parameters to the node.
Expand Down Expand Up @@ -769,6 +795,8 @@ def add_volatility_parent(
"value_coupling_parents": None,
"tonic_volatility": tonic_volatility,
"tonic_drift": tonic_drift,
"autoregressive_coefficient": autoregressive_coefficient,
"autoregressive_intercept": autoregressive_intercept,
}

# add more parameters (optional)
Expand Down
17 changes: 15 additions & 2 deletions src/pyhgf/updates/prediction/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,21 @@ def predict_mean(
):
driftrate += psi * attributes[value_parent_idx]["mean"]

# Compute the new expected mean this node
expected_mean = attributes[node_idx]["mean"] + time_step * driftrate
# New expected mean from the previous value
expected_mean = attributes[node_idx]["mean"]

# Take the drift into account
expected_mean += time_step * driftrate

# Add quatities that come from the autoregressive process if not zero
expected_mean += (
time_step
* attributes[node_idx]["autoregressive_coefficient"]
* (
attributes[node_idx]["autoregressive_intercept"]
- attributes[node_idx]["mean"]
)
)

return expected_mean

Expand Down
6 changes: 6 additions & 0 deletions tests/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def test_update_binary_input_parents(self):
"value_coupling_parents": (1.0,),
"volatility_coupling_parents": None,
"volatility_coupling_children": None,
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
Expand All @@ -75,6 +77,8 @@ def test_update_binary_input_parents(self):
"value_coupling_parents": None,
"volatility_coupling_parents": (1.0,),
"volatility_coupling_children": None,
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
Expand All @@ -87,6 +91,8 @@ def test_update_binary_input_parents(self):
"value_coupling_parents": None,
"volatility_coupling_parents": None,
"volatility_coupling_children": (1.0,),
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_continuous_node_update(self):
"value_coupling_parents": None,
"volatility_coupling_parents": None,
"volatility_coupling_children": None,
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
Expand All @@ -49,6 +51,8 @@ def test_continuous_node_update(self):
"value_coupling_parents": None,
"volatility_coupling_parents": None,
"volatility_coupling_children": None,
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
Expand Down Expand Up @@ -108,6 +112,8 @@ def test_continuous_input_update(self):
"value_coupling_parents": None,
"volatility_coupling_parents": (1.0,),
"volatility_coupling_children": None,
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
Expand All @@ -120,6 +126,8 @@ def test_continuous_input_update(self):
"value_coupling_parents": None,
"volatility_coupling_parents": None,
"volatility_coupling_children": (1.0,),
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": 1.0,
"tonic_drift": 0.0,
Expand Down Expand Up @@ -190,6 +198,8 @@ def test_scan_loop(self):
"value_coupling_parents": None,
"volatility_coupling_parents": (1.0,),
"volatility_coupling_children": None,
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": -3.0,
"tonic_drift": 0.0,
Expand All @@ -202,6 +212,8 @@ def test_scan_loop(self):
"value_coupling_parents": None,
"volatility_coupling_parents": None,
"volatility_coupling_children": (1.0,),
"autoregressive_coefficient": 0.0,
"autoregressive_intercept": 0.0,
"mean": 1.0,
"tonic_volatility": -3.0,
"tonic_drift": 0.0,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_grad_logp(self):
np.array(0.0),
)

assert jnp.isclose(tonic_volatility_1, -8.440489)
assert jnp.isclose(tonic_volatility_1, -8.440622)

##############
# Binary HGF #
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_aesara_grad_logp(self):
volatility_coupling_1=1.0,
)[0].eval()

assert jnp.isclose(tonic_volatility_1, -8.440489)
assert jnp.isclose(tonic_volatility_1, -8.440622)

##############
# Binary HGF #
Expand Down

0 comments on commit 50a6c69

Please sign in to comment.