diff --git a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py index 8d4ec4d8..5ed7041d 100644 --- a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py +++ b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py @@ -7,9 +7,11 @@ from pyhgf.typing import Edges -from .posterior_update_mean_continuous_node import posterior_update_mean_continuous_node -from .posterior_update_precision_continuous_node import ( - posterior_update_precision_continuous_node, +from .posterior_update_mean_continuous_node_unbounded import ( + posterior_update_mean_continuous_node_unbounded, +) +from .posterior_update_precision_continuous_node_unbounded import ( + posterior_update_precision_continuous_node_unbounded, ) @@ -43,19 +45,22 @@ def continuous_node_posterior_update_unbounded( """ # update the posterior mean and precision using the eHGF update step # we start with the mean update using the expected precision as an approximation - posterior_mean = posterior_update_mean_continuous_node( - attributes, - edges, - node_idx, - node_precision=attributes[node_idx]["expected_precision"], + posterior_precision, precision_l1, precision_l2 = ( + posterior_update_precision_continuous_node_unbounded( + attributes, + edges, + node_idx, + ) ) - attributes[node_idx]["mean"] = posterior_mean + attributes[node_idx]["precision"] = posterior_precision - posterior_precision = posterior_update_precision_continuous_node( - attributes, - edges, - node_idx, + posterior_mean = posterior_update_mean_continuous_node_unbounded( + attributes=attributes, + edges=edges, + node_idx=node_idx, + precision_l1=precision_l1, + precision_l2=precision_l2, ) - attributes[node_idx]["precision"] = posterior_precision + attributes[node_idx]["mean"] = posterior_mean return attributes diff --git a/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py new file mode 100644 index 00000000..0f4ed643 --- /dev/null +++ b/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py @@ -0,0 +1,112 @@ +# Author: Nicolas Legrand + +from functools import partial +from typing import Dict + +import jax.numpy as jnp +from jax import jit + +from pyhgf.typing import Edges + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_mean_continuous_node_unbounded( + attributes: Dict, + edges: Edges, + node_idx: int, + precision_l1: float, + precision_l2: float, +) -> float: + """Posterior update of mean using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] + volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + phi = jnp.log( + (1 / attributes[volatility_child_idx]["precision"]) * (2 + jnp.sqrt(3)) + ) + + # first approximation ------------------------------------------------------ + delta_l1 = ( + ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] ** 2 + ) + ) + / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp( + volatility_coupling * phi + + attributes[volatility_child_idx]["tonic_volatility"] + ) + ) + ) - 1 + mean_l1 = ( + attributes[node_idx]["expected_mean"] + + ( + (volatility_coupling * attributes[node_idx]["tonic_volatility"]) + / (2 * precision_l1) + ) + * delta_l1 + ) + + # second approximation ----------------------------------------------------- + omega_phi = jnp.exp( + volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] + ) / ( + (1 / attributes[volatility_child_idx]["precision"]) + + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) + ) + delta_phi = ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp( + volatility_coupling * phi + + attributes[volatility_child_idx]["tonic_volatility"] + ) + ) - 1 + + mu_phi = ((2 * precision_l2 - 1) * phi + attributes[node_idx]["expected_mean"]) / ( + 2 * precision_l2 + ) + + mean_l2 = ( + mu_phi + (volatility_coupling * omega_phi) / (2 * precision_l2) * delta_phi + ) + + # weigthed interpolation + theta_l = jnp.sqrt( + 1.2 + * ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) + ) + phi_l = 8.0 + theta_r = 0.0 + phi_r = 1.0 + mean = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * mean_l1 + b( + gamma, theta_l, phi_l, theta_r, phi_r + ) * mean_l2 + + return mean + + +def s(x, theta, phi): + return 1 / (1 + jnp.exp(-phi * (x - theta))) + + +def b(x, theta_l, phi_l, theta_r, phi_r): + return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r)) diff --git a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py new file mode 100644 index 00000000..87222d27 --- /dev/null +++ b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py @@ -0,0 +1,84 @@ +# Author: Nicolas Legrand + +from functools import partial +from typing import Dict + +import jax.numpy as jnp +from jax import jit + +from pyhgf.typing import Edges + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_precision_continuous_node_unbounded( + attributes: Dict, edges: Edges, node_idx: int +) -> float: + """Posterior update of precision using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] + volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + + # first approximation ------------------------------------------------------ + precision_l1 = attributes[node_idx][ + "expected_precision" + ] + 0.5 * volatility_coupling**2 * attributes[node_idx]["tonic_volatility"] * ( + 1 - attributes[node_idx]["tonic_volatility"] + ) + + # second approximation ----------------------------------------------------- + phi = jnp.log( + (1 / attributes[volatility_child_idx]["expected_precision"]) * (2 + jnp.sqrt(3)) + ) + omega_phi = jnp.exp( + volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] + ) / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) + ) + delta_phi = ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) + ) - 1 + + precision_l2 = attributes[node_idx][ + "expected_precision" + ] + 0.5 * volatility_coupling**2 * omega_phi * ( + omega_phi + (2 * omega_phi - 1) * delta_phi + ) + + # weigthed interpolation + theta_l = jnp.sqrt( + 1.2 + * ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) + ) + phi_l = 8.0 + theta_r = 0.0 + phi_r = 1.0 + precision = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * precision_l1 + b( + gamma, theta_l, phi_l, theta_r, phi_r + ) * precision_l2 + + return precision, precision_l1, precision_l2 + + +def s(x, theta, phi): + return 1 / (1 + jnp.exp(-phi * (x - theta))) + + +def b(x, theta_l, phi_l, theta_r, phi_r): + return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r))