diff --git a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py index 5ed7041d..e35702e1 100644 --- a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py +++ b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py @@ -3,17 +3,11 @@ from functools import partial from typing import Dict +import jax.numpy as jnp from jax import jit from pyhgf.typing import Edges -from .posterior_update_mean_continuous_node_unbounded import ( - posterior_update_mean_continuous_node_unbounded, -) -from .posterior_update_precision_continuous_node_unbounded import ( - posterior_update_precision_continuous_node_unbounded, -) - @partial(jit, static_argnames=("edges", "node_idx")) def continuous_node_posterior_update_unbounded( @@ -45,12 +39,10 @@ def continuous_node_posterior_update_unbounded( """ # update the posterior mean and precision using the eHGF update step # we start with the mean update using the expected precision as an approximation - posterior_precision, precision_l1, precision_l2 = ( - posterior_update_precision_continuous_node_unbounded( - attributes, - edges, - node_idx, - ) + posterior_precision = posterior_update_precision_continuous_node_unbounded( + attributes=attributes, + edges=edges, + node_idx=node_idx, ) attributes[node_idx]["precision"] = posterior_precision @@ -58,9 +50,128 @@ def continuous_node_posterior_update_unbounded( attributes=attributes, edges=edges, node_idx=node_idx, - precision_l1=precision_l1, - precision_l2=precision_l2, ) attributes[node_idx]["mean"] = posterior_mean return attributes + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_mean_continuous_node_unbounded( + attributes: Dict, + edges: Edges, + node_idx: int, +) -> float: + """Posterior update of mean using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] # type: ignore + # volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + + # previous child uncertainty + alpha = 1 / attributes[volatility_child_idx]["expected_precision"] + + # posterior total uncertainty about the child + beta = ( + 1 / attributes[volatility_child_idx]["expected_precision"] + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + + return mu_l(alpha, beta, gamma) + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_precision_continuous_node_unbounded( + attributes: Dict, + edges: Edges, + node_idx: int, +) -> float: + """Posterior update of mean using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] # type: ignore + # volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + + # previous child uncertainty + alpha = 1 / attributes[volatility_child_idx]["expected_precision"] + + # posterior total uncertainty about the child + beta = ( + 1 / attributes[volatility_child_idx]["expected_precision"] + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + + return pi_l(alpha, beta, gamma) + + +def s(x, theta, psi): + return 1 / (1 + jnp.exp(-psi * (x - theta))) + + +def b(x, theta_l, phi_l, theta_r, phi_r): + return s(x, theta_l, phi_l) * (1 - s(x, theta_r, phi_r)) + + +def pi_l1(alpha, gamma): + return 0.5 * omega(alpha, gamma) * (1 - omega(alpha, gamma)) + 0.5 + + +def mu_l1(alpha, beta, gamma): + return gamma + 0.5 / pi_l1(alpha, gamma) * omega(alpha, gamma) * delta( + alpha, beta, gamma + ) + + +def omega(alpha, x): + return jnp.exp(x) / (alpha + jnp.exp(x)) + + +def delta(alpha, beta, x): + return beta / (alpha + jnp.exp(x)) - 1 + + +def phi(alpha): + return jnp.log(alpha * (2 + jnp.sqrt(3))) + + +def pi_l2(alpha, beta): + return -ddJ(phi(alpha), alpha, beta) + + +def dJ(x, alpha, beta, gamma): + return 0.5 * omega(alpha, x) * delta(alpha, beta, x) - 0.5 * (x - gamma) + + +def ddJ(x, alpha, beta): + return ( + -0.5 + * omega(alpha, x) + * (omega(alpha, x) + (2 * omega(alpha, x) - 1) * delta(alpha, beta, x)) + - 0.5 + ) + + +def mu_l2(alpha, beta, gamma): + return phi(alpha) - dJ(phi(alpha), alpha, beta, gamma) / ddJ( + phi(alpha), alpha, beta + ) + + +def mu_l(alpha, beta, gamma): + return (1 - b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0)) * mu_l1( + alpha, beta, gamma + ) + b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0) * mu_l2( + alpha, beta, gamma + ) + + +def pi_l(alpha, beta, gamma): + return (1 - b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0)) * pi_l1( + alpha, gamma + ) + b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0) * pi_l2(alpha, beta) diff --git a/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py deleted file mode 100644 index 0f4ed643..00000000 --- a/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py +++ /dev/null @@ -1,112 +0,0 @@ -# Author: Nicolas Legrand - -from functools import partial -from typing import Dict - -import jax.numpy as jnp -from jax import jit - -from pyhgf.typing import Edges - - -@partial(jit, static_argnames=("edges", "node_idx")) -def posterior_update_mean_continuous_node_unbounded( - attributes: Dict, - edges: Edges, - node_idx: int, - precision_l1: float, - precision_l2: float, -) -> float: - """Posterior update of mean using ubounded update.""" - volatility_child_idx = edges[node_idx].volatility_children[0] - volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] - gamma = attributes[node_idx]["expected_mean"] - phi = jnp.log( - (1 / attributes[volatility_child_idx]["precision"]) * (2 + jnp.sqrt(3)) - ) - - # first approximation ------------------------------------------------------ - delta_l1 = ( - ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] ** 2 - ) - ) - / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp( - volatility_coupling * phi - + attributes[volatility_child_idx]["tonic_volatility"] - ) - ) - ) - 1 - mean_l1 = ( - attributes[node_idx]["expected_mean"] - + ( - (volatility_coupling * attributes[node_idx]["tonic_volatility"]) - / (2 * precision_l1) - ) - * delta_l1 - ) - - # second approximation ----------------------------------------------------- - omega_phi = jnp.exp( - volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] - ) / ( - (1 / attributes[volatility_child_idx]["precision"]) - + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) - ) - delta_phi = ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp( - volatility_coupling * phi - + attributes[volatility_child_idx]["tonic_volatility"] - ) - ) - 1 - - mu_phi = ((2 * precision_l2 - 1) * phi + attributes[node_idx]["expected_mean"]) / ( - 2 * precision_l2 - ) - - mean_l2 = ( - mu_phi + (volatility_coupling * omega_phi) / (2 * precision_l2) * delta_phi - ) - - # weigthed interpolation - theta_l = jnp.sqrt( - 1.2 - * ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) - / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) - ) - phi_l = 8.0 - theta_r = 0.0 - phi_r = 1.0 - mean = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * mean_l1 + b( - gamma, theta_l, phi_l, theta_r, phi_r - ) * mean_l2 - - return mean - - -def s(x, theta, phi): - return 1 / (1 + jnp.exp(-phi * (x - theta))) - - -def b(x, theta_l, phi_l, theta_r, phi_r): - return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r)) diff --git a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py deleted file mode 100644 index 87222d27..00000000 --- a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py +++ /dev/null @@ -1,84 +0,0 @@ -# Author: Nicolas Legrand - -from functools import partial -from typing import Dict - -import jax.numpy as jnp -from jax import jit - -from pyhgf.typing import Edges - - -@partial(jit, static_argnames=("edges", "node_idx")) -def posterior_update_precision_continuous_node_unbounded( - attributes: Dict, edges: Edges, node_idx: int -) -> float: - """Posterior update of precision using ubounded update.""" - volatility_child_idx = edges[node_idx].volatility_children[0] - volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] - gamma = attributes[node_idx]["expected_mean"] - - # first approximation ------------------------------------------------------ - precision_l1 = attributes[node_idx][ - "expected_precision" - ] + 0.5 * volatility_coupling**2 * attributes[node_idx]["tonic_volatility"] * ( - 1 - attributes[node_idx]["tonic_volatility"] - ) - - # second approximation ----------------------------------------------------- - phi = jnp.log( - (1 / attributes[volatility_child_idx]["expected_precision"]) * (2 + jnp.sqrt(3)) - ) - omega_phi = jnp.exp( - volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] - ) / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) - ) - delta_phi = ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) - ) - 1 - - precision_l2 = attributes[node_idx][ - "expected_precision" - ] + 0.5 * volatility_coupling**2 * omega_phi * ( - omega_phi + (2 * omega_phi - 1) * delta_phi - ) - - # weigthed interpolation - theta_l = jnp.sqrt( - 1.2 - * ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) - / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) - ) - phi_l = 8.0 - theta_r = 0.0 - phi_r = 1.0 - precision = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * precision_l1 + b( - gamma, theta_l, phi_l, theta_r, phi_r - ) * precision_l2 - - return precision, precision_l1, precision_l2 - - -def s(x, theta, phi): - return 1 / (1 + jnp.exp(-phi * (x - theta))) - - -def b(x, theta_l, phi_l, theta_r, phi_r): - return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r)) diff --git a/tests/test_updates/posterior/continuous.py b/tests/test_updates/posterior/continuous.py index 6bc59f39..5141ab4d 100644 --- a/tests/test_updates/posterior/continuous.py +++ b/tests/test_updates/posterior/continuous.py @@ -1,11 +1,25 @@ # Author: Nicolas Legrand +import jax.numpy as jnp + from pyhgf.model import Network from pyhgf.updates.posterior.continuous import ( continuous_node_posterior_update, continuous_node_posterior_update_ehgf, continuous_node_posterior_update_unbounded, ) +from pyhgf.updates.posterior.continuous.continuous_node_posterior_update_unbounded import ( + b, + delta, + mu_l, + mu_l1, + mu_l2, + omega, + pi_l, + pi_l1, + pi_l2, + s, +) def test_continuous_posterior_updates(): @@ -34,3 +48,24 @@ def test_continuous_posterior_updates(): _ = continuous_node_posterior_update_unbounded( attributes=attributes, node_idx=2, edges=edges ) + + +def test_unbounded_hgf_equations(): + + alpha = 1.0 + beta = 5.0 + gamma = 4.0 + + assert jnp.isclose(omega(alpha, gamma), 0.98201376) + assert jnp.isclose(delta(alpha, beta, gamma), -0.9100689) + + assert b(1.0, 1.0, 1.0, 1.0, 1.0) == 0.25 + assert s(1.0, 1.0, 1.0) == 0.5 + + assert jnp.isclose(pi_l1(alpha, gamma), 0.5088314) + assert jnp.isclose(pi_l2(alpha, beta), 0.82389593) + assert jnp.isclose(pi_l(alpha, beta, gamma), 0.51449823) + + assert jnp.isclose(mu_l1(alpha, beta, gamma), 3.1218112) + assert jnp.isclose(mu_l2(alpha, beta, gamma), 2.9723248) + assert jnp.isclose(mu_l(alpha, beta, gamma), 3.1191223)