Skip to content

Commit

Permalink
test categorical
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 26, 2024
1 parent f374d8d commit 1da3dd3
Show file tree
Hide file tree
Showing 12 changed files with 536 additions and 513 deletions.
143 changes: 87 additions & 56 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

472 changes: 149 additions & 323 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

106 changes: 67 additions & 39 deletions docs/source/notebooks/1.3-Continuous_HGF.ipynb

Large diffs are not rendered by default.

78 changes: 67 additions & 11 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pyhgf.utils import (
add_edges,
beliefs_propagation,
fill_categorical_state_node,
get_input_idxs,
get_update_sequence,
to_pandas,
Expand Down Expand Up @@ -146,9 +147,9 @@ def cache_belief_propagation_fn(self) -> "Network":

def input_data(
self,
input_data: np.ndarray,
input_data: Union[np.ndarray, tuple],
time_steps: Optional[np.ndarray] = None,
observed: Optional[np.ndarray] = None,
observed: Optional[Union[np.ndarray, tuple]] = None,
input_idxs: Optional[Tuple[int]] = None,
):
"""Add new observations.
Expand Down Expand Up @@ -184,21 +185,32 @@ def input_data(
if self.scan_fn is None:
self = self.create_belief_propagation_fn()

# input_data should be a tuple of n by time_steps arrays
if not isinstance(input_data, tuple):
if observed is None:
observed = np.ones(input_data.shape, dtype=int)
if input_data.ndim == 1:

# Interleave observations and masks
input_data = (input_data, observed)
else:
observed = jnp.hsplit(observed, input_data.shape[1])
observations = jnp.hsplit(input_data, input_data.shape[1])

# Interleave observations and masks
input_data = tuple(
[item for pair in zip(observations, observed) for item in pair]
)

# time steps vector
if time_steps is None:
time_steps = np.ones(input_data.shape[0])
if input_data.ndim == 1:
input_data = input_data[..., jnp.newaxis]

# is it an observation or a missing input
if observed is None:
observed = np.ones(input_data.shape, dtype=int)
time_steps = np.ones(input_data[0].shape[0])

# this is where the model loops over the whole input time series
# at each time point, the node structure is traversed and beliefs are updated
# using precision-weighted prediction errors
last_attributes, node_trajectories = scan(
self.scan_fn, self.attributes, (input_data, time_steps, observed)
self.scan_fn, self.attributes, (*input_data, time_steps)
)

# belief trajectories
Expand Down Expand Up @@ -397,6 +409,7 @@ def add_nodes(
if kind not in [
"DP-state",
"ef-normal",
"categorical-state",
"continuous-state",
"binary-state",
"generic-state",
Expand All @@ -405,7 +418,7 @@ def add_nodes(
(
"Invalid node type. Should be one of the following: "
"'DP-state', 'continuous-state', 'binary-state', 'ef-normal'."
"'generic-state'"
"'generic-state' or 'categorical-state'"
)
)

Expand Down Expand Up @@ -489,6 +502,37 @@ def add_nodes(
"mean": 0.0,
"observed": 1.0,
}
elif kind == "categorical-state":
if "n_categories" in node_parameters:
n_categories = node_parameters["n_categories"]
elif "n_categories" in additional_parameters:
n_categories = additional_parameters["n_categories"]
else:
n_categories = 4
binary_parameters = {
"n_categories": n_categories,
"precision_1": 1.0,
"precision_2": 1.0,
"precision_3": 1.0,
"mean_1": 1 / n_categories,
"mean_2": -jnp.log(n_categories - 1),
"mean_3": 0.0,
"tonic_volatility_2": -4.0,
"tonic_volatility_3": -4.0,
}
binary_idxs: List[int] = [
1 + i + len(self.edges) for i in range(n_categories)
]
default_parameters = {
"binary_idxs": binary_idxs, # type: ignore
"n_categories": n_categories,
"surprise": 0.0,
"kl_divergence": 0.0,
"alpha": jnp.ones(n_categories),
"observed": jnp.ones(n_categories, dtype=int),
"mean": jnp.array([1.0 / n_categories] * n_categories),
"binary_parameters": binary_parameters,
}
elif kind == "DP-state":

if "batch_size" in additional_parameters.keys():
Expand Down Expand Up @@ -549,6 +593,8 @@ def add_nodes(
node_type = 3
elif kind == "DP-state":
node_type = 4
elif kind == "categorical-state":
node_type = 5

for _ in range(n_nodes):
# convert the structure to a list to modify it
Expand Down Expand Up @@ -610,6 +656,16 @@ def add_nodes(
coupling_strengths=volatility_parents[1], # type: ignore
)

if kind == "categorical-state":
# if we are creating a categorical state or state-transition node
# we have to generate the implied binary network(s) here
self = fill_categorical_state_node(
self,
node_idx=node_idx,
binary_states_idxs=node_parameters["binary_idxs"], # type: ignore
binary_parameters=binary_parameters,
)

return self

def plot_nodes(self, node_idxs: Union[int, List[int]], **kwargs):
Expand Down
16 changes: 13 additions & 3 deletions src/pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def plot_correlations(network: "Network") -> Axes:
)
ax.set_title("Correlations between the model trajectories")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", size=8)
ax.set_yticklabels(ax.get_xticklabels(), size=8)
ax.set_yticklabels(ax.get_yticklabels(), size=8)

return ax

Expand All @@ -250,7 +250,7 @@ def plot_network(network: "Network") -> "Source":
except ImportError:
print(
(
"Graphviz is required to plot the nodes structure. "
"Graphviz is required to plot networks. "
"See https://pypi.org/project/graphviz/"
)
)
Expand Down Expand Up @@ -287,7 +287,7 @@ def plot_network(network: "Network") -> "Source":
)

elif network.edges[idx].node_type == 4:
# Dirichlet PRocess state node
# Dirichlet Process state node
graphviz_structure.node(
f"x_{idx}",
label=f"DP-{idx}",
Expand All @@ -296,6 +296,16 @@ def plot_network(network: "Network") -> "Source":
fillcolor="#e2d8c1",
)

elif network.edges[idx].node_type == 5:
# Categorical state node
graphviz_structure.node(
f"x_{idx}",
label=f"Ca-{idx}",
style=style,
shape="diamond",
fillcolor="#e2d8c1",
)

# connect value parents
for i, index in enumerate(network.edges):
value_parents = index.value_parents
Expand Down
6 changes: 3 additions & 3 deletions src/pyhgf/updates/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def set_observation(
attributes: Dict,
node_idx: int,
value: float,
values: float,
observed: int,
) -> Dict:
r"""Add observations to the target node by setting the posterior to a given value.
Expand All @@ -21,7 +21,7 @@ def set_observation(
The attributes of the probabilistic network.
node_idx :
Pointer to the input node.
value :
values :
The new observed value.
observed :
Whether value was observed or not.
Expand All @@ -32,7 +32,7 @@ def set_observation(
The attributes of the probabilistic network.
"""
attributes[node_idx]["mean"] = value
attributes[node_idx]["mean"] = values
attributes[node_idx]["observed"] = observed

return attributes
61 changes: 27 additions & 34 deletions src/pyhgf/updates/posterior/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


@partial(jit, static_argnames=("edges", "node_idx"))
def categorical_input_update(
attributes: Dict, time_step: float, node_idx: int, edges: Edges, **args
def categorical_state_update(
attributes: Dict, node_idx: int, edges: Edges, **args
) -> Dict:
"""Update the categorical input node given an array of binary observations.
Expand All @@ -29,8 +29,6 @@ def categorical_input_update(
`"psis"` is the value coupling strength. It should have the same length as the
volatility parents' indexes. `"volatility_coupling"` is the volatility coupling
strength. It should have the same length as the volatility parents' indexes.
time_step :
The interval between the previous time point and the current time point.
node_idx :
Pointer to the node that needs to be updated.
edges :
Expand All @@ -49,56 +47,51 @@ def categorical_input_update(
binary_input_update, continuous_input_update
"""
# get the expected values at time k from the binary inputs (X_1)
new_xi = jnp.array(
# get the expected values before the update
expected_mean = jnp.array(
[
attributes[edges[vapa].value_parents[0]]["expected_mean"]
for vapa in edges[node_idx].value_parents # type: ignore
attributes[value_parent_idx]["expected_mean"]
for value_parent_idx in edges[node_idx].value_parents # type: ignore
]
)

# the differential of expectations (parents predictions at time k and k-1)
delta_xi = new_xi - attributes[node_idx]["xi"]

# using the PE for the previous time point, we can compute nu and the alpha vector
pe = attributes[node_idx]["pe"]
nu = (pe / delta_xi) - 1
alpha = (nu * new_xi) + 1 # concentration parameters for the Dirichlet

# in case alpha contains NaNs (e.g. after the first time step,
# due to the absence of belief evolution)
alpha = jnp.where(jnp.isnan(alpha), 1.0, alpha)

# now retrieve the values observed at time k
attributes[node_idx]["values"] = jnp.array(
# get the new values after the update from the continuous state
updated_mean = jnp.array(
[
attributes[vapa]["values"]
for vapa in edges[node_idx].value_parents # type: ignore
attributes[edges[parent_idx].value_parents[0]]["mean"]
for parent_idx in edges[node_idx].value_parents # type: ignore
]
)
updated_mean = 1 / (1 + jnp.exp(-updated_mean)) # logit transform

# compute the prediction error (observed - expected) at time K
pe = attributes[node_idx]["mean"] - expected_mean

# compute the prediction error at time K
pe = attributes[node_idx]["values"] - new_xi
attributes[node_idx]["pe"] = pe # keep PE for later use at k+1
attributes[node_idx]["xi"] = new_xi # keep expectation for later use at k+1
# the differential of expectations (parent posterior - parents expectation)
delta_xi = updated_mean - expected_mean

# using the new PE, we can update nu and the alpha vector
nu = (pe / delta_xi) - 1
alpha = (nu * expected_mean) + 1 # concentration parameters for the Dirichlet

# in case alpha contains NaNs
# alpha = jnp.where(jnp.isnan(alpha), 1.0, alpha)

# compute Bayesian surprise as :
# 1 - KL divergence from the concentration parameters
# 2 - the sum of binary surprises observed in the parents nodes
attributes[node_idx]["kl_divergence"] = dirichlet_kullback_leibler(
attributes[node_idx]["alpha"], alpha
)
# 2 - the sum of binary surprises observed in the parents nodes
attributes[node_idx]["surprise"] = jnp.sum(
binary_surprise(
x=attributes[node_idx]["values"], expected_mean=attributes[node_idx]["xi"]
)
binary_surprise(x=attributes[node_idx]["mean"], expected_mean=expected_mean)
)

# save the new concentration parameters
attributes[node_idx]["alpha"] = alpha
attributes[node_idx]["mean"] = updated_mean

# prediction mean
attributes[node_idx]["mean"] = alpha / jnp.sum(alpha)
attributes[node_idx]["time_step"] = time_step
# attributes[node_idx]["mean"] = alpha / jnp.sum(alpha)

return attributes
54 changes: 54 additions & 0 deletions src/pyhgf/updates/prediction_error/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial
from typing import Dict

from jax import jit

from pyhgf.typing import Edges


@partial(jit, static_argnames=("edges", "node_idx"))
def categorical_state_prediction_error(
attributes: Dict, node_idx: int, edges: Edges, **args
) -> Dict:
"""Prediction error from a categorical state node.
The update will pass the input observations to the binary state nodes.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
.. note::
`"psis"` is the value coupling strength. It should have the same length as the
volatility parents' indexes. `"volatility_coupling"` is the volatility coupling
strength. It should have the same length as the volatility parents' indexes.
node_idx :
Pointer to the node that needs to be updated.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node
number. For each node, the index lists the value and volatility parents and
children.
Returns
-------
attributes :
The updated parameters structure.
See Also
--------
binary_input_update, continuous_input_update
"""
# pass the mean to the binary state nodes
for mean, observed, value_parent_idx in zip(
attributes[node_idx]["mean"],
attributes[node_idx]["observed"],
edges[node_idx].value_parents, # type: ignore
):
attributes[value_parent_idx]["mean"] = mean
attributes[value_parent_idx]["observed"] = observed

return attributes
Loading

0 comments on commit 1da3dd3

Please sign in to comment.