-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6ff25f8
commit 4b5cdb2
Showing
1 changed file
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk> | ||
|
||
from functools import partial | ||
from typing import Dict, Tuple | ||
|
||
import jax.numpy as jnp | ||
from jax import jit | ||
|
||
|
||
@partial(jit, static_argnames=("node_idx", "node_structure")) | ||
def dirichlet_node_update( | ||
parameters_structure: Dict, | ||
node_structure, | ||
H, | ||
value: float, | ||
time_step: float, | ||
node_idx: int, | ||
**args, | ||
) -> Tuple: | ||
"""Dirichlet process node update. | ||
When receiving a new input, this node chose either: | ||
1. Allocation the value to a pre-existing cluster and updating its value and | ||
volatility parents accordingly. | ||
2. Creating a new cluster (branching) with value and volatility parents. | ||
3. Merging clusters. | ||
Parameters | ||
---------- | ||
parameters_structure : | ||
The structure of nodes' parameters. Each parameter is a dictionary with the | ||
following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for | ||
continuous nodes. | ||
.. note:: | ||
`"psis"` is the value coupling strength. It should have same length than the | ||
volatility parents' indexes. `"kappas"` is the volatility coupling strength. | ||
It should have same length than the volatility parents' indexes. | ||
node_structure : | ||
Tuple of :py:class:`pyhgf.typing.StandardNode` with same length than number of | ||
node. For each node, the index list value and volatility parents. | ||
H : | ||
Function and parameters controlling the creation and evaluation of the base | ||
distribution (`H`). The following variables are required: | ||
- `create_distribution` : a function creating a new branch starting from the | ||
Dirichlet node. | ||
- `updates` : a sequence of updates to propagate beliefs in the | ||
cluster for a given set of observations. | ||
- `pdf` : a function to evaluate the likelihood of a given observation under | ||
the base distribution. | ||
- `theta` : the distribution parameters. | ||
value : | ||
The new observed value. | ||
time_step : | ||
Interval between the previous time point and the current time point. | ||
node_idx : | ||
Pointer to the node that need to be updated. After continuous update, the | ||
parameters of value and volatility parents (if any) will be different. | ||
Returns | ||
------- | ||
parameters_structure : | ||
The updated node structure. | ||
node_structure : | ||
The updated node structure. | ||
See Also | ||
-------- | ||
update_continuous_input_parents, update_binary_input_parents | ||
""" | ||
( | ||
value_parents, | ||
volatility_parents, | ||
) = node_structure[node_idx] | ||
create_distribution, updates, pdf_distribution, theta = H | ||
|
||
# using the current node index, unwrap parameters | ||
node_parameters = parameters_structure[node_idx] | ||
|
||
# store value and time step in the input node | ||
for idx, node in enumerate(node_structure): | ||
if node.value_parents is not None: | ||
if node_idx in node.value_parents: | ||
parameters_structure[idx]["time_step"] = time_step | ||
parameters_structure[idx]["value"] = value | ||
|
||
alpha = node_parameters["alpha"] # concentration value | ||
cluster_idxs = node_parameters["cluster_idx"] # list of clusters with nodes indexes | ||
|
||
# probability to draw a new cluster | ||
n_total = jnp.sum( | ||
jnp.array(node_parameters["n"]) | ||
) # n of observation before this one | ||
pi_new = alpha / (alpha + n_total) | ||
|
||
# likelihood for new cluster given the base distribution | ||
new_likelihood = pdf_distribution(value=value, theta=theta) | ||
|
||
# joint likelihoods for a new clusters | ||
new_likelihood *= pi_new | ||
|
||
# probability of being assignet to pre-existing cluster | ||
pi_clusters = jnp.where( | ||
node_parameters["k"] == 0, | ||
jnp.array([10.0]), | ||
jnp.array(node_parameters["n"]) / (alpha + n_total), | ||
) | ||
from jax.debug import print as pt | ||
|
||
pt("{x}", x=pi_clusters) | ||
# estimate the likelihood of current observation under each value/volatility pairs | ||
likelihoods_cluster = [ | ||
pdf_distribution( | ||
value=value, | ||
cluster_idxs=c, | ||
) | ||
for c in cluster_idxs | ||
] | ||
likelihoods_cluster = jnp.where( | ||
len(likelihoods_cluster) == 0, | ||
jnp.array([0.0], ndmin=1), | ||
jnp.array(likelihoods_cluster), | ||
) | ||
|
||
# joint likelihoods for previous clusters | ||
likelihoods_cluster *= pi_clusters | ||
# ############################ | ||
# # 1- Creating a new cluser # | ||
# ############################ | ||
# if new_likelihood > likelihoods_cluster.max(): | ||
# print("Creating a new cluster") | ||
|
||
# # define the parameters for the new distribution | ||
# theta = value, 1.0, -3.0 | ||
|
||
# # create a new distribution and update the cluster index accordingly | ||
# node_structure, parameters_structure, cluster_idx = create_distribution( | ||
# dirichlet_idx=node_idx, | ||
# node_structure=node_structure, | ||
# parameters_structure=parameters_structure, | ||
# cluster_idx=cluster_idx, | ||
# theta=theta, | ||
# ) | ||
|
||
# # append a new cluster count | ||
# node_parameters["n"].append(1) | ||
# node_parameters["z"] = idx | ||
# node_parameters["k"] += 1 # number of clusters | ||
|
||
# ####################################################### | ||
# # 2 - Assigning observation to a pre-existing cluster # | ||
# ####################################################### | ||
|
||
# # pass the value to the selected cluster and update each branch of the network | ||
# # pass None to clusters that are not updated | ||
# if new_likelihood <= jnp.array(likelihoods_cluster).max(): | ||
# # which cluster should be updated | ||
# idx = jnp.argmax(likelihoods_cluster) | ||
# c_idx = cluster_idx[idx] # the cluster nodes | ||
|
||
# print(f"Assigning value to cluster {idx}") | ||
|
||
# # create the update sequence for this distribution | ||
# update_sequence = tuple([(c, f[1]) for c, f in zip(c_idx, cluster_updates)]) | ||
|
||
# parameters_structure = apply_sequence( | ||
# value=value, | ||
# time_step=time_step, | ||
# parameters_structure=parameters_structure, | ||
# node_structure=node_structure, | ||
# update_sequence=update_sequence, | ||
# ) | ||
|
||
# # update the number of observation in this cluster | ||
# node_parameters["n"][idx] += 1 | ||
# node_parameters["z"] = idx | ||
|
||
return parameters_structure, node_structure |