Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Dec 18, 2024
1 parent c2a1204 commit 5f690de
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
1 change: 0 additions & 1 deletion pyhgf/updates/posterior/continuous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@

__all__ = [
"continuous_node_posterior_update_ehgf",
"continuous_node_posterior_update_unbounded",
"continuous_node_posterior_update",
]
52 changes: 46 additions & 6 deletions tests/test_updates/posterior/continuous.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

import jax.numpy as jnp

from pyhgf.model import Network
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
continuous_node_posterior_update_unbounded,
)


Expand All @@ -20,17 +21,56 @@ def test_continuous_posterior_updates():

# Standard HGF updates -------------------------------------------------------------
# ----------------------------------------------------------------------------------

# value update
attributes, edges, _ = network.get_network()
attributes[0]["temp"]["value_prediction_error"] = 1.0357
attributes[0]["mean"] = 1.0357

new_attributes = continuous_node_posterior_update(
attributes=attributes, node_idx=1, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], 0.51785)

# volatility update
attributes, edges, _ = network.get_network()
_ = continuous_node_posterior_update(attributes=attributes, node_idx=2, edges=edges)
attributes[1]["temp"]["effective_precision"] = 0.01798621006309986
attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467
attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894
attributes[1]["expected_precision"] = 0.9820137619972229
attributes[1]["mean"] = 0.5225493907928467
attributes[1]["precision"] = 1.9820137023925781

new_attributes = continuous_node_posterior_update(
attributes=attributes, node_idx=2, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], -0.0021212)
assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)

# eHGF updates ---------------------------------------------------------------------
# ----------------------------------------------------------------------------------
_ = continuous_node_posterior_update_ehgf(

# value update
attributes, edges, _ = network.get_network()
attributes[0]["temp"]["value_prediction_error"] = 1.0357
attributes[0]["mean"] = 1.0357

new_attributes = continuous_node_posterior_update_ehgf(
attributes=attributes, node_idx=2, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], 0.51785)

# unbounded updates ----------------------------------------------------------------
# ----------------------------------------------------------------------------------
_ = continuous_node_posterior_update_unbounded(
# volatility update
attributes, edges, _ = network.get_network()
attributes[1]["temp"]["effective_precision"] = 0.01798621006309986
attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467
attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894
attributes[1]["expected_precision"] = 0.9820137619972229
attributes[1]["mean"] = 0.5225493907928467
attributes[1]["precision"] = 1.9820137023925781

new_attributes = continuous_node_posterior_update_ehgf(
attributes=attributes, node_idx=2, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], -0.00212589)
assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)

0 comments on commit 5f690de

Please sign in to comment.