diff --git a/pyhgf/updates/posterior/continuous/__init__.py b/pyhgf/updates/posterior/continuous/__init__.py index fd874075..a87285b6 100644 --- a/pyhgf/updates/posterior/continuous/__init__.py +++ b/pyhgf/updates/posterior/continuous/__init__.py @@ -3,6 +3,5 @@ __all__ = [ "continuous_node_posterior_update_ehgf", - "continuous_node_posterior_update_unbounded", "continuous_node_posterior_update", ] diff --git a/tests/test_updates/posterior/continuous.py b/tests/test_updates/posterior/continuous.py index 6bc59f39..6982a63c 100644 --- a/tests/test_updates/posterior/continuous.py +++ b/tests/test_updates/posterior/continuous.py @@ -1,10 +1,11 @@ # Author: Nicolas Legrand +import jax.numpy as jnp + from pyhgf.model import Network from pyhgf.updates.posterior.continuous import ( continuous_node_posterior_update, continuous_node_posterior_update_ehgf, - continuous_node_posterior_update_unbounded, ) @@ -20,17 +21,56 @@ def test_continuous_posterior_updates(): # Standard HGF updates ------------------------------------------------------------- # ---------------------------------------------------------------------------------- + + # value update + attributes, edges, _ = network.get_network() + attributes[0]["temp"]["value_prediction_error"] = 1.0357 + attributes[0]["mean"] = 1.0357 + + new_attributes = continuous_node_posterior_update( + attributes=attributes, node_idx=1, edges=edges + ) + assert jnp.isclose(new_attributes[1]["mean"], 0.51785) + + # volatility update attributes, edges, _ = network.get_network() - _ = continuous_node_posterior_update(attributes=attributes, node_idx=2, edges=edges) + attributes[1]["temp"]["effective_precision"] = 0.01798621006309986 + attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467 + attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894 + attributes[1]["expected_precision"] = 0.9820137619972229 + attributes[1]["mean"] = 0.5225493907928467 + attributes[1]["precision"] = 1.9820137023925781 + + new_attributes = continuous_node_posterior_update( + attributes=attributes, node_idx=2, edges=edges + ) + assert jnp.isclose(new_attributes[1]["mean"], -0.0021212) + assert jnp.isclose(new_attributes[1]["precision"], 1.0022112) # eHGF updates --------------------------------------------------------------------- # ---------------------------------------------------------------------------------- - _ = continuous_node_posterior_update_ehgf( + + # value update + attributes, edges, _ = network.get_network() + attributes[0]["temp"]["value_prediction_error"] = 1.0357 + attributes[0]["mean"] = 1.0357 + + new_attributes = continuous_node_posterior_update_ehgf( attributes=attributes, node_idx=2, edges=edges ) + assert jnp.isclose(new_attributes[1]["mean"], 0.51785) - # unbounded updates ---------------------------------------------------------------- - # ---------------------------------------------------------------------------------- - _ = continuous_node_posterior_update_unbounded( + # volatility update + attributes, edges, _ = network.get_network() + attributes[1]["temp"]["effective_precision"] = 0.01798621006309986 + attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467 + attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894 + attributes[1]["expected_precision"] = 0.9820137619972229 + attributes[1]["mean"] = 0.5225493907928467 + attributes[1]["precision"] = 1.9820137023925781 + + new_attributes = continuous_node_posterior_update_ehgf( attributes=attributes, node_idx=2, edges=edges ) + assert jnp.isclose(new_attributes[1]["mean"], -0.00212589) + assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)