Skip to content

Commit

Permalink
test model
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 26, 2024
1 parent 1da3dd3 commit 407d299
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 74 deletions.
40 changes: 20 additions & 20 deletions docs/source/notebooks/0.2-Creating_networks.ipynb

Large diffs are not rendered by default.

78 changes: 33 additions & 45 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,6 @@ def create_belief_propagation_fn(

return self

def cache_belief_propagation_fn(self) -> "Network":
"""Blank call to the belief propagation function.
.. note:
This step is called by default when using py:meth:`input_data`. It can
sometimes be convenient to call this step independently to chache the JITed
function before fitting the model.
"""
if self.scan_fn is None:
self = self.create_belief_propagation_fn()

# blanck call to cache the JIT-ed functions
_ = scan(
self.scan_fn,
self.attributes,
(
jnp.ones((1, len(self.input_idxs))),
jnp.ones((1, 1)),
jnp.ones((1, 1)),
),
)

return self

def input_data(
self,
input_data: Union[np.ndarray, tuple],
Expand Down Expand Up @@ -199,7 +174,11 @@ def input_data(

# Interleave observations and masks
input_data = tuple(
[item for pair in zip(observations, observed) for item in pair]
[
item.flatten()
for pair in zip(observations, observed)
for item in pair
]
)

# time steps vector
Expand All @@ -226,6 +205,7 @@ def input_custom_sequence(
input_data: np.ndarray,
time_steps: Optional[np.ndarray] = None,
observed: Optional[np.ndarray] = None,
input_idxs: Optional[Tuple[int]] = None,
):
"""Add new observations with custom update sequences.
Expand Down Expand Up @@ -263,52 +243,60 @@ def input_custom_sequence(
Missing inputs are missing observations from the agent's perspective and
should not be used to handle missing data points that were observed (e.g.
missing in the event log, or rejected trials).
input_idxs :
Indexes on the state nodes receiving observations.
"""
if time_steps is None:
time_steps = np.ones(len(input_data)) # time steps vector
# set the input nodes indexes
if input_idxs is not None:
self.input_idxs = input_idxs

# concatenate data and time
if time_steps is None:
time_steps = np.ones((len(input_data), 1)) # time steps vector
else:
time_steps = time_steps[..., jnp.newaxis]
if input_data.ndim == 1:
input_data = input_data[..., jnp.newaxis]
# input_data should be a tuple of n by time_steps arrays
if not isinstance(input_data, tuple):
if observed is None:
observed = np.ones(input_data.shape, dtype=int)
if input_data.ndim == 1:

# Interleave observations and masks
input_data = (input_data, observed)
else:
observed = jnp.hsplit(observed, input_data.shape[1])
observations = jnp.hsplit(input_data, input_data.shape[1])

# is it observation or missing inputs
if observed is None:
observed = np.ones(input_data.shape, dtype=int)
# Interleave observations and masks
input_data = tuple(
[item for pair in zip(observations, observed) for item in pair]
)

# time steps vector
if time_steps is None:
time_steps = np.ones(input_data[0].shape[0])

# create the update functions that will be scanned
branches_fn = [
Partial(
beliefs_propagation,
update_sequence=seq,
edges=self.edges,
input_idxs=self.input_idxs,
)
for seq in update_branches
]

# create the function that will be scanned
def switching_propagation(attributes, scan_input):
data, idx = scan_input
(*data, idx) = scan_input
return switch(idx, branches_fn, attributes, data)

# wrap the inputs
scan_input = (input_data, time_steps, observed), branches_idx
scan_input = (*input_data, time_steps, branches_idx)

# scan over the input data and apply the switching belief propagation functions
_, node_trajectories = scan(switching_propagation, self.attributes, scan_input)

# the node structure at each value updates
self.node_trajectories = node_trajectories

# because some of the input nodes might not have been updated, here we manually
# insert the input data to the input node (without triggering updates)
for idx, inp in zip(self.input_idxs, range(input_data.shape[1])):
self.node_trajectories[idx]["values"] = input_data[inp]

return self

def get_network(self) -> NetworkParameters:
Expand Down
4 changes: 2 additions & 2 deletions src/pyhgf/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ def total_gaussian_surprise(
"""
# compute the sum of Gaussian surprise across every node
surprise = jnp.zeros(len(hgf.node_trajectories[0]["values"]))
surprise = jnp.zeros(len(hgf.node_trajectories[0]["mean"]))

# first we start with nodes that are value parents to input nodes
input_parents_list = []
for idx in hgf.input_idxs:
va_pa = hgf.edges[idx].value_parents[0] # type: ignore
input_parents_list.append(va_pa)
surprise += gaussian_surprise(
x=hgf.node_trajectories[idx]["values"],
x=hgf.node_trajectories[idx]["mean"],
expected_mean=hgf.node_trajectories[va_pa]["expected_mean"],
expected_precision=hgf.node_trajectories[va_pa]["expected_precision"],
)
Expand Down
24 changes: 17 additions & 7 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@
from pyhgf.response import total_gaussian_surprise


def test_HGF():
"""Test the model class"""
def test_network():
"""Test the network class"""

#####################
# Creating networks #
#####################

custom_hgf = (
Network()
.add_nodes(kind="continuous-input")
.add_nodes(kind="binary-input")
.add_nodes(kind="continuous-state")
.add_nodes(kind="binary-state")
.add_nodes(value_children=0)
.add_nodes(
kind="binary-state",
value_children=1,
)
.add_nodes(value_children=[2, 3])
Expand All @@ -31,12 +30,14 @@ def test_HGF():
.add_nodes(volatility_children=7)
)

custom_hgf.cache_belief_propagation_fn()
custom_hgf.create_belief_propagation_fn(overwrite=False)
custom_hgf.create_belief_propagation_fn(overwrite=True)

custom_hgf.input_data(input_data=np.array([0.2, 1]))
custom_hgf.input_data(input_data=np.ones((10, 2)))


def test_continuous_hgf():
"""Test the continuous HGF"""
##############
# Continuous #
##############
Expand Down Expand Up @@ -79,6 +80,10 @@ def test_HGF():
sp = total_gaussian_surprise(three_level_continuous_hgf)
assert jnp.isclose(sp.sum(), 1159.1089)


def test_binary_hgf():
"""Test the binary HGF"""

##########
# Binary #
##########
Expand Down Expand Up @@ -122,9 +127,14 @@ def test_HGF():
surprise = three_level_binary_hgf.surprise()
assert jnp.isclose(surprise.sum(), 215.59067)


def test_custom_sequence():
"""Test the continuous HGF"""

############################
# dynamic update sequences #
############################
u, _ = load_data("binary")

three_level_binary_hgf = HGF(
n_levels=3,
Expand Down

0 comments on commit 407d299

Please sign in to comment.