From 8bb57ea1065bf17ca2aebadd4c80ef02022334f7 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 16 Dec 2024 16:08:21 +0100 Subject: [PATCH 1/3] add uhgf updates --- pyhgf/model/network.py | 13 +- .../updates/posterior/continuous/__init__.py | 3 + ...tinuous_node_posterior_update_unbounded.py | 66 +++++++++++ ...r_update_mean_continuous_node_unbounded.py | 112 ++++++++++++++++++ ...ate_precision_continuous_node_unbounded.py | 84 +++++++++++++ pyhgf/utils/get_update_sequence.py | 8 +- 6 files changed, 281 insertions(+), 5 deletions(-) create mode 100644 pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py create mode 100644 pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py create mode 100644 pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index 3eec09908..082b929a4 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -83,7 +83,7 @@ def input_idxs(self, value): self.input_idxs = value def create_belief_propagation_fn( - self, overwrite: bool = True, update_type: str = "eHGF" + self, overwrite: bool = True, update_type: str = "unbounded" ) -> "Network": """Create the belief propagation function. @@ -97,11 +97,16 @@ def create_belief_propagation_fn( preexisting values. Otherwise, do not create a new function if the attribute `scan_fn` is already defined. update_type : - The type of update to perform for volatility coupling. Can be `"eHGF"` - (defaults) or `"standard"`. The eHGF update step was proposed as an + The type of update to perform for volatility coupling. Can be `"unbounded"` + (defaults), `"ehgf"` or `"standard"`. The unbounded approximation was + recently introduced to avoid negative precisions updates, which greatly + improve sampling performance. The eHGF update step was proposed as an alternative to the original definition in that it starts by updating the mean and then the precision of the parent node, which generally reduces the - errors associated with impossible parameter space and improves sampling. + occurence of negative precision updates, while not removing them entirely. + .. note: + The different update steps only apply to nodes having at least one + volatility parents. In other cases, the regular HGF updates are applied. """ # create the update sequence if it does not already exist diff --git a/pyhgf/updates/posterior/continuous/__init__.py b/pyhgf/updates/posterior/continuous/__init__.py index fd8740754..e64c9523a 100644 --- a/pyhgf/updates/posterior/continuous/__init__.py +++ b/pyhgf/updates/posterior/continuous/__init__.py @@ -1,5 +1,8 @@ from .continuous_node_posterior_update import continuous_node_posterior_update from .continuous_node_posterior_update_ehgf import continuous_node_posterior_update_ehgf +from .continuous_node_posterior_update_unbounded import ( + continuous_node_posterior_update_unbounded, +) __all__ = [ "continuous_node_posterior_update_ehgf", diff --git a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py new file mode 100644 index 000000000..5ed7041d6 --- /dev/null +++ b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py @@ -0,0 +1,66 @@ +# Author: Nicolas Legrand + +from functools import partial +from typing import Dict + +from jax import jit + +from pyhgf.typing import Edges + +from .posterior_update_mean_continuous_node_unbounded import ( + posterior_update_mean_continuous_node_unbounded, +) +from .posterior_update_precision_continuous_node_unbounded import ( + posterior_update_precision_continuous_node_unbounded, +) + + +@partial(jit, static_argnames=("edges", "node_idx")) +def continuous_node_posterior_update_unbounded( + attributes: Dict, node_idx: int, edges: Edges, **args +) -> Dict: + """Update the posterior of a continuous node using an unbounded approximation. + + Parameters + ---------- + attributes : + The attributes of the probabilistic nodes. + node_idx : + Pointer to the node that needs to be updated. After continuous updates, the + parameters of value and volatility parents (if any) will be different. + edges : + The edges of the probabilistic nodes as a tuple of + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. + For each node, the index list value and volatility parents and children. + + Returns + ------- + attributes : + The updated attributes of the probabilistic nodes. + + See Also + -------- + continuous_node_posterior_update_ehgf + + """ + # update the posterior mean and precision using the eHGF update step + # we start with the mean update using the expected precision as an approximation + posterior_precision, precision_l1, precision_l2 = ( + posterior_update_precision_continuous_node_unbounded( + attributes, + edges, + node_idx, + ) + ) + attributes[node_idx]["precision"] = posterior_precision + + posterior_mean = posterior_update_mean_continuous_node_unbounded( + attributes=attributes, + edges=edges, + node_idx=node_idx, + precision_l1=precision_l1, + precision_l2=precision_l2, + ) + attributes[node_idx]["mean"] = posterior_mean + + return attributes diff --git a/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py new file mode 100644 index 000000000..0f4ed643b --- /dev/null +++ b/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py @@ -0,0 +1,112 @@ +# Author: Nicolas Legrand + +from functools import partial +from typing import Dict + +import jax.numpy as jnp +from jax import jit + +from pyhgf.typing import Edges + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_mean_continuous_node_unbounded( + attributes: Dict, + edges: Edges, + node_idx: int, + precision_l1: float, + precision_l2: float, +) -> float: + """Posterior update of mean using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] + volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + phi = jnp.log( + (1 / attributes[volatility_child_idx]["precision"]) * (2 + jnp.sqrt(3)) + ) + + # first approximation ------------------------------------------------------ + delta_l1 = ( + ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] ** 2 + ) + ) + / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp( + volatility_coupling * phi + + attributes[volatility_child_idx]["tonic_volatility"] + ) + ) + ) - 1 + mean_l1 = ( + attributes[node_idx]["expected_mean"] + + ( + (volatility_coupling * attributes[node_idx]["tonic_volatility"]) + / (2 * precision_l1) + ) + * delta_l1 + ) + + # second approximation ----------------------------------------------------- + omega_phi = jnp.exp( + volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] + ) / ( + (1 / attributes[volatility_child_idx]["precision"]) + + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) + ) + delta_phi = ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp( + volatility_coupling * phi + + attributes[volatility_child_idx]["tonic_volatility"] + ) + ) - 1 + + mu_phi = ((2 * precision_l2 - 1) * phi + attributes[node_idx]["expected_mean"]) / ( + 2 * precision_l2 + ) + + mean_l2 = ( + mu_phi + (volatility_coupling * omega_phi) / (2 * precision_l2) * delta_phi + ) + + # weigthed interpolation + theta_l = jnp.sqrt( + 1.2 + * ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) + ) + phi_l = 8.0 + theta_r = 0.0 + phi_r = 1.0 + mean = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * mean_l1 + b( + gamma, theta_l, phi_l, theta_r, phi_r + ) * mean_l2 + + return mean + + +def s(x, theta, phi): + return 1 / (1 + jnp.exp(-phi * (x - theta))) + + +def b(x, theta_l, phi_l, theta_r, phi_r): + return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r)) diff --git a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py new file mode 100644 index 000000000..87222d277 --- /dev/null +++ b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py @@ -0,0 +1,84 @@ +# Author: Nicolas Legrand + +from functools import partial +from typing import Dict + +import jax.numpy as jnp +from jax import jit + +from pyhgf.typing import Edges + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_precision_continuous_node_unbounded( + attributes: Dict, edges: Edges, node_idx: int +) -> float: + """Posterior update of precision using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] + volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + + # first approximation ------------------------------------------------------ + precision_l1 = attributes[node_idx][ + "expected_precision" + ] + 0.5 * volatility_coupling**2 * attributes[node_idx]["tonic_volatility"] * ( + 1 - attributes[node_idx]["tonic_volatility"] + ) + + # second approximation ----------------------------------------------------- + phi = jnp.log( + (1 / attributes[volatility_child_idx]["expected_precision"]) * (2 + jnp.sqrt(3)) + ) + omega_phi = jnp.exp( + volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] + ) / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) + ) + delta_phi = ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) / ( + (1 / attributes[volatility_child_idx]["expected_precision"]) + + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) + ) - 1 + + precision_l2 = attributes[node_idx][ + "expected_precision" + ] + 0.5 * volatility_coupling**2 * omega_phi * ( + omega_phi + (2 * omega_phi - 1) * delta_phi + ) + + # weigthed interpolation + theta_l = jnp.sqrt( + 1.2 + * ( + (1 / attributes[volatility_child_idx]["precision"]) + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) + ) + phi_l = 8.0 + theta_r = 0.0 + phi_r = 1.0 + precision = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * precision_l1 + b( + gamma, theta_l, phi_l, theta_r, phi_r + ) * precision_l2 + + return precision, precision_l1, precision_l2 + + +def s(x, theta, phi): + return 1 / (1 + jnp.exp(-phi * (x - theta))) + + +def b(x, theta_l, phi_l, theta_r, phi_r): + return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r)) diff --git a/pyhgf/utils/get_update_sequence.py b/pyhgf/utils/get_update_sequence.py index 1ad304ccc..b385b245a 100644 --- a/pyhgf/utils/get_update_sequence.py +++ b/pyhgf/utils/get_update_sequence.py @@ -9,6 +9,7 @@ from pyhgf.updates.posterior.continuous import ( continuous_node_posterior_update, continuous_node_posterior_update_ehgf, + continuous_node_posterior_update_unbounded, ) from pyhgf.updates.prediction.binary import binary_state_node_prediction from pyhgf.updates.prediction.continuous import continuous_node_prediction @@ -135,7 +136,12 @@ def get_update_sequence( if all([i not in nodes_without_prediction_error for i in all_children]): no_update = False if network.edges[idx].node_type == 2: - if update_type == "eHGF": + if update_type == "unbounded": + if network.edges[idx].volatility_children is not None: + update_fn = continuous_node_posterior_update_unbounded + else: + update_fn = continuous_node_posterior_update + elif update_type == "eHGF": if network.edges[idx].volatility_children is not None: update_fn = continuous_node_posterior_update_ehgf else: From d12f9c89944111d1a4e60a5bca951325af89a697 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 17 Dec 2024 11:23:39 +0100 Subject: [PATCH 2/3] new equations --- ...tinuous_node_posterior_update_unbounded.py | 141 ++++++++++++++++-- ...r_update_mean_continuous_node_unbounded.py | 112 -------------- ...ate_precision_continuous_node_unbounded.py | 84 ----------- tests/test_updates/posterior/continuous.py | 35 +++++ 4 files changed, 161 insertions(+), 211 deletions(-) delete mode 100644 pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py delete mode 100644 pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py diff --git a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py index 5ed7041d6..e35702e19 100644 --- a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py +++ b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py @@ -3,17 +3,11 @@ from functools import partial from typing import Dict +import jax.numpy as jnp from jax import jit from pyhgf.typing import Edges -from .posterior_update_mean_continuous_node_unbounded import ( - posterior_update_mean_continuous_node_unbounded, -) -from .posterior_update_precision_continuous_node_unbounded import ( - posterior_update_precision_continuous_node_unbounded, -) - @partial(jit, static_argnames=("edges", "node_idx")) def continuous_node_posterior_update_unbounded( @@ -45,12 +39,10 @@ def continuous_node_posterior_update_unbounded( """ # update the posterior mean and precision using the eHGF update step # we start with the mean update using the expected precision as an approximation - posterior_precision, precision_l1, precision_l2 = ( - posterior_update_precision_continuous_node_unbounded( - attributes, - edges, - node_idx, - ) + posterior_precision = posterior_update_precision_continuous_node_unbounded( + attributes=attributes, + edges=edges, + node_idx=node_idx, ) attributes[node_idx]["precision"] = posterior_precision @@ -58,9 +50,128 @@ def continuous_node_posterior_update_unbounded( attributes=attributes, edges=edges, node_idx=node_idx, - precision_l1=precision_l1, - precision_l2=precision_l2, ) attributes[node_idx]["mean"] = posterior_mean return attributes + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_mean_continuous_node_unbounded( + attributes: Dict, + edges: Edges, + node_idx: int, +) -> float: + """Posterior update of mean using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] # type: ignore + # volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + + # previous child uncertainty + alpha = 1 / attributes[volatility_child_idx]["expected_precision"] + + # posterior total uncertainty about the child + beta = ( + 1 / attributes[volatility_child_idx]["expected_precision"] + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + + return mu_l(alpha, beta, gamma) + + +@partial(jit, static_argnames=("edges", "node_idx")) +def posterior_update_precision_continuous_node_unbounded( + attributes: Dict, + edges: Edges, + node_idx: int, +) -> float: + """Posterior update of mean using ubounded update.""" + volatility_child_idx = edges[node_idx].volatility_children[0] # type: ignore + # volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] + gamma = attributes[node_idx]["expected_mean"] + + # previous child uncertainty + alpha = 1 / attributes[volatility_child_idx]["expected_precision"] + + # posterior total uncertainty about the child + beta = ( + 1 / attributes[volatility_child_idx]["expected_precision"] + + ( + attributes[volatility_child_idx]["mean"] + - attributes[volatility_child_idx]["expected_mean"] + ) + ** 2 + ) + + return pi_l(alpha, beta, gamma) + + +def s(x, theta, psi): + return 1 / (1 + jnp.exp(-psi * (x - theta))) + + +def b(x, theta_l, phi_l, theta_r, phi_r): + return s(x, theta_l, phi_l) * (1 - s(x, theta_r, phi_r)) + + +def pi_l1(alpha, gamma): + return 0.5 * omega(alpha, gamma) * (1 - omega(alpha, gamma)) + 0.5 + + +def mu_l1(alpha, beta, gamma): + return gamma + 0.5 / pi_l1(alpha, gamma) * omega(alpha, gamma) * delta( + alpha, beta, gamma + ) + + +def omega(alpha, x): + return jnp.exp(x) / (alpha + jnp.exp(x)) + + +def delta(alpha, beta, x): + return beta / (alpha + jnp.exp(x)) - 1 + + +def phi(alpha): + return jnp.log(alpha * (2 + jnp.sqrt(3))) + + +def pi_l2(alpha, beta): + return -ddJ(phi(alpha), alpha, beta) + + +def dJ(x, alpha, beta, gamma): + return 0.5 * omega(alpha, x) * delta(alpha, beta, x) - 0.5 * (x - gamma) + + +def ddJ(x, alpha, beta): + return ( + -0.5 + * omega(alpha, x) + * (omega(alpha, x) + (2 * omega(alpha, x) - 1) * delta(alpha, beta, x)) + - 0.5 + ) + + +def mu_l2(alpha, beta, gamma): + return phi(alpha) - dJ(phi(alpha), alpha, beta, gamma) / ddJ( + phi(alpha), alpha, beta + ) + + +def mu_l(alpha, beta, gamma): + return (1 - b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0)) * mu_l1( + alpha, beta, gamma + ) + b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0) * mu_l2( + alpha, beta, gamma + ) + + +def pi_l(alpha, beta, gamma): + return (1 - b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0)) * pi_l1( + alpha, gamma + ) + b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0) * pi_l2(alpha, beta) diff --git a/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py deleted file mode 100644 index 0f4ed643b..000000000 --- a/pyhgf/updates/posterior/continuous/posterior_update_mean_continuous_node_unbounded.py +++ /dev/null @@ -1,112 +0,0 @@ -# Author: Nicolas Legrand - -from functools import partial -from typing import Dict - -import jax.numpy as jnp -from jax import jit - -from pyhgf.typing import Edges - - -@partial(jit, static_argnames=("edges", "node_idx")) -def posterior_update_mean_continuous_node_unbounded( - attributes: Dict, - edges: Edges, - node_idx: int, - precision_l1: float, - precision_l2: float, -) -> float: - """Posterior update of mean using ubounded update.""" - volatility_child_idx = edges[node_idx].volatility_children[0] - volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] - gamma = attributes[node_idx]["expected_mean"] - phi = jnp.log( - (1 / attributes[volatility_child_idx]["precision"]) * (2 + jnp.sqrt(3)) - ) - - # first approximation ------------------------------------------------------ - delta_l1 = ( - ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] ** 2 - ) - ) - / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp( - volatility_coupling * phi - + attributes[volatility_child_idx]["tonic_volatility"] - ) - ) - ) - 1 - mean_l1 = ( - attributes[node_idx]["expected_mean"] - + ( - (volatility_coupling * attributes[node_idx]["tonic_volatility"]) - / (2 * precision_l1) - ) - * delta_l1 - ) - - # second approximation ----------------------------------------------------- - omega_phi = jnp.exp( - volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] - ) / ( - (1 / attributes[volatility_child_idx]["precision"]) - + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) - ) - delta_phi = ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp( - volatility_coupling * phi - + attributes[volatility_child_idx]["tonic_volatility"] - ) - ) - 1 - - mu_phi = ((2 * precision_l2 - 1) * phi + attributes[node_idx]["expected_mean"]) / ( - 2 * precision_l2 - ) - - mean_l2 = ( - mu_phi + (volatility_coupling * omega_phi) / (2 * precision_l2) * delta_phi - ) - - # weigthed interpolation - theta_l = jnp.sqrt( - 1.2 - * ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) - / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) - ) - phi_l = 8.0 - theta_r = 0.0 - phi_r = 1.0 - mean = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * mean_l1 + b( - gamma, theta_l, phi_l, theta_r, phi_r - ) * mean_l2 - - return mean - - -def s(x, theta, phi): - return 1 / (1 + jnp.exp(-phi * (x - theta))) - - -def b(x, theta_l, phi_l, theta_r, phi_r): - return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r)) diff --git a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py deleted file mode 100644 index 87222d277..000000000 --- a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node_unbounded.py +++ /dev/null @@ -1,84 +0,0 @@ -# Author: Nicolas Legrand - -from functools import partial -from typing import Dict - -import jax.numpy as jnp -from jax import jit - -from pyhgf.typing import Edges - - -@partial(jit, static_argnames=("edges", "node_idx")) -def posterior_update_precision_continuous_node_unbounded( - attributes: Dict, edges: Edges, node_idx: int -) -> float: - """Posterior update of precision using ubounded update.""" - volatility_child_idx = edges[node_idx].volatility_children[0] - volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0] - gamma = attributes[node_idx]["expected_mean"] - - # first approximation ------------------------------------------------------ - precision_l1 = attributes[node_idx][ - "expected_precision" - ] + 0.5 * volatility_coupling**2 * attributes[node_idx]["tonic_volatility"] * ( - 1 - attributes[node_idx]["tonic_volatility"] - ) - - # second approximation ----------------------------------------------------- - phi = jnp.log( - (1 / attributes[volatility_child_idx]["expected_precision"]) * (2 + jnp.sqrt(3)) - ) - omega_phi = jnp.exp( - volatility_coupling * phi + attributes[node_idx]["tonic_volatility"] - ) / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) - ) - delta_phi = ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) / ( - (1 / attributes[volatility_child_idx]["expected_precision"]) - + jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]) - ) - 1 - - precision_l2 = attributes[node_idx][ - "expected_precision" - ] + 0.5 * volatility_coupling**2 * omega_phi * ( - omega_phi + (2 * omega_phi - 1) * delta_phi - ) - - # weigthed interpolation - theta_l = jnp.sqrt( - 1.2 - * ( - (1 / attributes[volatility_child_idx]["precision"]) - + ( - attributes[volatility_child_idx]["mean"] - - attributes[volatility_child_idx]["expected_mean"] - ) - ** 2 - ) - / ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1) - ) - phi_l = 8.0 - theta_r = 0.0 - phi_r = 1.0 - precision = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * precision_l1 + b( - gamma, theta_l, phi_l, theta_r, phi_r - ) * precision_l2 - - return precision, precision_l1, precision_l2 - - -def s(x, theta, phi): - return 1 / (1 + jnp.exp(-phi * (x - theta))) - - -def b(x, theta_l, phi_l, theta_r, phi_r): - return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r)) diff --git a/tests/test_updates/posterior/continuous.py b/tests/test_updates/posterior/continuous.py index 6bc59f396..5141ab4d9 100644 --- a/tests/test_updates/posterior/continuous.py +++ b/tests/test_updates/posterior/continuous.py @@ -1,11 +1,25 @@ # Author: Nicolas Legrand +import jax.numpy as jnp + from pyhgf.model import Network from pyhgf.updates.posterior.continuous import ( continuous_node_posterior_update, continuous_node_posterior_update_ehgf, continuous_node_posterior_update_unbounded, ) +from pyhgf.updates.posterior.continuous.continuous_node_posterior_update_unbounded import ( + b, + delta, + mu_l, + mu_l1, + mu_l2, + omega, + pi_l, + pi_l1, + pi_l2, + s, +) def test_continuous_posterior_updates(): @@ -34,3 +48,24 @@ def test_continuous_posterior_updates(): _ = continuous_node_posterior_update_unbounded( attributes=attributes, node_idx=2, edges=edges ) + + +def test_unbounded_hgf_equations(): + + alpha = 1.0 + beta = 5.0 + gamma = 4.0 + + assert jnp.isclose(omega(alpha, gamma), 0.98201376) + assert jnp.isclose(delta(alpha, beta, gamma), -0.9100689) + + assert b(1.0, 1.0, 1.0, 1.0, 1.0) == 0.25 + assert s(1.0, 1.0, 1.0) == 0.5 + + assert jnp.isclose(pi_l1(alpha, gamma), 0.5088314) + assert jnp.isclose(pi_l2(alpha, beta), 0.82389593) + assert jnp.isclose(pi_l(alpha, beta, gamma), 0.51449823) + + assert jnp.isclose(mu_l1(alpha, beta, gamma), 3.1218112) + assert jnp.isclose(mu_l2(alpha, beta, gamma), 2.9723248) + assert jnp.isclose(mu_l(alpha, beta, gamma), 3.1191223) From 50581bef1bc33a3412004440b5ea98bd63aa23b7 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Tue, 17 Dec 2024 14:15:16 +0100 Subject: [PATCH 3/3] more equations --- ...tinuous_node_posterior_update_unbounded.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py index e35702e19..5f13a4b1f 100644 --- a/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py +++ b/pyhgf/updates/posterior/continuous/continuous_node_posterior_update_unbounded.py @@ -72,7 +72,7 @@ def posterior_update_mean_continuous_node_unbounded( # posterior total uncertainty about the child beta = ( - 1 / attributes[volatility_child_idx]["expected_precision"] + 1 / attributes[volatility_child_idx]["precision"] + ( attributes[volatility_child_idx]["mean"] - attributes[volatility_child_idx]["expected_mean"] @@ -80,7 +80,9 @@ def posterior_update_mean_continuous_node_unbounded( ** 2 ) - return mu_l(alpha, beta, gamma) + new_mu = new_mu_l1(alpha, beta, gamma, attributes, node_idx) + + return new_mu @partial(jit, static_argnames=("edges", "node_idx")) @@ -99,7 +101,7 @@ def posterior_update_precision_continuous_node_unbounded( # posterior total uncertainty about the child beta = ( - 1 / attributes[volatility_child_idx]["expected_precision"] + 1 / attributes[volatility_child_idx]["precision"] + ( attributes[volatility_child_idx]["mean"] - attributes[volatility_child_idx]["expected_mean"] @@ -107,7 +109,18 @@ def posterior_update_precision_continuous_node_unbounded( ** 2 ) - return pi_l(alpha, beta, gamma) + new_pi = new_pi_l1(alpha, gamma, attributes, node_idx) + + return new_pi + +def new_pi_l1(alpha, gamma, attributes, node_idx): + return attributes[node_idx]["expected_precision"] + attributes[node_idx]["volatility_coupling_children"][0]**2 * 0.5 * omega(alpha, gamma) * (1 - omega(alpha, gamma)) + + +def new_mu_l1(alpha, beta, gamma, attributes, node_idx): + return gamma + 0.5 / pi_l1(alpha, gamma) * omega(alpha, gamma) * delta( + alpha, beta, gamma + ) * attributes[node_idx]["volatility_coupling_children"][0] def s(x, theta, psi):