Skip to content

Commit

Permalink
posterior updates for mean and precision - not yet working
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Dec 13, 2024
1 parent b3c39f9 commit aff16cd
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

from pyhgf.typing import Edges

from .posterior_update_mean_continuous_node import posterior_update_mean_continuous_node
from .posterior_update_precision_continuous_node import (
posterior_update_precision_continuous_node,
from .posterior_update_mean_continuous_node_unbounded import (
posterior_update_mean_continuous_node_unbounded,
)
from .posterior_update_precision_continuous_node_unbounded import (
posterior_update_precision_continuous_node_unbounded,
)


Expand Down Expand Up @@ -43,19 +45,22 @@ def continuous_node_posterior_update_unbounded(
"""
# update the posterior mean and precision using the eHGF update step
# we start with the mean update using the expected precision as an approximation
posterior_mean = posterior_update_mean_continuous_node(
attributes,
edges,
node_idx,
node_precision=attributes[node_idx]["expected_precision"],
posterior_precision, precision_l1, precision_l2 = (
posterior_update_precision_continuous_node_unbounded(
attributes,
edges,
node_idx,
)
)
attributes[node_idx]["mean"] = posterior_mean
attributes[node_idx]["precision"] = posterior_precision

posterior_precision = posterior_update_precision_continuous_node(
attributes,
edges,
node_idx,
posterior_mean = posterior_update_mean_continuous_node_unbounded(
attributes=attributes,
edges=edges,
node_idx=node_idx,
precision_l1=precision_l1,
precision_l2=precision_l2,
)
attributes[node_idx]["precision"] = posterior_precision
attributes[node_idx]["mean"] = posterior_mean

return attributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial
from typing import Dict

import jax.numpy as jnp
from jax import jit

from pyhgf.typing import Edges


@partial(jit, static_argnames=("edges", "node_idx"))
def posterior_update_mean_continuous_node_unbounded(
attributes: Dict,
edges: Edges,
node_idx: int,
precision_l1: float,
precision_l2: float,
) -> float:
"""Posterior update of mean using ubounded update."""
volatility_child_idx = edges[node_idx].volatility_children[0]
volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0]
gamma = attributes[node_idx]["expected_mean"]
phi = jnp.log(
(1 / attributes[volatility_child_idx]["precision"]) * (2 + jnp.sqrt(3))
)

# first approximation ------------------------------------------------------
delta_l1 = (
(
(1 / attributes[volatility_child_idx]["precision"])
+ (
attributes[volatility_child_idx]["mean"]
- attributes[volatility_child_idx]["expected_mean"] ** 2
)
)
/ (
(1 / attributes[volatility_child_idx]["expected_precision"])
+ jnp.exp(
volatility_coupling * phi
+ attributes[volatility_child_idx]["tonic_volatility"]
)
)
) - 1
mean_l1 = (
attributes[node_idx]["expected_mean"]
+ (
(volatility_coupling * attributes[node_idx]["tonic_volatility"])
/ (2 * precision_l1)
)
* delta_l1
)

# second approximation -----------------------------------------------------
omega_phi = jnp.exp(
volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]
) / (
(1 / attributes[volatility_child_idx]["precision"])
+ jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"])
)
delta_phi = (
(1 / attributes[volatility_child_idx]["precision"])
+ (
attributes[volatility_child_idx]["mean"]
- attributes[volatility_child_idx]["expected_mean"]
)
** 2
) / (
(1 / attributes[volatility_child_idx]["expected_precision"])
+ jnp.exp(
volatility_coupling * phi
+ attributes[volatility_child_idx]["tonic_volatility"]
)
) - 1

mu_phi = ((2 * precision_l2 - 1) * phi + attributes[node_idx]["expected_mean"]) / (
2 * precision_l2
)

mean_l2 = (
mu_phi + (volatility_coupling * omega_phi) / (2 * precision_l2) * delta_phi
)

# weigthed interpolation
theta_l = jnp.sqrt(
1.2
* (
(1 / attributes[volatility_child_idx]["precision"])
+ (
attributes[volatility_child_idx]["mean"]
- attributes[volatility_child_idx]["expected_mean"]
)
** 2
)
/ ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1)
)
phi_l = 8.0
theta_r = 0.0
phi_r = 1.0
mean = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * mean_l1 + b(
gamma, theta_l, phi_l, theta_r, phi_r
) * mean_l2

return mean


def s(x, theta, phi):
return 1 / (1 + jnp.exp(-phi * (x - theta)))


def b(x, theta_l, phi_l, theta_r, phi_r):
return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r))
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial
from typing import Dict

import jax.numpy as jnp
from jax import jit

from pyhgf.typing import Edges


@partial(jit, static_argnames=("edges", "node_idx"))
def posterior_update_precision_continuous_node_unbounded(
attributes: Dict, edges: Edges, node_idx: int
) -> float:
"""Posterior update of precision using ubounded update."""
volatility_child_idx = edges[node_idx].volatility_children[0]
volatility_coupling = attributes[node_idx]["volatility_coupling_children"][0]
gamma = attributes[node_idx]["expected_mean"]

# first approximation ------------------------------------------------------
precision_l1 = attributes[node_idx][
"expected_precision"
] + 0.5 * volatility_coupling**2 * attributes[node_idx]["tonic_volatility"] * (
1 - attributes[node_idx]["tonic_volatility"]
)

# second approximation -----------------------------------------------------
phi = jnp.log(
(1 / attributes[volatility_child_idx]["expected_precision"]) * (2 + jnp.sqrt(3))
)
omega_phi = jnp.exp(
volatility_coupling * phi + attributes[node_idx]["tonic_volatility"]
) / (
(1 / attributes[volatility_child_idx]["expected_precision"])
+ jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"])
)
delta_phi = (
(1 / attributes[volatility_child_idx]["precision"])
+ (
attributes[volatility_child_idx]["mean"]
- attributes[volatility_child_idx]["expected_mean"]
)
** 2
) / (
(1 / attributes[volatility_child_idx]["expected_precision"])
+ jnp.exp(volatility_coupling * phi + attributes[node_idx]["tonic_volatility"])
) - 1

precision_l2 = attributes[node_idx][
"expected_precision"
] + 0.5 * volatility_coupling**2 * omega_phi * (
omega_phi + (2 * omega_phi - 1) * delta_phi
)

# weigthed interpolation
theta_l = jnp.sqrt(
1.2
* (
(1 / attributes[volatility_child_idx]["precision"])
+ (
attributes[volatility_child_idx]["mean"]
- attributes[volatility_child_idx]["expected_mean"]
)
** 2
)
/ ((1 / attributes[volatility_child_idx]["expected_precision"]) * precision_l1)
)
phi_l = 8.0
theta_r = 0.0
phi_r = 1.0
precision = (1 - b(gamma, theta_l, phi_l, theta_r, phi_r)) * precision_l1 + b(
gamma, theta_l, phi_l, theta_r, phi_r
) * precision_l2

return precision, precision_l1, precision_l2


def s(x, theta, phi):
return 1 / (1 + jnp.exp(-phi * (x - theta)))


def b(x, theta_l, phi_l, theta_r, phi_r):
return s(x, theta_l, phi_l) - (1 - s(x, theta_r, phi_r))

0 comments on commit aff16cd

Please sign in to comment.