Skip to content

Commit

Permalink
Non-linear tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
KoraTMontemagno authored and LegrandNico committed Aug 13, 2024
1 parent fc1b6f2 commit 6ca85aa
Showing 1 changed file with 150 additions and 0 deletions.
150 changes: 150 additions & 0 deletions tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jax.tree_util import Partial

from pyhgf import load_data
from pyhgf.model import Network
from pyhgf.math import gaussian_surprise
from pyhgf.typing import AdjacencyLists, Inputs
from pyhgf.updates.posterior.continuous import (
Expand Down Expand Up @@ -208,6 +209,62 @@ def identity(x):
):
assert jnp.isclose(new_attributes_nonlinear[2][idx], val)

def test_continuous_input_update_nonlinear(nodes_attributes):
###############################################
# one value parent with one volatility parent #
###############################################
attributes = nodes_attributes
def identity(x):
return(x)

# define a non linear network behaving like the linear one
edges_nonlinear = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (identity,)),
)

# create update sequence
sequence1 = 1, continuous_node_prediction
sequence2 = 2, continuous_node_prediction
sequence3 = 0, continuous_input_prediction_error
sequence4 = 1, continuous_node_update
sequence5 = 1, continuous_node_prediction_error
sequence6 = 2, continuous_node_update
update_sequence = (
sequence1,
sequence2,
sequence3,
sequence4,
sequence5,
sequence6,
)
data = jnp.array([0.2])
time_steps = jnp.ones(1)
observed = jnp.ones(1)
inputs = Inputs(0, 0)

# apply beliefs propagation updates
new_attributes_nonlinear, _ = beliefs_propagation(
structure=(inputs, edges_nonlinear),
attributes=attributes,
update_sequence=update_sequence,
input_data=(data, time_steps, observed),
)

for idx, val in zip(["time_step", "values"], [1.0, 0.2]):
assert jnp.isclose(new_attributes_nonlinear[0][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[10000.881, 0.880797, 0.20007047, 1.0],
):
assert jnp.isclose(new_attributes_nonlinear[1][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[0.9794834, 0.95257413, 0.97345114, 1.0],
):
assert jnp.isclose(new_attributes_nonlinear[2][idx], val)


def test_scan_loop(nodes_attributes):
timeserie = load_data("continuous")
Expand Down Expand Up @@ -299,6 +356,99 @@ def identity(x):
):
assert jnp.isclose(last_nonlinear[2][idx], val)

def test_scan_loop_nonlinear(nodes_attributes):
timeserie = load_data("continuous")

###############################################
# one value parent with one volatility parent #
###############################################
attributes = nodes_attributes

#defining an identity function
def identity(x):
return(x)

edges_nonlinear = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (identity,)),
)

# create update sequence
sequence1 = 2, continuous_node_prediction
sequence2 = 1, continuous_node_prediction
sequence3 = 0, continuous_input_prediction_error
sequence4 = 1, continuous_node_update
sequence5 = 1, continuous_node_prediction_error
sequence6 = 2, continuous_node_update_ehgf
update_sequence = (
sequence1,
sequence2,
sequence3,
sequence4,
sequence5,
sequence6,
)

inputs = Inputs(0, 1)

# create the function that will be scaned
scan_fn_nonlinear = Partial(
beliefs_propagation,
update_sequence=update_sequence,
structure=(inputs, edges_nonlinear),
)

# Create the data (value and time steps vectors)
time_steps = jnp.ones((len(timeserie), 1))
observed = jnp.ones((len(timeserie), 1))

# Run the entire for loop
last_nonlinear, _ = scan(scan_fn_nonlinear, attributes,
(timeserie, time_steps, observed))
for idx, val in zip(["time_step", "values"], [1.0, 0.8241]):
assert jnp.isclose(last_nonlinear[0][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[24557.84, 14557.839, 0.8041823, 0.79050046],
):
assert jnp.isclose(last_nonlinear[1][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[1.3334407, 1.3493799, -7.1686087, -7.509615],
):
assert jnp.isclose(last_nonlinear[2][idx], val)


def test_coupling_fn_multiple_children():
""" Tests if the coupling function is passed correctly
into the network"""

# creating a simple coupling function
def identity(x):
return x

# I create a network with a node with 2 value children
test_HGF = (
Network()
.add_nodes(kind="continuous-input")
.add_nodes(kind="continuous-input")
.add_nodes(value_children=0, n_nodes=1)
.add_nodes(value_children=[1, 2], n_nodes=1,
coupling_fn=(None,identity)
)
)

# check if the number of coupling fn matches the number of children
coupling_fn_length = []
children_number = []
for node_idx in range(2,len(test_HGF.edges)):
coupling_fn_length.append(len(test_HGF.edges[node_idx].coupling_fn))
children_number.append(len(test_HGF.edges[node_idx].value_children))

assert children_number == coupling_fn_length



if __name__ == "__main__":
unittest.main(argv=["first-arg-is-ignored"], exit=False)

0 comments on commit 6ca85aa

Please sign in to comment.