diff --git a/docs/source/api.rst b/docs/source/api.rst
index 1a0bb63f1..64abe87ce 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -94,6 +94,16 @@ Continuous nodes
predict_precision
continuous_node_prediction
+Dirichlet processes
+-------------------
+
+.. currentmodule:: pyhgf.updates.prediction.dirichlet
+
+.. autosummary::
+ :toctree: generated/pyhgf.updates.prediction.dirichlet
+
+ dirichlet_node_prediction
+
Prediction error steps
======================
@@ -161,6 +171,21 @@ Continuous state nodes
continuous_node_volatility_prediction_error
continuous_node_prediction_error
+Dirichlet processes
+^^^^^^^^^^^^^^^^^^^
+
+.. currentmodule:: pyhgf.updates.prediction_error.nodes.dirichlet
+
+.. autosummary::
+ :toctree: generated/pyhgf.updates.prediction_error.nodes.dirichlet
+
+ dirichlet_node_prediction_error
+ update_cluster
+ create_cluster
+ get_candidate
+ likely_cluster_proposal
+ clusters_likelihood
+
Distribution
************
@@ -238,6 +263,8 @@ Utilities for manipulating neural networks.
list_branches
fill_categorical_state_node
get_update_sequence
+ concatenate_networks
+ add_edges
Math
****
diff --git a/docs/source/learn.md b/docs/source/learn.md
index 86553d570..ea4173432 100644
--- a/docs/source/learn.md
+++ b/docs/source/learn.md
@@ -178,6 +178,17 @@ A generalisation of the binary Hierarchical Gaussian Filter to multiarmed bandit
::::
+### Non-parametric predictive coding
+
+::::{grid} 1 1 2 3
+
+:::{grid-item-card} Self-organizing neural network using Dirichlet Process nodes
+:link: example_3
+:link-type: ref
+
+:::
+::::
+
## Exercises
Hand-on exercises to build intuition around the main components of the HGF and use an agent that optimizes its action under noisy observations.
diff --git a/docs/source/notebooks/0.3-Generalised_filtering.ipynb b/docs/source/notebooks/0.3-Generalised_filtering.ipynb
index d089871c8..16c0b94fd 100644
--- a/docs/source/notebooks/0.3-Generalised_filtering.ipynb
+++ b/docs/source/notebooks/0.3-Generalised_filtering.ipynb
@@ -28,7 +28,15 @@
},
"tags": []
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
+ ]
+ }
+ ],
"source": [
"import jax.numpy as jnp\n",
"import matplotlib.animation as animation\n",
@@ -140,13 +148,6 @@
"tags": []
},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
- ]
- },
{
"data": {
"image/png": "",
@@ -168,7 +169,7 @@
" linestyle=\"--\",\n",
")\n",
"for i, x_i in enumerate(xs):\n",
- " xi = xi + (1 / (1 + nu)) * (Normal.sufficient_statistics(x_i) - xi)\n",
+ " xi = xi + (1 / (1 + nu)) * (Normal().sufficient_statistics(x=x_i) - xi)\n",
" nu += 1\n",
"\n",
" if i in [2, 4, 8, 16, 32, 64, 128, 256, 512, 999]:\n",
@@ -299,33 +300,33 @@
"\n",
"\n",
- "\n"
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 15,
@@ -1022,7 +1021,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Last updated: Mon Jun 10 2024\n",
+ "Last updated: Tue Jun 11 2024\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.12.3\n",
@@ -1032,10 +1031,10 @@
"jax : 0.4.27\n",
"jaxlib: 0.4.27\n",
"\n",
- "jax : 0.4.27\n",
- "numpy : 1.26.0\n",
- "matplotlib: 3.8.4\n",
"seaborn : 0.13.2\n",
+ "matplotlib: 3.8.4\n",
+ "numpy : 1.26.0\n",
+ "jax : 0.4.27\n",
"\n",
"Watermark: 2.4.3\n",
"\n"
diff --git a/src/pyhgf/math.py b/src/pyhgf/math.py
index 676ee80da..42eab0590 100644
--- a/src/pyhgf/math.py
+++ b/src/pyhgf/math.py
@@ -1,6 +1,6 @@
# Author: Nicolas Legrand
-from typing import Union
+from typing import Tuple, Union
import jax.numpy as jnp
from jax import Array
@@ -17,11 +17,13 @@ class MultivariateNormal:
"""
- def sufficient_statistics(x):
+ @staticmethod
+ def sufficient_statistics(x: ArrayLike) -> Array:
"""Compute the sufficient statistics for the multivariate normal."""
return jnp.hstack([x, jnp.outer(x, x)[jnp.tril_indices(x.shape[0])]])
- def base_measure(k):
+ @staticmethod
+ def base_measure(k: int) -> float:
"""Compute the base measures for the multivariate normal."""
return (2 * jnp.pi) ** (-k / 2)
@@ -35,16 +37,30 @@ class Normal:
"""
- def sufficient_statistics(x):
- """Compute the sufficient statistics for the univariate normal."""
+ @staticmethod
+ def sufficient_statistics(x: float) -> Array:
+ """Sufficient statistics for the univariate normal."""
return jnp.array([x, x**2])
- def base_measure(k):
- """Compute the base measure for the univariate normal."""
+ @staticmethod
+ def expected_sufficient_statistics(mu: float, sigma) -> Array:
+ """Compute expected sufficient statistics from the mean and std."""
+ return jnp.array([mu, mu**2 + sigma**2])
+
+ @staticmethod
+ def base_measure() -> float:
+ """Compute the base measure of the univariate normal."""
return 1 / (jnp.sqrt(2 * jnp.pi))
+ @staticmethod
+ def parameters(xis: ArrayLike) -> Tuple[float, float]:
+ """Get parameters from the expected sufficient statistics."""
+ mean = xis[0]
+ variance = xis[1] - (mean**2)
+ return mean, variance
+
-def gaussian_predictive_distribution(x, xi, nu):
+def gaussian_predictive_distribution(x: float, xi: ArrayLike, nu: float) -> float:
r"""Density of the Gaussian-predictive distribution.
This distribution is parametrized by hyperparameters from the exponential family as:
@@ -178,7 +194,7 @@ def gaussian_surprise(
Examples
--------
- >>> from pyhgf.continuous import gaussian_surprise
+ >>> from pyhgf.math import gaussian_surprise
>>> gaussian_surprise(x=2.0, expected_mean=0.0, expected_precision=1.0)
`Array(2.9189386, dtype=float32, weak_type=True)`
@@ -237,7 +253,7 @@ def binary_surprise_finite_precision(
expected_mean: Union[ArrayLike, float],
expected_precision: Union[ArrayLike, float],
eta0: Union[ArrayLike, float] = 0.0,
- eta1: Union[ArrayLike, float] = 0.0,
+ eta1: Union[ArrayLike, float] = 1.0,
) -> Array:
r"""Compute the binary surprise with finite precision.
@@ -264,3 +280,8 @@ def binary_surprise_finite_precision(
expected_mean * gaussian_density(value, eta1, expected_precision)
+ (1 - expected_mean) * gaussian_density(value, eta0, expected_precision)
)
+
+
+def sigmoid_inverse_temperature(x, temperature):
+ """Compute the sigmoid response function with inverse temperature parameter."""
+ return (x**temperature) / (x**temperature + (1 - x) ** temperature)
diff --git a/src/pyhgf/model/network.py b/src/pyhgf/model/network.py
index 4d026e29b..35285fbbd 100644
--- a/src/pyhgf/model/network.py
+++ b/src/pyhgf/model/network.py
@@ -1,5 +1,6 @@
# Author: Nicolas Legrand
+from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
import jax.numpy as jnp
@@ -20,6 +21,7 @@
input_types,
)
from pyhgf.utils import (
+ add_edges,
beliefs_propagation,
fill_categorical_state_node,
get_update_sequence,
@@ -28,7 +30,7 @@
class Network:
- """A generalised HGF neural network for predictive coding applications.
+ """A predictive coding neural network.
This is the core class to define and manipulate neural networks, that consists in
1. attributes, 2. structure and 3. update sequences.
@@ -189,12 +191,13 @@ def input_data(
# this is where the model loops over the whole input time series
# at each time point, the node structure is traversed and beliefs are updated
# using precision-weighted prediction errors
- _, node_trajectories = scan(
+ last_attributes, node_trajectories = scan(
self.scan_fn, self.attributes, (input_data, time_steps, observed)
)
# trajectories of the network attributes a each time point
self.node_trajectories = node_trajectories
+ self.last_attributes = last_attributes
return self
@@ -399,13 +402,31 @@ def add_nodes(
attributes.
"""
- # extract the node coupling indexes and coupling strengths
+ if kind not in [
+ "continuous-input",
+ "binary-input",
+ "categorical-input",
+ "DP-state",
+ "ef-normal",
+ "generic-input",
+ "continuous-state",
+ "binary-state",
+ ]:
+ raise ValueError(
+ (
+ "Invalid node type. Should be one of the following: "
+ "'continuous-input', 'binary-input', 'categorical-input', "
+ "'DP-state', 'continuous-state', 'binary-state', 'ef-normal'."
+ )
+ )
+
+ # transform coupling parameter into tuple of indexes and strenghts
couplings = []
for indexes in [
- value_children,
value_parents,
- volatility_children,
volatility_parents,
+ value_children,
+ volatility_children,
]:
if indexes is not None:
if isinstance(indexes, int):
@@ -420,6 +441,9 @@ def add_nodes(
else:
coupling_idxs, coupling_strengths = None, None
couplings.append((coupling_idxs, coupling_strengths))
+ value_parents, volatility_parents, value_children, volatility_children = (
+ couplings
+ )
# create the default parameters set according to the node type
if kind == "continuous-state":
@@ -428,10 +452,10 @@ def add_nodes(
"expected_mean": 0.0,
"precision": 1.0,
"expected_precision": 1.0,
- "volatility_coupling_children": couplings[2][1],
- "volatility_coupling_parents": couplings[3][1],
- "value_coupling_children": couplings[0][1],
- "value_coupling_parents": couplings[1][1],
+ "volatility_coupling_children": volatility_children[1],
+ "volatility_coupling_parents": volatility_parents[1],
+ "value_coupling_children": value_children[1],
+ "value_coupling_parents": value_parents[1],
"tonic_volatility": -4.0,
"tonic_drift": 0.0,
"autoconnection_strength": 1.0,
@@ -449,10 +473,10 @@ def add_nodes(
"expected_mean": 0.0,
"precision": 1.0,
"expected_precision": 1.0,
- "volatility_coupling_children": couplings[2][1],
- "volatility_coupling_parents": couplings[3][1],
- "value_coupling_children": couplings[0][1],
- "value_coupling_parents": couplings[1][1],
+ "volatility_coupling_children": volatility_children[1],
+ "volatility_coupling_parents": volatility_parents[1],
+ "value_coupling_children": value_children[1],
+ "value_coupling_parents": value_parents[1],
"tonic_volatility": 0.0,
"tonic_drift": 0.0,
"autoconnection_strength": 1.0,
@@ -536,11 +560,35 @@ def add_nodes(
}
elif "ef-normal" in kind:
default_parameters = {
- "nus": 0.0,
- "xis": jnp.array([0.0, 0.0]),
+ "nus": 3.0,
+ "xis": jnp.array([0.0, 1.0]),
"values": 0.0,
+ "observed": 1.0,
}
+ elif kind == "DP-state":
+ if "batch_size" in additional_parameters.keys():
+ batch_size = additional_parameters["batch_size"]
+ elif "batch_size" in node_parameters.keys():
+ batch_size = node_parameters["batch_size"]
+ else:
+ batch_size = 10
+
+ default_parameters = {
+ "batch_size": batch_size, # number of branches available in the network
+ "n": jnp.zeros(batch_size), # number of observation in each cluster
+ "n_total": 0, # the total number of observations in the node
+ "alpha": 1.0, # concentration parameter for the implied Dirichlet dist.
+ "expected_means": jnp.zeros(batch_size),
+ "expected_sigmas": jnp.ones(batch_size),
+ "sensory_precision": 1.0,
+ "activated": jnp.zeros(batch_size),
+ "value_coupling_children": (1.0,),
+ "values": 0.0,
+ "n_active_cluster": 0,
+ }
+
+ # Update the default node parameters using keywords args and dictonary
if bool(additional_parameters):
# ensure that all passed values are valid keys
invalid_keys = [
@@ -581,32 +629,37 @@ def add_nodes(
node_type = 2
elif "ef-normal" in kind:
node_type = 3
-
- # convert the structure to a list to modify it
- edges_as_list: List[AdjacencyLists] = list(self.edges)
+ elif "DP-state" in kind:
+ node_type = 4
for _ in range(n_nodes):
+ # convert the structure to a list to modify it
+ edges_as_list: List = list(self.edges)
+
node_idx = len(self.attributes) # the index of the new node
# add a new edge
edges_as_list.append(
AdjacencyLists(
node_type,
- couplings[1][0],
- couplings[3][0],
- couplings[0][0],
- couplings[2][0],
+ None,
+ None,
+ None,
+ None,
)
)
+ # convert the list back to a tuple
+ self.edges = tuple(edges_as_list)
+
if node_idx == 0:
# this is the first node, create the node structure
- self.attributes = {node_idx: node_parameters}
+ self.attributes = {node_idx: deepcopy(node_parameters)}
if input_type is not None:
self.inputs = Inputs((node_idx,), (input_type,))
else:
# update the node structure
- self.attributes[node_idx] = node_parameters
+ self.attributes[node_idx] = deepcopy(node_parameters)
if input_type is not None:
# add information about the new input node in the indexes
@@ -616,91 +669,40 @@ def add_nodes(
new_kind += (input_type,)
self.inputs = Inputs(new_idx, new_kind)
- # update the existing edge structure so it links to the new node as well
- for coupling, edge_type in zip(
- couplings,
- [
- "value_children",
- "value_parents",
- "volatility_children",
- "volatility_parents",
- ],
- ):
- if coupling[0] is not None:
- coupling_idxs, coupling_strengths = coupling
- for idx, coupling_strength in zip(
- coupling_idxs, coupling_strengths # type: ignore
- ):
- # unpack this node's edges
- (
- this_node_type,
- value_parents,
- volatility_parents,
- value_children,
- volatility_children,
- ) = edges_as_list[idx]
-
- # update the parents/children's edges depending on the coupling
- if edge_type == "value_parents":
- if value_children is None:
- value_children = (node_idx,)
- self.attributes[idx]["value_coupling_children"] = (
- coupling_strength,
- )
- else:
- value_children = value_children + (node_idx,)
- self.attributes[idx]["value_coupling_children"] += (
- coupling_strength,
- )
- elif edge_type == "volatility_parents":
- if volatility_children is None:
- volatility_children = (node_idx,)
- self.attributes[idx]["volatility_coupling_children"] = (
- coupling_strength,
- )
- else:
- volatility_children = volatility_children + (node_idx,)
- self.attributes[idx][
- "volatility_coupling_children"
- ] += (coupling_strength,)
- elif edge_type == "value_children":
- if value_parents is None:
- value_parents = (node_idx,)
- self.attributes[idx]["value_coupling_parents"] = (
- coupling_strength,
- )
- else:
- value_parents = value_parents + (node_idx,)
- self.attributes[idx]["value_coupling_parents"] += (
- coupling_strength,
- )
- elif edge_type == "volatility_children":
- if volatility_parents is None:
- volatility_parents = (node_idx,)
- self.attributes[idx]["volatility_coupling_parents"] = (
- coupling_strength,
- )
- else:
- volatility_parents = volatility_parents + (node_idx,)
- self.attributes[idx]["volatility_coupling_parents"] += (
- coupling_strength,
- )
-
- # save the updated edges back
- edges_as_list[idx] = AdjacencyLists(
- this_node_type,
- value_parents,
- volatility_parents,
- value_children,
- volatility_children,
- )
-
- # convert the list back to a tuple
- self.edges = tuple(edges_as_list)
-
- # if we are creating a categorical state or state-transition node
- # we have to generate the implied binary network(s) here
+ # Update the edges of the parents and children accordingly
+ # --------------------------------------------------------
+ if value_parents[0] is not None:
+ self.add_edges(
+ kind="value",
+ parent_idxs=value_parents[0],
+ children_idxs=node_idx,
+ coupling_strengths=value_parents[1], # type: ignore
+ )
+ if value_children[0] is not None:
+ self.add_edges(
+ kind="value",
+ parent_idxs=node_idx,
+ children_idxs=value_children[0],
+ coupling_strengths=value_children[1], # type: ignore
+ )
+ if volatility_children[0] is not None:
+ self.add_edges(
+ kind="volatility",
+ parent_idxs=node_idx,
+ children_idxs=volatility_children[0],
+ coupling_strengths=volatility_children[1], # type: ignore
+ )
+ if volatility_parents[0] is not None:
+ self.add_edges(
+ kind="volatility",
+ parent_idxs=volatility_parents[0],
+ children_idxs=node_idx,
+ coupling_strengths=volatility_parents[1], # type: ignore
+ )
+
if kind == "categorical-input":
+ # if we are creating a categorical state or state-transition node
+ # we have to generate the implied binary network(s) here
self = fill_categorical_state_node(
self,
node_idx=node_idx,
@@ -778,3 +780,39 @@ def surprise(
response_function_inputs=response_function_inputs,
response_function_parameters=response_function_parameters,
)
+ return self
+
+ def add_edges(
+ self,
+ kind="value",
+ parent_idxs=Union[int, List[int]],
+ children_idxs=Union[int, List[int]],
+ coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0,
+ ) -> "Network":
+ """Add a value or volatility coupling link between a set of nodes.
+
+ Parameters
+ ----------
+ kind :
+ The kind of coupling, can be `"value"` or `"volatility"`.
+ parent_idxs :
+ The index(es) of the parent node(s).
+ children_idxs :
+ The index(es) of the children node(s).
+ coupling_strengths :
+ The coupling strength betwen the parents and children.
+
+ """
+ attributes, edges = add_edges(
+ attributes=self.attributes,
+ edges=self.edges,
+ kind=kind,
+ parent_idxs=parent_idxs,
+ children_idxs=children_idxs,
+ coupling_strengths=coupling_strengths,
+ )
+
+ self.attributes = attributes
+ self.edges = edges
+
+ return self
diff --git a/src/pyhgf/plots.py b/src/pyhgf/plots.py
index 59fc6c925..c5fbe6c8d 100644
--- a/src/pyhgf/plots.py
+++ b/src/pyhgf/plots.py
@@ -274,10 +274,31 @@ def plot_network(network: "Network") -> "Source":
)
# create the rest of nodes
- for i in range(len(network.edges)):
- # only if node is not an input node
- if i not in network.inputs.idx:
- graphviz_structure.node(f"x_{i}", label=str(i), shape="circle")
+ for idx in range(len(network.edges)):
+
+ if network.edges[idx].node_type == 2:
+ # Continuous state nore
+ graphviz_structure.node(f"x_{idx}", label=str(idx), shape="circle")
+
+ elif network.edges[idx].node_type == 3:
+ # Exponential family state nore
+ graphviz_structure.node(
+ f"x_{idx}",
+ label=f"EF-{idx}",
+ style="filled",
+ shape="circle",
+ fillcolor="#ced6e4",
+ )
+
+ elif network.edges[idx].node_type == 4:
+ # Dirichlet PRocess state node
+ graphviz_structure.node(
+ f"x_{idx}",
+ label=f"DP-{idx}",
+ style="filled",
+ shape="doublecircle",
+ fillcolor="#e2d8c1",
+ )
# connect value parents
for i, index in enumerate(network.edges):
diff --git a/src/pyhgf/typing.py b/src/pyhgf/typing.py
index 751c81132..d7a17b702 100644
--- a/src/pyhgf/typing.py
+++ b/src/pyhgf/typing.py
@@ -12,6 +12,7 @@ class AdjacencyLists(NamedTuple):
* 2: continuous state node.
* 3: exponential family state node - univariate Gaussian distribution with unknown
mean and unknown variance.
+ * 4: Dirichlet Process state node.
"""
diff --git a/src/pyhgf/updates/posterior/exponential.py b/src/pyhgf/updates/posterior/exponential.py
index 57f5f28c3..ca0ee4579 100644
--- a/src/pyhgf/updates/posterior/exponential.py
+++ b/src/pyhgf/updates/posterior/exponential.py
@@ -3,6 +3,7 @@
from functools import partial
from typing import Callable, Dict
+import jax.numpy as jnp
from jax import jit
from pyhgf.typing import Attributes, Edges
@@ -49,11 +50,14 @@ def posterior_update_exponential_family(
"""
# update the hyperparameter vectors
- attributes[node_idx]["xis"] = attributes[node_idx]["xis"] + (
- 1 / (1 + attributes[node_idx]["nus"])
- ) * (
- sufficient_stats_fn(attributes[node_idx]["values"])
+ xis = attributes[node_idx]["xis"] + (1 / (1 + attributes[node_idx]["nus"])) * (
+ sufficient_stats_fn(x=attributes[node_idx]["values"])
- attributes[node_idx]["xis"]
)
+ # blank update in the case of unobserved value
+ attributes[node_idx]["xis"] = jnp.where(
+ attributes[node_idx]["observed"], xis, attributes[node_idx]["xis"]
+ )
+
return attributes
diff --git a/src/pyhgf/updates/prediction/dirichlet.py b/src/pyhgf/updates/prediction/dirichlet.py
new file mode 100644
index 000000000..03fb86e25
--- /dev/null
+++ b/src/pyhgf/updates/prediction/dirichlet.py
@@ -0,0 +1,55 @@
+# Author: Nicolas Legrand
+
+from typing import Dict
+
+import jax.numpy as jnp
+
+from pyhgf.math import Normal
+from pyhgf.typing import Attributes, Edges
+
+
+def dirichlet_node_prediction(
+ edges: Edges,
+ attributes: Dict,
+ node_idx: int,
+ **args,
+) -> Attributes:
+ """Prediction of a Dirichlet process node.
+
+ Parameters
+ ----------
+ edges :
+ The edges of the neural network as a tuple of
+ :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
+ For each node, the index lists the value/volatility parents/children.
+ attributes :
+ The attributes of the probabilistic nodes.
+ node_idx :
+ Pointer to the Dirichlet process input node.
+
+ Returns
+ -------
+ attributes :
+ The attributes of the probabilistic nodes.
+ edges :
+ The edges of the neural network.
+ input_nodes_idx :
+ Static input nodes' parameters for the neural network.
+ dirichlet_node :
+ Static parameters of the Dirichlet process node.
+
+ """
+ # get the parameter (mean and variance) from the EF-normal parent nodes
+ value_parent_idxs = edges[node_idx].value_parents
+ if value_parent_idxs is not None:
+ parameters = jnp.array(
+ [
+ Normal().parameters(xis=attributes[parent_idx]["xis"])
+ for parent_idx in value_parent_idxs
+ ]
+ )
+
+ attributes[node_idx]["expected_means"] = parameters[:, 0]
+ attributes[node_idx]["expected_sigmas"] = jnp.sqrt(parameters[:, 1])
+
+ return attributes
diff --git a/src/pyhgf/updates/prediction_error/nodes/dirichlet.py b/src/pyhgf/updates/prediction_error/nodes/dirichlet.py
new file mode 100644
index 000000000..cc2a50d3b
--- /dev/null
+++ b/src/pyhgf/updates/prediction_error/nodes/dirichlet.py
@@ -0,0 +1,402 @@
+# Author: Nicolas Legrand
+
+from functools import partial
+from typing import Dict, Tuple
+
+import jax.numpy as jnp
+from jax import Array, jit, random
+from jax._src.typing import Array as KeyArray
+from jax.lax import cond
+from jax.scipy.stats.norm import pdf
+from jax.tree_util import Partial
+from jax.typing import ArrayLike
+
+from pyhgf.math import Normal
+from pyhgf.typing import Attributes, Edges
+
+
+@partial(jit, static_argnames=("edges", "node_idx"))
+def dirichlet_node_prediction_error(
+ edges: Edges,
+ attributes: Dict,
+ node_idx: int,
+ **args,
+) -> Attributes:
+ """Prediction error and update the child networks of a Dirichlet process node.
+
+ When receiving a new input, this node chose to either:
+ 1. Allocate the value to a pre-existing cluster.
+ 2. Create a new cluster.
+
+ The network always contains a temporary branch as the new cluster candidate. This
+ branch is parametrized under the new observation to assess its likelihood and the
+ previous clusters' likelihood.
+
+ Parameters
+ ----------
+ edges :
+ The edges of the neural network as a tuple of
+ :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
+ For each node, the index lists the value/volatility parents/children.
+ attributes :
+ The attributes of the probabilistic nodes.
+ node_idx :
+ Pointer to the Dirichlet process input node.
+
+ Returns
+ -------
+ attributes :
+ The attributes of the probabilistic nodes.
+
+ """
+ values = attributes[node_idx]["values"] # the input value
+ alpha = attributes[node_idx]["alpha"] # the concentration parameter
+ n_total = attributes[node_idx]["n_total"] # total number of observations
+ n = attributes[node_idx]["n"] # number of observations per cluster
+ sensory_precision = attributes[node_idx][
+ "sensory_precision"
+ ] # number of observations per cluster
+
+ # likelihood of the current observation under existing clusters
+ # -------------------------------------------------------------
+ cluster_ll = clusters_likelihood(
+ value=values,
+ expected_mean=attributes[node_idx]["expected_means"],
+ expected_sigma=attributes[node_idx]["expected_sigmas"],
+ )
+
+ # set the likelihood to 0 for inactive clusters
+ cluster_ll *= attributes[node_idx]["activated"]
+
+ # likelihood of the current observation under the best candidate cluster
+ # ----------------------------------------------------------------------
+
+ # find the best cluster candidate given the new observation
+ candidate_mean, candidate_sigma = get_candidate(
+ value=values,
+ sensory_precision=sensory_precision,
+ expected_mean=attributes[node_idx]["expected_means"],
+ expected_sigma=attributes[node_idx]["expected_sigmas"],
+ )
+
+ # get the likelihood under this candidate
+ candidate_ll = clusters_likelihood(
+ value=values,
+ expected_mean=candidate_mean,
+ expected_sigma=candidate_sigma,
+ )
+
+ # DP step: compare the likelihood of existing cluster with a new cluster
+ # ----------------------------------------------------------------------
+
+ # probability of being assigned to a pre-existing cluster
+ cluster_ll *= n / (alpha + n_total)
+
+ # probability to draw a new cluster
+ candidate_ll *= alpha / (alpha + n_total)
+
+ best_val = jnp.max(cluster_ll)
+
+ # set all cluster to non-observed by default
+ for parent_idx in edges[node_idx].value_parents: # type:ignore
+ attributes[parent_idx]["observed"] = 0
+
+ # get the index of the cluster (!= the node index)
+ # depending on whether a new cluster is created or updated
+ cluster_idx = jnp.where(
+ best_val >= candidate_ll,
+ jnp.argmax(cluster_ll),
+ attributes[node_idx]["n_active_cluster"],
+ )
+
+ update_fn = Partial(
+ update_cluster,
+ edges=edges,
+ node_idx=node_idx,
+ )
+
+ create_fn = Partial(
+ create_cluster,
+ edges=edges,
+ node_idx=node_idx,
+ )
+
+ # apply either cluster update or cluster creation
+ operands = attributes, cluster_idx, values, (candidate_mean, candidate_sigma)
+
+ attributes = cond(best_val >= candidate_ll, update_fn, create_fn, operands)
+
+ attributes[node_idx]["n_total"] += 1
+
+ return attributes
+
+
+@partial(jit, static_argnames=("edges", "node_idx"))
+def update_cluster(operands: Tuple, edges: Edges, node_idx: int) -> Attributes:
+ """Update an existing cluster.
+
+ Parameters
+ ----------
+ operands :
+ Non-static parameters.
+ edges :
+ The edges of the neural network as a tuple of
+ :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
+ For each node, the index lists the value/volatility parents/children.
+ node_idx :
+ Pointer to the Dirichlet process input node.
+
+ Returns
+ -------
+ attributes :
+ The attributes of the probabilistic nodes.
+
+ """
+ attributes, cluster_idx, value, _ = operands
+
+ # activate the corresponding branch and pass the value
+ for i, value_parent_idx in enumerate(edges[node_idx].value_parents): # type: ignore
+
+ attributes[value_parent_idx]["observed"] = jnp.where(cluster_idx == i, 1.0, 0.0)
+ attributes[value_parent_idx]["values"] = value
+
+ attributes[node_idx]["n"] = (
+ attributes[node_idx]["n"]
+ .at[cluster_idx]
+ .set(attributes[node_idx]["n"][cluster_idx] + 1.0)
+ )
+
+ return attributes
+
+
+@partial(jit, static_argnames=("edges", "node_idx"))
+def create_cluster(operands: Tuple, edges: Edges, node_idx: int) -> Attributes:
+ """Create a new cluster.
+
+ Parameters
+ ----------
+ operands :
+ Non-static parameters.
+ edges :
+ The edges of the neural network as a tuple of
+ :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
+ For each node, the index lists the value/volatility parents/children.
+ node_idx :
+ Pointer to the Dirichlet process input node.
+
+ Returns
+ -------
+ attributes :
+ The attributes of the probabilistic nodes.
+
+ """
+ attributes, cluster_idx, value, (candidate_mean, candidate_sigma) = operands
+
+ # creating a new cluster
+ attributes[node_idx]["activated"] = (
+ attributes[node_idx]["activated"].at[cluster_idx].set(1)
+ )
+
+ for i, value_parent_idx in enumerate(edges[node_idx].value_parents): # type: ignore
+
+ attributes[value_parent_idx]["observed"] = 0.0
+ attributes[value_parent_idx]["values"] = value
+
+ # initialize the new cluster using candidate values
+ attributes[value_parent_idx]["xis"] = jnp.where(
+ cluster_idx == i,
+ Normal().expected_sufficient_statistics(
+ mu=candidate_mean, sigma=candidate_sigma
+ ),
+ attributes[value_parent_idx]["xis"],
+ )
+
+ attributes[node_idx]["n"] = attributes[node_idx]["n"].at[cluster_idx].set(1.0)
+ attributes[node_idx]["n_active_cluster"] += 1
+
+ return attributes
+
+
+@jit
+def get_candidate(
+ value: float,
+ sensory_precision: float,
+ expected_mean: ArrayLike,
+ expected_sigma: ArrayLike,
+ n_samples: int = 20_000,
+) -> Tuple[float, float]:
+ """Find the best cluster candidate given previous clusters and an input value.
+
+ Parameters
+ ----------
+ value :
+ The new observation.
+ sensory_precision :
+ The expected precision of the new observation.
+ expected_mean :
+ The mean of the existing clusters.
+ expected_sigma :
+ The standard deviation of the existing clusters.
+ n_samples :
+ The number of samples that should be simulated.
+
+ Returns
+ -------
+ mean :
+ The mean of the new candidate cluster.
+ sigma :
+ The standard deviation of the new candidate cluster.
+
+ """
+ # sample n likely clusters given the base distribution priors
+ mus, sigmas, weights = likely_cluster_proposal(
+ mean_mu_G0=0.0,
+ sigma_mu_G0=10.0,
+ sigma_pi_G0=3.0,
+ expected_mean=expected_mean,
+ expected_sigma=expected_sigma,
+ key=random.key(42),
+ n_samples=n_samples,
+ )
+
+ # 1 - Likelihood of the new observation under each sampled cluster
+ # ----------------------------------------------------------------
+ ll_value = pdf(value, mus, sigmas)
+ ll_value /= ll_value.sum() # normalize the weights
+
+ # 2- re-scale the weights using expected precision
+ # ------------------------------------------------
+ weights *= ll_value**sensory_precision
+
+ # only use the 1000 best candidates for inference
+ idxs = jnp.argsort(weights)
+ mus, sigmas, weights = (
+ mus[idxs][-1000:],
+ sigmas[idxs][-1000:],
+ weights[idxs][-1000:],
+ )
+
+ # 3 - estimate new mean and standard deviation using the weigthed mean
+ # --------------------------------------------------------------------
+ mean = jnp.average(mus, weights=weights)
+ sigma = jnp.average(sigmas, weights=weights)
+
+ return mean, sigma
+
+
+@partial(jit, static_argnames=("n_samples"))
+def likely_cluster_proposal(
+ mean_mu_G0: float,
+ sigma_mu_G0: float,
+ sigma_pi_G0: float,
+ expected_mean=ArrayLike,
+ expected_sigma=ArrayLike,
+ key: KeyArray = random.key(42),
+ n_samples: int = 20_000,
+) -> Tuple[Array, Array, Array]:
+ """Sample likely new belief distributions given pre-existing clusters.
+
+ Parameters
+ ----------
+ mean_mu_G0 :
+ The mean of the mean of the base distribution.
+ sigma_mu_G0 :
+ The standard deviation of mean of the base distribution.
+ sigma_pi_G0 :
+ The standard deviation of the standard deviation of the base distribution.
+ expected_mean :
+ Pre-existing clusters means.
+ expected_sigma :
+ Pre-existing clusters standard deviation.
+ key :
+ Random state.
+ n_samples :
+ The number of samples used during the simulations.
+
+ Returns
+ -------
+ new_mu :
+ A vector of means candidates.
+ new_sigma :
+ A vector of standard deviation candidates.
+ weights :
+ Weigths for each cluster candidate under pre-existing cluster (irrespective of
+ new observations).
+
+ """
+ # sample new candidate for cluster means
+ key, use_key = random.split(key)
+ new_mu = sigma_mu_G0 * random.normal(use_key, (n_samples,)) + mean_mu_G0
+
+ # sample new candidate for cluster standard deviation
+ key, use_key = random.split(key)
+ new_sigma = jnp.abs(random.normal(use_key, (n_samples,)) * sigma_pi_G0)
+
+ # 1 - Cluster specificity
+ # -----------------------
+ # this cluster should explain new dimensions, not explained by other clusters
+
+ # evidence for pre-existing clusters
+ pre_existing_likelihood = jnp.zeros(n_samples)
+ for mu_i, sigma_i in zip(expected_mean, expected_sigma):
+ pre_existing_likelihood += pdf(new_mu, mu_i, sigma_i)
+
+ # evidence for the new cluster proposal
+ new_likelihood = pdf(new_mu, new_mu, new_sigma)
+
+ # standardize the measure of cluster specificity (ratio)
+ ratio = new_likelihood / (new_likelihood + pre_existing_likelihood)
+ ratio -= ratio.min()
+ ratio /= ratio.max()
+ weights = ratio
+
+ # 2 - Cluster isolation
+ # ---------------------
+ # this cluster should not try to explain what was already explained
+
+ # (pre-existing cluster) / (pre-existing cluster + new cluster)
+ cluster_isolation = jnp.ones(n_samples)
+ for mu_i, sigma_i in zip(expected_mean, expected_sigma):
+ ratio = pdf(mu_i, mu_i, sigma_i) / (
+ pdf(mu_i, mu_i, sigma_i) + pdf(mu_i, new_mu, new_sigma)
+ )
+ cluster_isolation *= ratio
+ cluster_isolation -= cluster_isolation.min()
+ cluster_isolation /= cluster_isolation.max()
+
+ weights *= cluster_isolation
+
+ # 3 - Spread of the cluster
+ # -------------------------
+ # large clusters should be favored over small clusters
+ cluster_spread = pdf(1 / (new_sigma**2), 0.0, 5.0)
+ cluster_spread -= cluster_spread.min()
+ cluster_spread /= cluster_spread.max()
+ weights *= cluster_spread
+
+ return new_mu, new_sigma, weights
+
+
+def clusters_likelihood(
+ value: float,
+ expected_mean: ArrayLike,
+ expected_sigma: ArrayLike,
+) -> ArrayLike:
+ """Likelihood of a parametrized candidate under the new observation.
+
+ Parameters
+ ----------
+ value :
+ The new observation.
+ expected_mean :
+ Pre-existing clusters means.
+ expected_sigma :
+ Pre-existing clusters standard deviation.
+
+ Returns
+ -------
+ likelihood :
+ The probability of observing the value under each cluster.
+
+ """
+ return pdf(value, expected_mean, expected_sigma)
diff --git a/src/pyhgf/utils.py b/src/pyhgf/utils.py
index bbc36d656..8602a78e8 100644
--- a/src/pyhgf/utils.py
+++ b/src/pyhgf/utils.py
@@ -1,7 +1,7 @@
# Author: Nicolas Legrand
from functools import partial
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
import jax.numpy as jnp
import numpy as np
@@ -11,7 +11,7 @@
from jax.typing import ArrayLike
from pyhgf.math import Normal, binary_surprise, gaussian_surprise
-from pyhgf.typing import AdjacencyLists, Attributes, Structure, UpdateSequence
+from pyhgf.typing import AdjacencyLists, Attributes, Edges, Structure, UpdateSequence
from pyhgf.updates.posterior.binary import binary_node_update_infinite
from pyhgf.updates.posterior.categorical import categorical_input_update
from pyhgf.updates.posterior.continuous import (
@@ -21,6 +21,7 @@
from pyhgf.updates.posterior.exponential import posterior_update_exponential_family
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
+from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.inputs.binary import (
binary_input_prediction_error_infinite_precision,
)
@@ -34,6 +35,9 @@
from pyhgf.updates.prediction_error.nodes.continuous import (
continuous_node_prediction_error,
)
+from pyhgf.updates.prediction_error.nodes.dirichlet import (
+ dirichlet_node_prediction_error,
+)
if TYPE_CHECKING:
from pyhgf.model import Network
@@ -110,39 +114,6 @@ def beliefs_propagation(
) # ("carryover", "accumulated")
-def trim_sequence(
- exclude_node_idxs: List, update_sequence: UpdateSequence, edges: Tuple
-) -> UpdateSequence:
- """Remove steps from an update sequence that depends on a set of nodes.
-
- Parameters
- ----------
- exclude_node_idxs :
- A list of node indexes. The nodes can be input nodes or any other node in the
- network.
- update_sequence :
- The sequence of updates that will be applied to the node structure.
- edges :
- The nodes structure.
-
- Returns
- -------
- trimmed_update_sequence :
- The update sequence without the update steps for nodes depending on the root
- list.
-
- """
- # list the nodes that depend on the root indexes
- branch_list = list_branches(node_idxs=exclude_node_idxs, edges=edges)
-
- # remove the update steps that are targetting the excluded nodes
- trimmed_update_sequence = tuple(
- [seq for seq in update_sequence if seq[0] not in branch_list]
- )
-
- return trimmed_update_sequence
-
-
def list_branches(node_idxs: List, edges: Tuple, branch_list: List = []) -> List:
"""Return the branch of a network from a given set of root nodes.
@@ -322,12 +293,14 @@ def get_update_sequence(network: "Network", update_type: str) -> List:
node_without_update = [i for i in range(n_nodes)]
# start by injecting the observations in all input nodes
+ # ------------------------------------------------------
for input_idx, kind in zip(network.inputs.idx, network.inputs.kind):
if kind == 0:
update_fn = continuous_input_prediction_error
update_sequence.append((input_idx, update_fn))
elif kind == 1:
+
# add the update steps for the binary state node as well
binary_state_idx = network.edges[input_idx].value_parents[0] # type: ignore
@@ -356,13 +329,19 @@ def get_update_sequence(network: "Network", update_type: str) -> List:
update_fn = generic_input_prediction_error
update_sequence.append((input_idx, update_fn))
+ elif kind == 4:
+ update_fn = dirichlet_node_prediction_error
+ update_sequence.append((input_idx, update_fn))
+
# add the PE step to the sequence
node_without_pe.remove(input_idx)
# input node does not need to update the posterior
node_without_update.remove(input_idx)
+ # prediction errors and posterior updates
# will fail if the structure of the network does not allow a consistent update order
+ # ----------------------------------------------------------------------------------
while True:
no_update = True
@@ -400,10 +379,16 @@ def get_update_sequence(network: "Network", update_type: str) -> List:
# for the exponential family node
ef_update = Partial(
posterior_update_exponential_family,
- sufficient_stats_fn=Normal.sufficient_statistics,
+ sufficient_stats_fn=Normal().sufficient_statistics,
)
update_fn = ef_update
+ elif network.edges[idx].node_type == 4:
+
+ update_fn = None
+ # the prediction sequence is the update sequence in reverse order
+ prediction_sequence.insert(0, (idx, dirichlet_node_prediction))
+
update_sequence.append((idx, update_fn))
node_without_update.remove(idx)
@@ -425,15 +410,20 @@ def get_update_sequence(network: "Network", update_type: str) -> List:
else:
# if this node has been updated
if idx not in node_without_update:
+
+ if network.edges[idx].node_type == 2:
+ update_fn = continuous_node_prediction_error
+ elif network.edges[idx].node_type == 4:
+ update_fn = dirichlet_node_prediction_error
+
no_update = False
- update_sequence.append((idx, continuous_node_prediction_error))
+ update_sequence.append((idx, update_fn))
node_without_pe.remove(idx)
if (not node_without_pe) and (not node_without_update):
break
if no_update:
- break
raise Warning(
"The structure of the network cannot be updated consistently."
)
@@ -447,7 +437,10 @@ def get_update_sequence(network: "Network", update_type: str) -> List:
# create a new sequence step and add it to the list
prediction_sequence.append((idx, categorical_input_update))
- return prediction_sequence
+ # remove None steps and return the update sequence
+ sequence = [update for update in prediction_sequence if update[1] is not None]
+
+ return sequence
def to_pandas(network: "Network") -> pd.DataFrame:
@@ -599,3 +592,204 @@ def to_pandas(network: "Network") -> pd.DataFrame:
].sum(axis=1, min_count=1)
return trajectories_df
+
+
+def concatenate_networks(attributes_1, attributes_2, edges_1, edges_2):
+ """Concatenate two networks.
+
+ Parameters
+ ----------
+ attributes_1 :
+ The attributes of the first network.
+ attributes_2 :
+ The attributes of the second network.
+ edges_1 :
+ The edges of the first network.
+ edges_2 :
+ The edges of the second network.
+
+ Returns
+ -------
+ attributes :
+ The attribute of the concatenated networks.
+ edges :
+ The edges of the concatenated networks.
+
+ """
+ n_nodes = len(attributes_2)
+ edges_1 = list(edges_1)
+ attributes = {}
+ for i in range(len(attributes_1)):
+ # update the attributes
+ attributes[i + n_nodes] = attributes_1[i]
+
+ # update the edges
+ edges_1[i] = AdjacencyLists(
+ value_parents=(
+ tuple([e + n_nodes for e in list(edges_1[i].value_parents)])
+ if edges_1[i].value_parents is not None
+ else None
+ ),
+ volatility_parents=(
+ tuple([e + n_nodes for e in list(edges_1[i].volatility_parents)])
+ if edges_1[i].volatility_parents is not None
+ else None
+ ),
+ value_children=(
+ tuple([e + n_nodes for e in list(edges_1[i].value_children)])
+ if edges_1[i].value_children is not None
+ else None
+ ),
+ volatility_children=(
+ tuple([e + n_nodes for e in list(edges_1[i].volatility_children)])
+ if edges_1[i].volatility_children is not None
+ else None
+ ),
+ )
+
+ edges_1 = tuple(edges_1)
+
+ attributes = {**attributes_2, **attributes}
+ edges = edges_2 + edges_1
+
+ return attributes, edges
+
+
+def add_edges(
+ attributes: Dict,
+ edges: Edges,
+ kind="value",
+ parent_idxs=Union[int, List[int]],
+ children_idxs=Union[int, List[int]],
+ coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0,
+) -> Tuple:
+ """Add a value or volatility coupling link between a set of nodes.
+
+ Parameters
+ ----------
+ attributes :
+ Attributes of the neural network.
+ edges :
+ Edges of the neural network.
+ kind :
+ The kind of coupling can be `"value"` or `"volatility"`.
+ parent_idxs :
+ The index(es) of the parent node(s).
+ children_idxs :
+ The index(es) of the children node(s).
+ coupling_strengths :
+ The coupling strength between the parents and children.
+
+ """
+ if kind not in ["value", "volatility"]:
+ raise ValueError(
+ f"The kind of coupling should be value or volatility, got {kind}"
+ )
+ if isinstance(children_idxs, int):
+ children_idxs = [children_idxs]
+ assert isinstance(children_idxs, (list, tuple))
+
+ if isinstance(parent_idxs, int):
+ parent_idxs = [parent_idxs]
+ assert isinstance(parent_idxs, (list, tuple))
+
+ if isinstance(coupling_strengths, int):
+ coupling_strengths = [float(coupling_strengths)]
+ if isinstance(coupling_strengths, float):
+ coupling_strengths = [coupling_strengths]
+
+ assert isinstance(coupling_strengths, (list, tuple))
+
+ edges_as_list = list(edges)
+ # update the parent nodes
+ # -----------------------
+ for parent_idx in parent_idxs:
+ # unpack the parent's edges
+ (
+ node_type,
+ value_parents,
+ volatility_parents,
+ value_children,
+ volatility_children,
+ ) = edges_as_list[parent_idx]
+
+ if kind == "value":
+ if value_children is None:
+ value_children = tuple(children_idxs)
+ attributes[parent_idx]["value_coupling_children"] = tuple(
+ coupling_strengths
+ )
+ else:
+ value_children = value_children + tuple(children_idxs)
+ attributes[parent_idx]["value_coupling_children"] += tuple(
+ coupling_strengths
+ )
+ elif kind == "volatility":
+ if volatility_children is None:
+ volatility_children = tuple(children_idxs)
+ attributes[parent_idx]["volatility_coupling_children"] = tuple(
+ coupling_strengths
+ )
+ else:
+ volatility_children = volatility_children + tuple(children_idxs)
+ attributes[parent_idx]["volatility_coupling_children"] += tuple(
+ coupling_strengths
+ )
+
+ # save the updated edges back
+ edges_as_list[parent_idx] = AdjacencyLists(
+ node_type,
+ value_parents,
+ volatility_parents,
+ value_children,
+ volatility_children,
+ )
+
+ # update the children nodes
+ # -------------------------
+ for children_idx in children_idxs:
+ # unpack this node's edges
+ (
+ node_type,
+ value_parents,
+ volatility_parents,
+ value_children,
+ volatility_children,
+ ) = edges_as_list[children_idx]
+
+ if kind == "value":
+ if value_parents is None:
+ value_parents = tuple(parent_idxs)
+ attributes[children_idx]["value_coupling_parents"] = tuple(
+ coupling_strengths
+ )
+ else:
+ value_parents = value_parents + tuple(parent_idxs)
+ attributes[children_idx]["value_coupling_parents"] += tuple(
+ coupling_strengths
+ )
+ elif kind == "volatility":
+ if volatility_parents is None:
+ volatility_parents = tuple(parent_idxs)
+ attributes[children_idx]["volatility_coupling_parents"] = tuple(
+ coupling_strengths
+ )
+ else:
+ volatility_parents = volatility_parents + tuple(parent_idxs)
+ attributes[children_idx]["volatility_coupling_parents"] += tuple(
+ coupling_strengths
+ )
+
+ # save the updated edges back
+ edges_as_list[children_idx] = AdjacencyLists(
+ node_type,
+ value_parents,
+ volatility_parents,
+ value_children,
+ volatility_children,
+ )
+
+ # convert the list back to a tuple
+ edges = tuple(edges_as_list)
+
+ return attributes, edges
diff --git a/tests/test_math.py b/tests/test_math.py
new file mode 100644
index 000000000..146498b0c
--- /dev/null
+++ b/tests/test_math.py
@@ -0,0 +1,56 @@
+# Author: Nicolas Legrand
+
+import unittest
+from unittest import TestCase
+
+import jax.numpy as jnp
+
+from pyhgf.math import (
+ MultivariateNormal,
+ Normal,
+ binary_surprise_finite_precision,
+ gaussian_predictive_distribution,
+)
+
+
+class TestMath(TestCase):
+ def test_multivariate_normal(self):
+
+ ss = MultivariateNormal.sufficient_statistics(jnp.array([1.0, 2.0]))
+ assert jnp.isclose(
+ ss, jnp.array([1.0, 2.0, 1.0, 2.0, 4.0], dtype="float32")
+ ).all()
+
+ bm = MultivariateNormal.base_measure(2)
+ assert bm == 0.15915494309189535
+
+ def test_normal(self):
+
+ ss = Normal.sufficient_statistics(jnp.array(1.0))
+ assert jnp.isclose(ss, jnp.array([1.0, 1.0], dtype="float32")).all()
+
+ bm = Normal.base_measure()
+ assert bm == 0.3989423
+
+ ess = Normal.expected_sufficient_statistics(mu=0.0, sigma=1.0)
+ assert jnp.isclose(ess, jnp.array([0.0, 1.0], dtype="float32")).all()
+
+ def test_gaussian_predictive_distribution(self):
+
+ pdf = gaussian_predictive_distribution(x=1.5, xi=[0.0, 1 / 8], nu=5.0)
+ assert jnp.isclose(pdf, jnp.array(0.00845728, dtype="float32"))
+
+ def test_binary_surprise_finite_precision(self):
+
+ surprise = binary_surprise_finite_precision(
+ value=1.0,
+ expected_mean=0.0,
+ expected_precision=1.0,
+ eta0=0.0,
+ eta1=1.0,
+ )
+ assert surprise == 1.4189385
+
+
+if __name__ == "__main__":
+ unittest.main(argv=["first-arg-is-ignored"], exit=False)
diff --git a/tests/test_updates/prediction_errors/inputs/test_prediction_errors.py b/tests/test_updates/prediction_errors/inputs/test_prediction_errors.py
new file mode 100644
index 000000000..f8d4f63cb
--- /dev/null
+++ b/tests/test_updates/prediction_errors/inputs/test_prediction_errors.py
@@ -0,0 +1,32 @@
+# Author: Nicolas Legrand
+
+import unittest
+from unittest import TestCase
+
+from pyhgf.model import Network
+from pyhgf.updates.prediction_error.inputs.generic import generic_input_prediction_error
+
+
+class TestPredictionErrors(TestCase):
+ def test_generic_input(self):
+ """Test the generic input nodes"""
+
+ ###############################################
+ # one value parent with one volatility parent #
+ ###############################################
+ network = Network().add_nodes(kind="generic-input").add_nodes(value_children=0)
+
+ attributes, (_, edges), _ = network.get_network()
+
+ attributes = generic_input_prediction_error(
+ attributes=attributes,
+ time_step=1.0,
+ edges=edges,
+ node_idx=0,
+ value=10.0,
+ observed=True,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(argv=["first-arg-is-ignored"], exit=False)
diff --git a/tests/test_updates/prediction_errors/nodes/test_dirichlet.py b/tests/test_updates/prediction_errors/nodes/test_dirichlet.py
new file mode 100644
index 000000000..4f7b1d655
--- /dev/null
+++ b/tests/test_updates/prediction_errors/nodes/test_dirichlet.py
@@ -0,0 +1,51 @@
+# Author: Nicolas Legrand
+
+import unittest
+from unittest import TestCase
+
+import jax.numpy as jnp
+
+from pyhgf.model import Network
+from pyhgf.updates.prediction_error.nodes.dirichlet import (
+ dirichlet_node_prediction_error,
+ get_candidate,
+)
+
+
+class TestDirichletNode(TestCase):
+ def test_get_candidate(self):
+ mean, precision = get_candidate(
+ value=5.0,
+ sensory_precision=1.0,
+ expected_mean=jnp.array([0.0, -5.0]),
+ expected_sigma=jnp.array([1.0, 3.0]),
+ )
+
+ assert jnp.isclose(mean, 5.026636)
+ assert jnp.isclose(precision, 1.2752448)
+
+ def test_dirichlet_node_prediction_error(self):
+
+ network = (
+ Network()
+ .add_nodes(kind="generic-input")
+ .add_nodes(kind="DP-state", value_children=0)
+ .add_nodes(
+ kind="ef-normal",
+ n_nodes=2,
+ value_children=1,
+ xis=jnp.array([0.0, 1 / 8]),
+ nus=15.0,
+ )
+ )
+
+ attributes, (_, edges), _ = network.get_network()
+ dirichlet_node_prediction_error(
+ edges=edges,
+ attributes=attributes,
+ node_idx=1,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(argv=["first-arg-is-ignored"], exit=False)
diff --git a/tests/test_networks.py b/tests/test_utils.py
similarity index 78%
rename from tests/test_networks.py
rename to tests/test_utils.py
index f37801dd2..6d99e7434 100644
--- a/tests/test_networks.py
+++ b/tests/test_utils.py
@@ -5,6 +5,7 @@
import jax.numpy as jnp
+from pyhgf.model import Network
from pyhgf.typing import AdjacencyLists, Inputs
from pyhgf.updates.posterior.continuous import (
continuous_node_update,
@@ -16,7 +17,7 @@
from pyhgf.utils import beliefs_propagation, list_branches
-class TestNetworks(TestCase):
+class TestUtils(TestCase):
def test_beliefs_propagation(self):
"""Test the loop_inputs function"""
@@ -109,7 +110,7 @@ def test_beliefs_propagation(self):
assert new_attributes[2]["precision"] == 1.5
def test_find_branch(self):
- """Test the find_branch function"""
+ """Test the find_branch function."""
edges = (
AdjacencyLists(0, (1,), None, None, None),
AdjacencyLists(2, None, (2,), (0,), None),
@@ -120,31 +121,37 @@ def test_find_branch(self):
branch_list = list_branches([0], edges, branch_list=[])
assert branch_list == [0, 1, 2]
- def test_trim_sequence(self):
- """Test the trim_sequence function"""
- # TODO: need to rewrite the trim sequence method
- # edges = (
- # Indexes((1,), None, None, None),
- # Indexes(None, (2,), (0,), None),
- # Indexes(None, None, None, (1,)),
- # Indexes((4,), None, None, None),
- # Indexes(None, None, (3,), None),
- # )
- # update_sequence = (
- # (0, continuous_input_prediction_error),
- # (1, continuous_node_prediction_error),
- # (2, continuous_node_prediction_error),
- # (3, continuous_node_prediction_error),
- # (4, continuous_node_prediction_error),
- # )
- # new_sequence = trim_sequence(
- # exclude_node_idxs=[0],
- # update_sequence=update_sequence,
- # edges=edges,
- # )
- # assert len(new_sequence) == 2
- # assert new_sequence[0][0] == 3
- # assert new_sequence[1][0] == 4
+ def test_set_update_sequence(self):
+ """Test the set_update_sequence function."""
+
+ # a standard binary HGF
+ network1 = (
+ Network()
+ .add_nodes(kind="binary-input")
+ .add_nodes(kind="binary-state", value_children=0)
+ .add_nodes(value_children=1)
+ .set_update_sequence()
+ )
+ assert len(network1.update_sequence) == 6
+
+ # a standard continuous HGF
+ network2 = (
+ Network()
+ .add_nodes(kind="continuous-input")
+ .add_nodes(value_children=0)
+ .add_nodes(volatility_children=1)
+ .set_update_sequence(update_type="standard")
+ )
+ assert len(network2.update_sequence) == 6
+
+ # a generic input with a normal-EF node
+ network3 = (
+ Network()
+ .add_nodes(kind="generic-input")
+ .add_nodes(kind="ef-normal")
+ .set_update_sequence()
+ )
+ assert len(network3.update_sequence) == 2
if __name__ == "__main__":