Skip to content

Commit

Permalink
Add support for nonlinear value coupling (#215)
Browse files Browse the repository at this point in the history
- add support for non-linear value function
- add a tutorial on how to use non-linear value functions
  • Loading branch information
KoraTMontemagno authored Aug 22, 2024
1 parent f943078 commit f72ca61
Show file tree
Hide file tree
Showing 15 changed files with 1,415 additions and 95 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
branch: gh-pages

- name: Deploy Dev 🚀
if: github.event_name == 'pull_request'
if: (github.event_name == 'pull_request') || (github.event_name == 'push')
uses: JamesIves/github-pages-deploy-action@v4
with:
folder: docs/build/html
Expand Down
Binary file added docs/source/images/non_linear_coupling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 11 additions & 3 deletions docs/source/learn.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,27 +121,35 @@ Advanced customisation of predictive coding neural networks and Bayesian modelli
::::{grid} 1 1 2 3
:gutter: 1

:::{grid-item-card} Using custom response functions
:::{grid-item-card} Using custom response functions
:link: custom_response_functions
:link-type: ref
:img-top: ./images/response_models.png

How to adapt any model to specific behaviours and experimental design by using custom response functions.
:::

:::{grid-item-card} Embedding the Hierarchical Gaussian Filter in a Bayesian network for multilevel inference
:::{grid-item-card} Embedding the Hierarchical Gaussian Filter in a Bayesian network for multilevel inference
:link: multilevel_hgf
:link-type: ref
:img-top: ./images/multilevel-hgf.png

How to use any model as a distribution to perform hierarchical inference at the group level.
:::

:::{grid-item-card} Parameter recovery, prior and posterior predictive sampling
:::{grid-item-card} Parameter recovery, prior and posterior predictive sampling
:link: parameters_recovery
:link-type: ref
:img-top: ./images/parameter_recovery.png

Recovering parameters from the generative model and using the sampling functionalities to estimate prior and posterior uncertainties.
:::

:::{grid-item-card} Non-linear value coupling
:link: non_linear_coupling
:link-type: ref
:img-top: ./images/non_linear_coupling.png

Recovering parameters from the generative model and using the sampling functionalities to estimate prior and posterior uncertainties.
:::
::::
Expand Down
90 changes: 53 additions & 37 deletions docs/source/notebooks/0.2-Creating_networks.ipynb

Large diffs are not rendered by default.

1,114 changes: 1,114 additions & 0 deletions docs/source/notebooks/5-Non_linear_value_coupling.ipynb

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ @InProceedings{mathys:2020
}

@article{Vehtari:2015,
doi = {10.48550/ARXIV.1507.04544},
url = {https://arxiv.org/abs/1507.04544},
author = {Vehtari, Aki and Gelman, Andrew and Gabry, Jonah},
keywords = {Computation (stat.CO), Methodology (stat.ME), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC},
publisher = {arXiv},
year = {2015},
copyright = {arXiv.org perpetual, non-exclusive license}
}
title={Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC},
volume={27},
ISSN={1573-1375},
url={http://dx.doi.org/10.1007/s11222-016-9696-4},
DOI={10.1007/s11222-016-9696-4},
number={5},
journal={Statistics and Computing},
publisher={Springer Science and Business Media LLC},
author={Vehtari, Aki and Gelman, Andrew and Gabry, Jonah},
year={2016},
month=aug, pages={1413–1432} }
45 changes: 40 additions & 5 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def add_nodes(
value_parents: Optional[Union[List, Tuple, int]] = None,
volatility_children: Optional[Union[List, Tuple, int]] = None,
volatility_parents: Optional[Union[List, Tuple, int]] = None,
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
**additional_parameters,
):
"""Add new input/state node(s) to the neural network.
Expand Down Expand Up @@ -397,6 +398,14 @@ def add_nodes(
integer or a list of integers, in case of multiple children. The coupling
strength can be controlled by passing a tuple, where the first item is the
list of indexes, and the second item is the list of coupling strengths.
coupling_fn :
Coupling function(s) between the current node and its value children.
It has to be provided as a tuple. If multiple value children are specified,
the coupling functions must be stated in the same order of the children.
Note: if a node has multiple parents nodes with different coupling
functions, a coupling function should be indicated for all the parent nodes.
If no coupling function is stated, the relationship between nodes is assumed
linear.
**kwargs :
Additional keyword parameters will be passed and overwrite the node
attributes.
Expand All @@ -420,6 +429,16 @@ def add_nodes(
)
)

# assess children number
# this is required to ensure the coupling functions match
children_number = 1
if value_children is None:
children_number = 0
elif isinstance(value_children, int):
children_number = 1
elif isinstance(value_children, list):
children_number = len(value_children)

# transform coupling parameter into tuple of indexes and strenghts
couplings = []
for indexes in [
Expand Down Expand Up @@ -638,14 +657,19 @@ def add_nodes(

node_idx = len(self.attributes) # the index of the new node

# for mutiple value children, set a default tuple with corresponding length
if children_number != len(coupling_fn):
if coupling_fn == (None,):
coupling_fn = children_number * coupling_fn
else:
raise ValueError(
"The number of coupling fn and value children do not match"
)

# add a new edge
edges_as_list.append(
AdjacencyLists(
node_type,
None,
None,
None,
None,
node_type, None, None, None, None, coupling_fn=coupling_fn
)
)

Expand Down Expand Up @@ -684,6 +708,7 @@ def add_nodes(
parent_idxs=node_idx,
children_idxs=value_children[0],
coupling_strengths=value_children[1], # type: ignore
coupling_fn=coupling_fn,
)
if volatility_children[0] is not None:
self.add_edges(
Expand Down Expand Up @@ -788,6 +813,7 @@ def add_edges(
parent_idxs=Union[int, List[int]],
children_idxs=Union[int, List[int]],
coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0,
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
) -> "Network":
"""Add a value or volatility coupling link between a set of nodes.
Expand All @@ -801,6 +827,14 @@ def add_edges(
The index(es) of the children node(s).
coupling_strengths :
The coupling strength betwen the parents and children.
coupling_fn :
Coupling function(s) between the current node and its value children.
It has to be provided as a tuple. If multiple value children are specified,
the coupling functions must be stated in the same order of the children.
Note: if a node has multiple parents nodes with different coupling
functions, a coupling function should be indicated for all the parent nodes.
If no coupling function is stated, the relationship between nodes is assumed
linear.
"""
attributes, edges = add_edges(
Expand All @@ -810,6 +844,7 @@ def add_edges(
parent_idxs=parent_idxs,
children_idxs=children_idxs,
coupling_strengths=coupling_strengths,
coupling_fn=coupling_fn,
)

self.attributes = attributes
Expand Down
5 changes: 5 additions & 0 deletions src/pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,14 @@ def plot_network(network: "Network") -> "Source":

if value_parents is not None:
for value_parents_idx in value_parents:

# get the coupling function from the value parent
child_idx = network.edges[value_parents_idx].value_children.index(i)
coupling_fn = network.edges[value_parents_idx].coupling_fn[child_idx]
graphviz_structure.edge(
f"x_{value_parents_idx}",
f"x_{i}",
color="black" if coupling_fn is None else "black:invis:black",
)

# connect volatility parents
Expand Down
4 changes: 4 additions & 0 deletions src/pyhgf/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ class AdjacencyLists(NamedTuple):
mean and unknown variance.
* 4: Dirichlet Process state node.
The variable `coupling_fn` list the coupling functions between this nodes and the
children nodes. If `None` is provided, a linear coupling is assumed.
"""

node_type: int
value_parents: Optional[Tuple]
volatility_parents: Optional[Tuple]
value_children: Optional[Tuple]
volatility_children: Optional[Tuple]
coupling_fn: Tuple[Optional[Callable], ...]


class Inputs(NamedTuple):
Expand Down
73 changes: 64 additions & 9 deletions src/pyhgf/updates/posterior/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from typing import Dict

import jax.numpy as jnp
from jax import jit
from jax import grad, jit

from pyhgf.typing import Edges


@partial(jit, static_argnames=("edges", "node_idx"))
def posterior_update_mean_continuous_node(
attributes: Dict, edges: Edges, node_idx: int, node_precision: float
attributes: Dict,
edges: Edges,
node_idx: int,
node_precision: float,
) -> float:
r"""Update the mean of a state node using the value prediction errors.
Expand All @@ -20,6 +23,8 @@ def posterior_update_mean_continuous_node(
The new mean of a state node :math:`b` value coupled with other input and/or state
nodes :math:`j` at time :math:`k` is given by:
For linear value coupling:
.. math::
\mu_b^{(k)} = \hat{\mu}_b^{(k)} + \sum_{j=1}^{N_{children}}
\frac{\kappa_j \hat{\pi}_j^{(k)}}{\pi_b} \delta_j^{(k)}
Expand All @@ -32,6 +37,14 @@ def posterior_update_mean_continuous_node(
If the child node is a state node, this value was computed by
:py:func:`pyhgf.updates.prediction_errors.nodes.continuous.continuous_node_value_prediction_error`.
For non-linear value coupling:
.. math::
\mu_b^{(k)} = \hat{\mu}_b^{(k)} + \sum_{j=1}^{N_{children}}
\frac{\kappa_j g'_{j,b}({\mu}_b^{(k-1)}) \hat{\pi}_j^{(k)}}{\pi_b}
\delta_j^{(k)}
2. Mean update from volatility coupling.
The new mean of a state node :math:`b` volatility coupled with other input and/or
Expand Down Expand Up @@ -115,9 +128,10 @@ def posterior_update_mean_continuous_node(
# Value coupling updates - update the mean of a value parent
# ----------------------------------------------------------
if edges[node_idx].value_children is not None:
for value_child_idx, value_coupling in zip(
for value_child_idx, value_coupling, coupling_fn in zip(
edges[node_idx].value_children, # type: ignore
attributes[node_idx]["value_coupling_children"],
edges[node_idx].coupling_fn,
):
# get the value prediction error (VAPE)
# if this is jnp.nan (no observation) set the VAPE to 0.0
Expand All @@ -128,11 +142,22 @@ def posterior_update_mean_continuous_node(
# cancel the prediction error if the child value was not observed
value_prediction_error *= attributes[value_child_idx]["observed"]

# get differential of coupling function with value children
if coupling_fn is None: # linear coupling
coupling_fn_prime = 1
else: # non-linear coupling
# Compute the derivative of the coupling function
coupling_fn_prime = grad(coupling_fn)(attributes[node_idx]["mean"])

# expected precisions from the value children
# sum the precision weigthed prediction errors over all children
value_precision_weigthed_prediction_error += (
(
(value_coupling * attributes[value_child_idx]["expected_precision"])
(
value_coupling
* attributes[value_child_idx]["expected_precision"]
* coupling_fn_prime
)
/ node_precision
)
) * value_prediction_error
Expand All @@ -149,7 +174,8 @@ def posterior_update_mean_continuous_node(
"volatility_prediction_error"
]

# retrieve the effective precision (γ) computed during the prediction step
# retrieve the effective precision (γ)
# computed during the prediction step
effective_precision = attributes[volatility_child_idx]["temp"][
"effective_precision"
]
Expand Down Expand Up @@ -197,6 +223,8 @@ def posterior_update_precision_continuous_node(
The new precision of a state node :math:`b` value coupled with other input and/or
state nodes :math:`j` at time :math:`k` is given by:
For linear coupling (default)
.. math::
\pi_b^{(k)} = \hat{\pi}_b^{(k)} + \sum_{j=1}^{N_{children}}
Expand All @@ -210,6 +238,13 @@ def posterior_update_precision_continuous_node(
If the child node is a state node, this value was computed by
:py:func:`pyhgf.updates.prediction_errors.nodes.continuous.continuous_node_value_prediction_error`.
For non-linear value coupling:
.. math::
\pi_b^{(k)} = \hat{\pi}_b^{(k)} + \sum_{j=1}^{N_{children}}
\hat{\pi}_j^{(k)} * (\kappa_j^2 * g'_{j,b}(\mu_b^(k-1))^2 -
g''_{j,b}(\mu_b^(k-1))*\delta_j)
#. Precision update from volatility coupling.
Expand Down Expand Up @@ -284,13 +319,30 @@ def posterior_update_precision_continuous_node(
# Value coupling updates - update the precision of a value parent
# ---------------------------------------------------------------
if edges[node_idx].value_children is not None:
for value_child_idx, value_coupling in zip(
for value_child_idx, value_coupling, coupling_fn in zip(
edges[node_idx].value_children, # type: ignore
attributes[node_idx]["value_coupling_children"],
edges[node_idx].coupling_fn,
):
if coupling_fn is None: # linear coupling
coupling_fn_prime = 1
coupling_fn_second = 0
else: # non-linear coupling
coupling_fn_prime = grad(coupling_fn)(attributes[node_idx]["mean"]) ** 2
value_prediction_error = attributes[value_child_idx]["temp"][
"value_prediction_error"
]
coupling_fn_second = (
grad(grad(coupling_fn))(attributes[node_idx]["mean"])
* value_prediction_error
)

# cancel the prediction error if the child value was not observed
precision_weigthed_prediction_error += (
value_coupling**2 * attributes[value_child_idx]["expected_precision"]
value_coupling**2
* attributes[value_child_idx]["expected_precision"]
* coupling_fn_prime
- coupling_fn_second
) * attributes[value_child_idx]["observed"]

# Volatility coupling updates - update the precision of a volatility parent
Expand Down Expand Up @@ -334,7 +386,7 @@ def posterior_update_precision_continuous_node(
)

# additionnal steps for unobserved values
# ----------------------------------------------------------------------------------
# ---------------------------------------

# List the node's volatility parents
volatility_parents_idxs = edges[node_idx].volatility_parents
Expand Down Expand Up @@ -493,7 +545,10 @@ def continuous_node_update_ehgf(
attributes[node_idx]["mean"] = posterior_mean

posterior_precision = posterior_update_precision_continuous_node(
attributes, edges, node_idx, time_step
attributes,
edges,
node_idx,
time_step,
)
attributes[node_idx]["precision"] = posterior_precision

Expand Down
Loading

0 comments on commit f72ca61

Please sign in to comment.