Skip to content

Commit

Permalink
dirichlet distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 19, 2023
1 parent 6ff25f8 commit 4b5cdb2
Showing 1 changed file with 178 additions and 0 deletions.
178 changes: 178 additions & 0 deletions src/pyhgf/dirichlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial
from typing import Dict, Tuple

import jax.numpy as jnp
from jax import jit


@partial(jit, static_argnames=("node_idx", "node_structure"))
def dirichlet_node_update(
parameters_structure: Dict,
node_structure,
H,
value: float,
time_step: float,
node_idx: int,
**args,
) -> Tuple:
"""Dirichlet process node update.
When receiving a new input, this node chose either:
1. Allocation the value to a pre-existing cluster and updating its value and
volatility parents accordingly.
2. Creating a new cluster (branching) with value and volatility parents.
3. Merging clusters.
Parameters
----------
parameters_structure :
The structure of nodes' parameters. Each parameter is a dictionary with the
following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for
continuous nodes.
.. note::
`"psis"` is the value coupling strength. It should have same length than the
volatility parents' indexes. `"kappas"` is the volatility coupling strength.
It should have same length than the volatility parents' indexes.
node_structure :
Tuple of :py:class:`pyhgf.typing.StandardNode` with same length than number of
node. For each node, the index list value and volatility parents.
H :
Function and parameters controlling the creation and evaluation of the base
distribution (`H`). The following variables are required:
- `create_distribution` : a function creating a new branch starting from the
Dirichlet node.
- `updates` : a sequence of updates to propagate beliefs in the
cluster for a given set of observations.
- `pdf` : a function to evaluate the likelihood of a given observation under
the base distribution.
- `theta` : the distribution parameters.
value :
The new observed value.
time_step :
Interval between the previous time point and the current time point.
node_idx :
Pointer to the node that need to be updated. After continuous update, the
parameters of value and volatility parents (if any) will be different.
Returns
-------
parameters_structure :
The updated node structure.
node_structure :
The updated node structure.
See Also
--------
update_continuous_input_parents, update_binary_input_parents
"""
(
value_parents,
volatility_parents,
) = node_structure[node_idx]
create_distribution, updates, pdf_distribution, theta = H

# using the current node index, unwrap parameters
node_parameters = parameters_structure[node_idx]

# store value and time step in the input node
for idx, node in enumerate(node_structure):
if node.value_parents is not None:
if node_idx in node.value_parents:
parameters_structure[idx]["time_step"] = time_step
parameters_structure[idx]["value"] = value

alpha = node_parameters["alpha"] # concentration value
cluster_idxs = node_parameters["cluster_idx"] # list of clusters with nodes indexes

# probability to draw a new cluster
n_total = jnp.sum(
jnp.array(node_parameters["n"])
) # n of observation before this one
pi_new = alpha / (alpha + n_total)

# likelihood for new cluster given the base distribution
new_likelihood = pdf_distribution(value=value, theta=theta)

# joint likelihoods for a new clusters
new_likelihood *= pi_new

# probability of being assignet to pre-existing cluster
pi_clusters = jnp.where(
node_parameters["k"] == 0,
jnp.array([10.0]),
jnp.array(node_parameters["n"]) / (alpha + n_total),
)
from jax.debug import print as pt

pt("{x}", x=pi_clusters)
# estimate the likelihood of current observation under each value/volatility pairs
likelihoods_cluster = [
pdf_distribution(
value=value,
cluster_idxs=c,
)
for c in cluster_idxs
]
likelihoods_cluster = jnp.where(
len(likelihoods_cluster) == 0,
jnp.array([0.0], ndmin=1),
jnp.array(likelihoods_cluster),
)

# joint likelihoods for previous clusters
likelihoods_cluster *= pi_clusters
# ############################
# # 1- Creating a new cluser #
# ############################
# if new_likelihood > likelihoods_cluster.max():
# print("Creating a new cluster")

# # define the parameters for the new distribution
# theta = value, 1.0, -3.0

# # create a new distribution and update the cluster index accordingly
# node_structure, parameters_structure, cluster_idx = create_distribution(
# dirichlet_idx=node_idx,
# node_structure=node_structure,
# parameters_structure=parameters_structure,
# cluster_idx=cluster_idx,
# theta=theta,
# )

# # append a new cluster count
# node_parameters["n"].append(1)
# node_parameters["z"] = idx
# node_parameters["k"] += 1 # number of clusters

# #######################################################
# # 2 - Assigning observation to a pre-existing cluster #
# #######################################################

# # pass the value to the selected cluster and update each branch of the network
# # pass None to clusters that are not updated
# if new_likelihood <= jnp.array(likelihoods_cluster).max():
# # which cluster should be updated
# idx = jnp.argmax(likelihoods_cluster)
# c_idx = cluster_idx[idx] # the cluster nodes

# print(f"Assigning value to cluster {idx}")

# # create the update sequence for this distribution
# update_sequence = tuple([(c, f[1]) for c, f in zip(c_idx, cluster_updates)])

# parameters_structure = apply_sequence(
# value=value,
# time_step=time_step,
# parameters_structure=parameters_structure,
# node_structure=node_structure,
# update_sequence=update_sequence,
# )

# # update the number of observation in this cluster
# node_parameters["n"][idx] += 1
# node_parameters["z"] = idx

return parameters_structure, node_structure

0 comments on commit 4b5cdb2

Please sign in to comment.