Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unbounded posterior updates for volatility parents #262

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def input_idxs(self, value):
self.input_idxs = value

def create_belief_propagation_fn(
self, overwrite: bool = True, update_type: str = "eHGF"
self, overwrite: bool = True, update_type: str = "unbounded"
) -> "Network":
"""Create the belief propagation function.

Expand All @@ -97,11 +97,16 @@ def create_belief_propagation_fn(
preexisting values. Otherwise, do not create a new function if the attribute
`scan_fn` is already defined.
update_type :
The type of update to perform for volatility coupling. Can be `"eHGF"`
(defaults) or `"standard"`. The eHGF update step was proposed as an
The type of update to perform for volatility coupling. Can be `"unbounded"`
(defaults), `"ehgf"` or `"standard"`. The unbounded approximation was
recently introduced to avoid negative precisions updates, which greatly
improve sampling performance. The eHGF update step was proposed as an
alternative to the original definition in that it starts by updating the
mean and then the precision of the parent node, which generally reduces the
errors associated with impossible parameter space and improves sampling.
occurence of negative precision updates, while not removing them entirely.
.. note:
The different update steps only apply to nodes having at least one
volatility parents. In other cases, the regular HGF updates are applied.

"""
# create the update sequence if it does not already exist
Expand Down
3 changes: 3 additions & 0 deletions pyhgf/updates/posterior/continuous/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .continuous_node_posterior_update import continuous_node_posterior_update
from .continuous_node_posterior_update_ehgf import continuous_node_posterior_update_ehgf
from .continuous_node_posterior_update_unbounded import (
continuous_node_posterior_update_unbounded,
)

__all__ = [
"continuous_node_posterior_update_ehgf",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial
from typing import Dict

import jax.numpy as jnp
from jax import jit

from pyhgf.typing import Edges


@partial(jit, static_argnames=("edges", "node_idx"))
def continuous_node_posterior_update_unbounded(
attributes: Dict, node_idx: int, edges: Edges, **args
) -> Dict:
"""Update the posterior of a continuous node using an unbounded approximation.

Parameters
----------
attributes :
The attributes of the probabilistic nodes.
node_idx :
Pointer to the node that needs to be updated. After continuous updates, the
parameters of value and volatility parents (if any) will be different.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
For each node, the index list value and volatility parents and children.

Returns
-------
attributes :
The updated attributes of the probabilistic nodes.

See Also
--------
continuous_node_posterior_update_ehgf

"""
# update the posterior mean and precision using the eHGF update step
# we start with the mean update using the expected precision as an approximation
posterior_precision = posterior_update_precision_continuous_node_unbounded(
attributes=attributes,
edges=edges,
node_idx=node_idx,
)
attributes[node_idx]["precision"] = posterior_precision

posterior_mean = posterior_update_mean_continuous_node_unbounded(
attributes=attributes,
edges=edges,
node_idx=node_idx,
)
attributes[node_idx]["mean"] = posterior_mean

return attributes


@partial(jit, static_argnames=("edges", "node_idx"))
def posterior_update_mean_continuous_node_unbounded(
attributes: Dict,
edges: Edges,
node_idx: int,
) -> float:
"""Posterior update of mean using ubounded update."""
volatility_child_idx = edges[node_idx].volatility_children[0] # type: ignore
# volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0]
gamma = attributes[node_idx]["expected_mean"]

# previous child uncertainty
alpha = 1 / attributes[volatility_child_idx]["expected_precision"]

# posterior total uncertainty about the child
beta = (
1 / attributes[volatility_child_idx]["precision"]
+ (
attributes[volatility_child_idx]["mean"]
- attributes[volatility_child_idx]["expected_mean"]
)
** 2
)

new_mu = new_mu_l1(alpha, beta, gamma, attributes, node_idx)

return new_mu


@partial(jit, static_argnames=("edges", "node_idx"))
def posterior_update_precision_continuous_node_unbounded(
attributes: Dict,
edges: Edges,
node_idx: int,
) -> float:
"""Posterior update of mean using ubounded update."""
volatility_child_idx = edges[node_idx].volatility_children[0] # type: ignore
# volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0]
gamma = attributes[node_idx]["expected_mean"]

# previous child uncertainty
alpha = 1 / attributes[volatility_child_idx]["expected_precision"]

# posterior total uncertainty about the child
beta = (
1 / attributes[volatility_child_idx]["precision"]
+ (
attributes[volatility_child_idx]["mean"]
- attributes[volatility_child_idx]["expected_mean"]
)
** 2
)

new_pi = new_pi_l1(alpha, gamma, attributes, node_idx)

return new_pi

def new_pi_l1(alpha, gamma, attributes, node_idx):
return attributes[node_idx]["expected_precision"] + attributes[node_idx]["volatility_coupling_children"][0]**2 * 0.5 * omega(alpha, gamma) * (1 - omega(alpha, gamma))


def new_mu_l1(alpha, beta, gamma, attributes, node_idx):
return gamma + 0.5 / pi_l1(alpha, gamma) * omega(alpha, gamma) * delta(
alpha, beta, gamma
) * attributes[node_idx]["volatility_coupling_children"][0]


def s(x, theta, psi):
return 1 / (1 + jnp.exp(-psi * (x - theta)))


def b(x, theta_l, phi_l, theta_r, phi_r):
return s(x, theta_l, phi_l) * (1 - s(x, theta_r, phi_r))


def pi_l1(alpha, gamma):
return 0.5 * omega(alpha, gamma) * (1 - omega(alpha, gamma)) + 0.5


def mu_l1(alpha, beta, gamma):
return gamma + 0.5 / pi_l1(alpha, gamma) * omega(alpha, gamma) * delta(
alpha, beta, gamma
)


def omega(alpha, x):
return jnp.exp(x) / (alpha + jnp.exp(x))


def delta(alpha, beta, x):
return beta / (alpha + jnp.exp(x)) - 1


def phi(alpha):
return jnp.log(alpha * (2 + jnp.sqrt(3)))


def pi_l2(alpha, beta):
return -ddJ(phi(alpha), alpha, beta)


def dJ(x, alpha, beta, gamma):
return 0.5 * omega(alpha, x) * delta(alpha, beta, x) - 0.5 * (x - gamma)


def ddJ(x, alpha, beta):
return (
-0.5
* omega(alpha, x)
* (omega(alpha, x) + (2 * omega(alpha, x) - 1) * delta(alpha, beta, x))
- 0.5
)


def mu_l2(alpha, beta, gamma):
return phi(alpha) - dJ(phi(alpha), alpha, beta, gamma) / ddJ(
phi(alpha), alpha, beta
)


def mu_l(alpha, beta, gamma):
return (1 - b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0)) * mu_l1(
alpha, beta, gamma
) + b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0) * mu_l2(
alpha, beta, gamma
)


def pi_l(alpha, beta, gamma):
return (1 - b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0)) * pi_l1(
alpha, gamma
) + b(gamma, -jnp.sqrt(1.2 * 2 * beta / alpha), 8.0, 0.0, 1.0) * pi_l2(alpha, beta)
8 changes: 7 additions & 1 deletion pyhgf/utils/get_update_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
continuous_node_posterior_update_unbounded,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
Expand Down Expand Up @@ -135,7 +136,12 @@ def get_update_sequence(
if all([i not in nodes_without_prediction_error for i in all_children]):
no_update = False
if network.edges[idx].node_type == 2:
if update_type == "eHGF":
if update_type == "unbounded":
if network.edges[idx].volatility_children is not None:
update_fn = continuous_node_posterior_update_unbounded
else:
update_fn = continuous_node_posterior_update
elif update_type == "eHGF":
if network.edges[idx].volatility_children is not None:
update_fn = continuous_node_posterior_update_ehgf
else:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_updates/posterior/continuous.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

import jax.numpy as jnp

from pyhgf.model import Network
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
continuous_node_posterior_update_unbounded,
)
from pyhgf.updates.posterior.continuous.continuous_node_posterior_update_unbounded import (
b,
delta,
mu_l,
mu_l1,
mu_l2,
omega,
pi_l,
pi_l1,
pi_l2,
s,
)


def test_continuous_posterior_updates():
Expand Down Expand Up @@ -34,3 +48,24 @@ def test_continuous_posterior_updates():
_ = continuous_node_posterior_update_unbounded(
attributes=attributes, node_idx=2, edges=edges
)


def test_unbounded_hgf_equations():

alpha = 1.0
beta = 5.0
gamma = 4.0

assert jnp.isclose(omega(alpha, gamma), 0.98201376)
assert jnp.isclose(delta(alpha, beta, gamma), -0.9100689)

assert b(1.0, 1.0, 1.0, 1.0, 1.0) == 0.25
assert s(1.0, 1.0, 1.0) == 0.5

assert jnp.isclose(pi_l1(alpha, gamma), 0.5088314)
assert jnp.isclose(pi_l2(alpha, beta), 0.82389593)
assert jnp.isclose(pi_l(alpha, beta, gamma), 0.51449823)

assert jnp.isclose(mu_l1(alpha, beta, gamma), 3.1218112)
assert jnp.isclose(mu_l2(alpha, beta, gamma), 2.9723248)
assert jnp.isclose(mu_l(alpha, beta, gamma), 3.1191223)
Loading