-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e119974
commit 4f5e35e
Showing
9 changed files
with
984 additions
and
745 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .add_edges import add_edges | ||
|
||
from .beliefs_propagation import beliefs_propagation | ||
|
||
from .fill_categorical_state_node import fill_categorical_state_node | ||
|
||
from .get_input_idxs import get_input_idxs | ||
|
||
from .get_update_sequence import get_update_sequence | ||
|
||
from .list_branches import list_branches | ||
|
||
from .to_pandas import to_pandas | ||
|
||
__all__ = [ | ||
"add_edges", | ||
"beliefs_propagation", | ||
"fill_categorical_state_node", | ||
"get_input_idxs", | ||
"get_update_sequence", | ||
"list_branches", | ||
"to_pandas", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
|
||
|
||
def add_edges( | ||
attributes: Dict, | ||
edges: Edges, | ||
kind="value", | ||
parent_idxs=Union[int, List[int]], | ||
children_idxs=Union[int, List[int]], | ||
coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0, | ||
coupling_fn: Tuple[Optional[Callable], ...] = (None,), | ||
) -> Tuple: | ||
"""Add a value or volatility coupling link between a set of nodes. | ||
Parameters | ||
---------- | ||
attributes : | ||
Attributes of the neural network. | ||
edges : | ||
Edges of the neural network. | ||
kind : | ||
The kind of coupling can be `"value"` or `"volatility"`. | ||
parent_idxs : | ||
The index(es) of the parent node(s). | ||
children_idxs : | ||
The index(es) of the children node(s). | ||
coupling_strengths : | ||
The coupling strength between the parents and children. | ||
coupling_fn : | ||
Coupling function(s) between the current node and its value children. | ||
It has to be provided as a tuple. If multiple value children are specified, | ||
the coupling functions must be stated in the same order of the children. | ||
Note: if a node has multiple parents nodes with different coupling | ||
functions, a coupling function should be indicated for all the parent nodes. | ||
If no coupling function is stated, the relationship between nodes is assumed | ||
linear. | ||
""" | ||
if kind not in ["value", "volatility"]: | ||
raise ValueError( | ||
f"The kind of coupling should be value or volatility, got {kind}" | ||
) | ||
if isinstance(children_idxs, int): | ||
children_idxs = [children_idxs] | ||
assert isinstance(children_idxs, (list, tuple)) | ||
|
||
if isinstance(parent_idxs, int): | ||
parent_idxs = [parent_idxs] | ||
assert isinstance(parent_idxs, (list, tuple)) | ||
|
||
if isinstance(coupling_strengths, int): | ||
coupling_strengths = [float(coupling_strengths)] | ||
if isinstance(coupling_strengths, float): | ||
coupling_strengths = [coupling_strengths] | ||
|
||
assert isinstance(coupling_strengths, (list, tuple)) | ||
|
||
edges_as_list = list(edges) | ||
# update the parent nodes | ||
# ----------------------- | ||
for parent_idx in parent_idxs: | ||
# unpack the parent's edges | ||
( | ||
node_type, | ||
value_parents, | ||
volatility_parents, | ||
value_children, | ||
volatility_children, | ||
this_coupling_fn, | ||
) = edges_as_list[parent_idx] | ||
|
||
if kind == "value": | ||
if value_children is None: | ||
value_children = tuple(children_idxs) | ||
attributes[parent_idx]["value_coupling_children"] = tuple( | ||
coupling_strengths | ||
) | ||
else: | ||
value_children = value_children + tuple(children_idxs) | ||
attributes[parent_idx]["value_coupling_children"] += tuple( | ||
coupling_strengths | ||
) | ||
this_coupling_fn = this_coupling_fn + coupling_fn | ||
elif kind == "volatility": | ||
if volatility_children is None: | ||
volatility_children = tuple(children_idxs) | ||
attributes[parent_idx]["volatility_coupling_children"] = tuple( | ||
coupling_strengths | ||
) | ||
else: | ||
volatility_children = volatility_children + tuple(children_idxs) | ||
attributes[parent_idx]["volatility_coupling_children"] += tuple( | ||
coupling_strengths | ||
) | ||
|
||
# save the updated edges back | ||
edges_as_list[parent_idx] = AdjacencyLists( | ||
node_type, | ||
value_parents, | ||
volatility_parents, | ||
value_children, | ||
volatility_children, | ||
this_coupling_fn, | ||
) | ||
|
||
# update the children nodes | ||
# ------------------------- | ||
for children_idx in children_idxs: | ||
# unpack this node's edges | ||
( | ||
node_type, | ||
value_parents, | ||
volatility_parents, | ||
value_children, | ||
volatility_children, | ||
coupling_fn, | ||
) = edges_as_list[children_idx] | ||
|
||
if kind == "value": | ||
if value_parents is None: | ||
value_parents = tuple(parent_idxs) | ||
attributes[children_idx]["value_coupling_parents"] = tuple( | ||
coupling_strengths | ||
) | ||
else: | ||
value_parents = value_parents + tuple(parent_idxs) | ||
attributes[children_idx]["value_coupling_parents"] += tuple( | ||
coupling_strengths | ||
) | ||
elif kind == "volatility": | ||
if volatility_parents is None: | ||
volatility_parents = tuple(parent_idxs) | ||
attributes[children_idx]["volatility_coupling_parents"] = tuple( | ||
coupling_strengths | ||
) | ||
else: | ||
volatility_parents = volatility_parents + tuple(parent_idxs) | ||
attributes[children_idx]["volatility_coupling_parents"] += tuple( | ||
coupling_strengths | ||
) | ||
|
||
# save the updated edges back | ||
edges_as_list[children_idx] = AdjacencyLists( | ||
node_type, | ||
value_parents, | ||
volatility_parents, | ||
value_children, | ||
volatility_children, | ||
coupling_fn, | ||
) | ||
|
||
# convert the list back to a tuple | ||
edges = tuple(edges_as_list) | ||
|
||
return attributes, edges | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
|
||
|
||
@partial(jit, static_argnames=("update_sequence", "edges", "input_idxs")) | ||
def beliefs_propagation( | ||
attributes: Attributes, | ||
inputs: Tuple[ArrayLike, ...], | ||
update_sequence: UpdateSequence, | ||
edges: Edges, | ||
input_idxs: Tuple[int], | ||
) -> Tuple[Dict, Dict]: | ||
"""Update the network's parameters after observing new data point(s). | ||
This function performs the beliefs propagation step. Belief propagation consists in: | ||
1. A prediction sequence, from the leaves of the graph to the roots. | ||
2. The assignation of new observations to target nodes (usually the roots of the | ||
network) | ||
3. An inference step alternating between prediction errors and posterior updates, | ||
starting from the roots of the network to the leaves. | ||
This function returns a tuple of two new `parameter_structure` (i.e. the carryover | ||
and the accumulated in the context of :py:func:`jax.lax.scan`). | ||
Parameters | ||
---------- | ||
attributes : | ||
The dictionaries of nodes' parameters. This variable is updated and returned | ||
after the beliefs propagation step. | ||
inputs : | ||
A tuple of n by time steps arrays containing the new observation(s), the time | ||
steps as well as a boolean mask for observed values. The new observations are a | ||
tuple of array, with length equal to the number of input nodes. Each input node | ||
can receive observations The time steps are the last | ||
column of the array, the default is unit incrementation. | ||
update_sequence : | ||
The sequence of updates that will be applied to the node structure. | ||
edges : | ||
Information on the network's edges. | ||
input_idxs : | ||
List input indexes. | ||
Returns | ||
------- | ||
attributes, attributes : | ||
A tuple of parameters structure (carryover and accumulated). | ||
""" | ||
prediction_steps, update_steps = update_sequence | ||
|
||
# unpack input data - input_values is a tuple of n x time steps arrays | ||
(*input_data, time_step) = inputs | ||
|
||
attributes[-1]["time_step"] = time_step | ||
|
||
# Prediction sequence | ||
# ------------------- | ||
for step in prediction_steps: | ||
|
||
node_idx, update_fn = step | ||
|
||
attributes = update_fn( | ||
attributes=attributes, | ||
node_idx=node_idx, | ||
edges=edges, | ||
) | ||
|
||
# Observations | ||
# ------------ | ||
for values, observed, node_idx in zip( | ||
input_data[::2], input_data[1::2], input_idxs | ||
): | ||
|
||
attributes = set_observation( | ||
attributes=attributes, | ||
node_idx=node_idx, | ||
values=values, | ||
observed=observed, | ||
) | ||
|
||
# Update sequence | ||
# --------------- | ||
for step in update_steps: | ||
|
||
node_idx, update_fn = step | ||
|
||
attributes = update_fn( | ||
attributes=attributes, | ||
node_idx=node_idx, | ||
edges=edges, | ||
) | ||
|
||
return ( | ||
attributes, | ||
attributes, | ||
) # ("carryover", "accumulated") | ||
|
Oops, something went wrong.