Skip to content

Commit

Permalink
Add eHGF update step (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico authored Oct 2, 2023
1 parent 48cb2c9 commit d333de8
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 9 deletions.
8 changes: 8 additions & 0 deletions src/pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class HGF(object):
network only has one input node.
model_type :
The model implemented (can be `"continuous"`, `"binary"` or `"custom"`).
update_type :
The type of update to perform for volatility coupling. Can be `"eHGF"`
(defaults) or `"standard"`. The eHGF update step was proposed as an alternative
to the original definition in that it starts by updating the mean and then the
precision of the parent node, which generally reduces the errors associated with
impossible parameter space and improves sampling.
n_levels :
The number of hierarchies in the model, including the input vector. Cannot be
less than 2.
Expand All @@ -59,6 +65,7 @@ def __init__(
self,
n_levels: Optional[int] = 2,
model_type: str = "continuous",
update_type: str = "eHGF",
initial_mean: Dict = {
"1": 0.0,
"2": 0.0,
Expand Down Expand Up @@ -135,6 +142,7 @@ def __init__(
"""
self.model_type = model_type
self.update_type = update_type
self.verbose = verbose
self.n_levels = n_levels
self.edges: Edges
Expand Down
7 changes: 6 additions & 1 deletion src/pyhgf/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
continuous_input_prediction_error,
continuous_node_prediction,
continuous_node_prediction_error,
ehgf_continuous_node_prediction_error,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -333,7 +334,11 @@ def get_update_sequence(
# --------------------------

# case 1 - default to a continuous node
update_fn = continuous_node_prediction_error
# choose between the eHGF and standard update step
if hgf.update_type == "eHGF":
update_fn = ehgf_continuous_node_prediction_error
elif hgf.update_type == "standard":
update_fn = continuous_node_prediction_error
prediction_fn = continuous_node_prediction

# case 2 - this is an input node
Expand Down
125 changes: 125 additions & 0 deletions src/pyhgf/updates/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,131 @@ def continuous_node_prediction_error(
return attributes


@partial(jit, static_argnames=("edges", "node_idx"))
def ehgf_continuous_node_prediction_error(
attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args
) -> Dict:
"""eHGF prediction-error step for value and volatility parents of a continuous node.
This update step uses a different order for the mean and precision as compared to
the standard HGF, respectively:
1. Update volatility parent(s).
2. Update value parent(s).
If a value/volatility parent has multiple children, all the children will update
the parent together, therefor this function should only be called once per group
of child nodes. The method :py:meth:`pyhgf.model.HGF.get_update_sequence`
ensures that this function is only called once all the children have been
updated.
Then returns the structure of the new parameters.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
time_step :
The interval between the previous time point and the current time point.
node_idx :
Pointer to the node that needs to be updated. After continuous updates, the
parameters of value and volatility parents (if any) will be different.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
For each node, the index list value and volatility parents and children.
Returns
-------
attributes :
The updated attributes of the probabilistic nodes.
See Also
--------
update_continuous_input_parents, update_binary_input_parents
References
----------
.. [1] Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., &
Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# list value and volatility parents
value_parents_idxs = edges[node_idx].value_parents
volatility_parents_idxs = edges[node_idx].volatility_parents

# return here if no parents node are provided
if (value_parents_idxs is None) and (volatility_parents_idxs is None):
return attributes

########################
# Update value parents #
########################
if value_parents_idxs is not None:
for value_parent_idx in value_parents_idxs:
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[value_parent_idx].value_children[-1] == node_idx:
# in the eHGF update step, we use the expected precision here
# as we haven't computed it yet due to the reverse update order
precision_value_parent = attributes[value_parent_idx][
"expected_precision"
]

# Estimate the mean of the posterior distribution
mean_value_parent = prediction_error_mean_value_parent(
attributes, edges, value_parent_idx, precision_value_parent
)
# Update this parent's parameters
attributes[value_parent_idx]["mean"] = mean_value_parent

# Estimate the precision of the posterior distribution
precision_value_parent = prediction_error_precision_value_parent(
attributes, edges, value_parent_idx
)

# Update this parent's parameters
attributes[value_parent_idx]["precision"] = precision_value_parent

#############################
# Update volatility parents #
#############################
if volatility_parents_idxs is not None:
for volatility_parent_idx in volatility_parents_idxs:
# if this child is the last one relative to this parent's family, all the
# children will update the parent at once, otherwise just pass and wait
if edges[volatility_parent_idx].volatility_children[-1] == node_idx:
# in the eHGF update step, we use the expected precision here
# as we haven't computed it yet due to the reverse update order
precision_volatility_parent = attributes[volatility_parent_idx][
"expected_precision"
]

# Estimate the new mean of the volatility parent
mean_volatility_parent = prediction_error_mean_volatility_parent(
attributes,
edges,
time_step,
volatility_parent_idx,
precision_volatility_parent,
)
attributes[volatility_parent_idx]["mean"] = mean_volatility_parent

# Estimate the new precision of the volatility parent
precision_volatility_parent = (
prediction_error_precision_volatility_parent(
attributes, edges, time_step, volatility_parent_idx
)
)

# Update this parent's parameters
attributes[volatility_parent_idx][
"precision"
] = precision_volatility_parent

return attributes


@partial(jit, static_argnames=("edges", "node_idx"))
def continuous_node_prediction(
attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args
Expand Down
10 changes: 5 additions & 5 deletions tests/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_hgf_logp(self):
volatility_coupling_1=1.0,
volatility_coupling_2=jnp.nan,
)
assert jnp.isclose(logp, 1194.0072)
assert jnp.isclose(logp, 1130.5503)

##############
# Binary HGF #
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_grad_logp(self):
np.array(0.0),
)

assert jnp.isclose(tonic_volatility_1, -7.9176354)
assert jnp.isclose(tonic_volatility_1, -8.440489)

##############
# Binary HGF #
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_aesara_logp(self):
volatility_coupling_2=np.array(0.0),
).eval()

assert jnp.isclose(logp, 1194.00720215)
assert jnp.isclose(logp, 1130.55029297)

##############
# Binary HGF #
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_aesara_grad_logp(self):
volatility_coupling_1=1.0,
)[0].eval()

assert jnp.isclose(tonic_volatility_1, -7.9176354)
assert jnp.isclose(tonic_volatility_1, -8.440489)

##############
# Binary HGF #
Expand Down Expand Up @@ -382,7 +382,7 @@ def test_pymc_sampling(self):

pointslogs = model.point_logps(initial_point)
assert pointslogs["tonic_volatility_2"] == -1.39
assert pointslogs["hhgf_loglike"] == 1491.58
assert pointslogs["hhgf_loglike"] == 1442.85

with model:
idata = pm.sample(chains=2, cores=1, tune=1000)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_HGF(self):
surprise = (
two_level_continuous_hgf.surprise()
) # Sum the surprise for this model
assert jnp.isclose(surprise, -1194.0071)
assert jnp.isclose(surprise, -1130.5503)
assert len(two_level_continuous_hgf.node_trajectories[1]["mean"]) == 614

# three-level
Expand All @@ -73,11 +73,11 @@ def test_HGF(self):
)
three_level_continuous_hgf.input_data(input_data=timeserie)
surprise = three_level_continuous_hgf.surprise()
assert jnp.isclose(surprise, -976.2536)
assert jnp.isclose(surprise, -870.08887)

# test an alternative response function
sp = total_gaussian_surprise(three_level_continuous_hgf)
assert jnp.isclose(sp, 1065.8903)
assert jnp.isclose(sp, 1191.6648)

##########
# Binary #
Expand Down

0 comments on commit d333de8

Please sign in to comment.