From 4b5cdb26b0828a692042a00d8a712cc96a2dd0ac Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Wed, 20 Sep 2023 00:01:05 +0200 Subject: [PATCH] dirichlet distribution --- src/pyhgf/dirichlet.py | 178 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 src/pyhgf/dirichlet.py diff --git a/src/pyhgf/dirichlet.py b/src/pyhgf/dirichlet.py new file mode 100644 index 000000000..aaed679b9 --- /dev/null +++ b/src/pyhgf/dirichlet.py @@ -0,0 +1,178 @@ +# Author: Nicolas Legrand + +from functools import partial +from typing import Dict, Tuple + +import jax.numpy as jnp +from jax import jit + + +@partial(jit, static_argnames=("node_idx", "node_structure")) +def dirichlet_node_update( + parameters_structure: Dict, + node_structure, + H, + value: float, + time_step: float, + node_idx: int, + **args, +) -> Tuple: + """Dirichlet process node update. + + When receiving a new input, this node chose either: + 1. Allocation the value to a pre-existing cluster and updating its value and + volatility parents accordingly. + 2. Creating a new cluster (branching) with value and volatility parents. + 3. Merging clusters. + + Parameters + ---------- + parameters_structure : + The structure of nodes' parameters. Each parameter is a dictionary with the + following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for + continuous nodes. + .. note:: + `"psis"` is the value coupling strength. It should have same length than the + volatility parents' indexes. `"kappas"` is the volatility coupling strength. + It should have same length than the volatility parents' indexes. + node_structure : + Tuple of :py:class:`pyhgf.typing.StandardNode` with same length than number of + node. For each node, the index list value and volatility parents. + H : + Function and parameters controlling the creation and evaluation of the base + distribution (`H`). The following variables are required: + - `create_distribution` : a function creating a new branch starting from the + Dirichlet node. + - `updates` : a sequence of updates to propagate beliefs in the + cluster for a given set of observations. + - `pdf` : a function to evaluate the likelihood of a given observation under + the base distribution. + - `theta` : the distribution parameters. + value : + The new observed value. + time_step : + Interval between the previous time point and the current time point. + node_idx : + Pointer to the node that need to be updated. After continuous update, the + parameters of value and volatility parents (if any) will be different. + + Returns + ------- + parameters_structure : + The updated node structure. + node_structure : + The updated node structure. + + See Also + -------- + update_continuous_input_parents, update_binary_input_parents + + """ + ( + value_parents, + volatility_parents, + ) = node_structure[node_idx] + create_distribution, updates, pdf_distribution, theta = H + + # using the current node index, unwrap parameters + node_parameters = parameters_structure[node_idx] + + # store value and time step in the input node + for idx, node in enumerate(node_structure): + if node.value_parents is not None: + if node_idx in node.value_parents: + parameters_structure[idx]["time_step"] = time_step + parameters_structure[idx]["value"] = value + + alpha = node_parameters["alpha"] # concentration value + cluster_idxs = node_parameters["cluster_idx"] # list of clusters with nodes indexes + + # probability to draw a new cluster + n_total = jnp.sum( + jnp.array(node_parameters["n"]) + ) # n of observation before this one + pi_new = alpha / (alpha + n_total) + + # likelihood for new cluster given the base distribution + new_likelihood = pdf_distribution(value=value, theta=theta) + + # joint likelihoods for a new clusters + new_likelihood *= pi_new + + # probability of being assignet to pre-existing cluster + pi_clusters = jnp.where( + node_parameters["k"] == 0, + jnp.array([10.0]), + jnp.array(node_parameters["n"]) / (alpha + n_total), + ) + from jax.debug import print as pt + + pt("{x}", x=pi_clusters) + # estimate the likelihood of current observation under each value/volatility pairs + likelihoods_cluster = [ + pdf_distribution( + value=value, + cluster_idxs=c, + ) + for c in cluster_idxs + ] + likelihoods_cluster = jnp.where( + len(likelihoods_cluster) == 0, + jnp.array([0.0], ndmin=1), + jnp.array(likelihoods_cluster), + ) + + # joint likelihoods for previous clusters + likelihoods_cluster *= pi_clusters + # ############################ + # # 1- Creating a new cluser # + # ############################ + # if new_likelihood > likelihoods_cluster.max(): + # print("Creating a new cluster") + + # # define the parameters for the new distribution + # theta = value, 1.0, -3.0 + + # # create a new distribution and update the cluster index accordingly + # node_structure, parameters_structure, cluster_idx = create_distribution( + # dirichlet_idx=node_idx, + # node_structure=node_structure, + # parameters_structure=parameters_structure, + # cluster_idx=cluster_idx, + # theta=theta, + # ) + + # # append a new cluster count + # node_parameters["n"].append(1) + # node_parameters["z"] = idx + # node_parameters["k"] += 1 # number of clusters + + # ####################################################### + # # 2 - Assigning observation to a pre-existing cluster # + # ####################################################### + + # # pass the value to the selected cluster and update each branch of the network + # # pass None to clusters that are not updated + # if new_likelihood <= jnp.array(likelihoods_cluster).max(): + # # which cluster should be updated + # idx = jnp.argmax(likelihoods_cluster) + # c_idx = cluster_idx[idx] # the cluster nodes + + # print(f"Assigning value to cluster {idx}") + + # # create the update sequence for this distribution + # update_sequence = tuple([(c, f[1]) for c, f in zip(c_idx, cluster_updates)]) + + # parameters_structure = apply_sequence( + # value=value, + # time_step=time_step, + # parameters_structure=parameters_structure, + # node_structure=node_structure, + # update_sequence=update_sequence, + # ) + + # # update the number of observation in this cluster + # node_parameters["n"][idx] += 1 + # node_parameters["z"] = idx + + return parameters_structure, node_structure