From c76360566e12802b6155412c874b04352a4b46e7 Mon Sep 17 00:00:00 2001 From: LegrandNico <nicolas.legrand@cas.au.dk> Date: Wed, 4 Oct 2023 18:02:07 +0200 Subject: [PATCH] update the dirichlet step --- src/pyhgf/model.py | 19 ++- src/pyhgf/networks.py | 55 +++++++++ src/pyhgf/plots.py | 2 + src/pyhgf/updates/dirichlet.py | 203 ++++++++++++--------------------- 4 files changed, 143 insertions(+), 136 deletions(-) diff --git a/src/pyhgf/model.py b/src/pyhgf/model.py index 22017e4c1..4d4d5fe9e 100644 --- a/src/pyhgf/model.py +++ b/src/pyhgf/model.py @@ -462,11 +462,12 @@ def add_input_node( "binary_precision": jnp.inf, }, categorical_parameters: Dict = {"n_categories": 4}, + dirichlet_parameters: Dict = {"alpha": 1.0}, additional_parameters: Optional[Dict] = None, ): """Create an input node. - Three types of input nodes are supported: + Four types of input nodes are supported: - `continuous`: receive a continuous observation as input. The parameter `continuous_precision` is required. @@ -474,6 +475,10 @@ def add_input_node( `binary_precision`, `eta0` and `eta1` are required. - `categorical` receive a boolean array as observation. The parameters are `n_categories` required. + - `dirichlet_process` A Dirichlet process node is a distribution over a network + of node. The type of input and the size of inputs is the one supported by the + base network. This input node has the capability of merging and creating new + networks to fit the complexity of the observation. .. note: When using `categorical`, the implied `n` binary HGFs are automatically @@ -483,8 +488,8 @@ def add_input_node( Parameters ---------- kind : - The kind of input that should be created (can be `"continuous"`, `"binary"` - or `"categorical"`). + The kind of input that should be created (can be `"continuous"`, `"binary"`, + `"categorical"` or `"dirichlet_process"`). input_idxs : The index of the new input (defaults to `0`). continuous_parameters : @@ -504,6 +509,8 @@ def add_input_node( .. note:: When using a categorical state node, the `binary_parameters` can be used to parametrize the implied collection of binray HGFs. + dirichlet_parameters : + Additional parameters for the Dirichlet process node. additional_parameters : Add more custom parameters to the input node. @@ -544,6 +551,12 @@ def add_input_node( ), "value": jnp.zeros(categorical_parameters["n_categories"]), } + elif kind == "dirichlet_process": + input_node_parameters = { + "n": None, + "alpha": dirichlet_parameters["alpha"], + "cluster_idx": None, + } # add more parameters (optional) if additional_parameters is not None: diff --git a/src/pyhgf/networks.py b/src/pyhgf/networks.py index aa8a17097..d75e8bd65 100644 --- a/src/pyhgf/networks.py +++ b/src/pyhgf/networks.py @@ -521,3 +521,58 @@ def to_pandas(hgf: "HGF") -> pd.DataFrame: ].sum(axis=1, min_count=1) return structure_df + + +def concatenate_networks(attributes_1, attributes_2, edges_1, edges_2): + """ "Concatenate two networks. + + Parameters + ---------- + attributes_1 : + + attributes_2 : + + edges_1 : + + edges_2 : + + Returns + ------- + attributes : + + edges : + + """ + n_nodes = len(attributes_2) + edges_1 = list(edges_1) + attributes = {} + for i in range(len(attributes_1)): + # update the attributes + attributes[i + n_nodes] = attributes_1[i] + + # update the edges + edges_1[i] = Indexes( + value_parents=tuple([e + n_nodes for e in list(edges_1[i].value_parents)]) + if edges_1[i].value_parents is not None + else None, + volatility_parents=tuple( + [e + n_nodes for e in list(edges_1[i].volatility_parents)] + ) + if edges_1[i].volatility_parents is not None + else None, + value_children=tuple([e + n_nodes for e in list(edges_1[i].value_children)]) + if edges_1[i].value_children is not None + else None, + volatility_children=tuple( + [e + n_nodes for e in list(edges_1[i].volatility_children)] + ) + if edges_1[i].volatility_children is not None + else None, + ) + + edges_1 = tuple(edges_1) + + attributes = {**attributes_2, **attributes} + edges = edges_2 + edges_1 + + return attributes, edges diff --git a/src/pyhgf/plots.py b/src/pyhgf/plots.py index 52ccde447..52a7afed7 100644 --- a/src/pyhgf/plots.py +++ b/src/pyhgf/plots.py @@ -261,6 +261,8 @@ def plot_network(hgf: "HGF") -> "Source": label, shape = f"Bi-{idx}", "box" elif kind == "categorical": label, shape = f"Ca-{idx}", "diamond" + elif kind == "dirichlet_process": + label, shape = f"DP-{idx}", "doublecircle" graphviz_structure.node( f"x_{idx}", label=label, diff --git a/src/pyhgf/updates/dirichlet.py b/src/pyhgf/updates/dirichlet.py index aaed679b9..702ae5f7b 100644 --- a/src/pyhgf/updates/dirichlet.py +++ b/src/pyhgf/updates/dirichlet.py @@ -1,55 +1,44 @@ # Author: Nicolas Legrand <nicolas.legrand@cas.au.dk> -from functools import partial -from typing import Dict, Tuple +from typing import Callable, Dict, Tuple import jax.numpy as jnp -from jax import jit +from pyhgf.model import HGF +from pyhgf.typing import Edges -@partial(jit, static_argnames=("node_idx", "node_structure")) -def dirichlet_node_update( - parameters_structure: Dict, - node_structure, - H, - value: float, + +# @partial(jit, static_argnames=("edges", "cluster_creation", "child_network")) +def dirichlet_process_node_prediction_error( + edges: Edges, + attributes: Dict, + values: float, time_step: float, node_idx: int, + cluster_creation: Callable, + parametrize_cluster: Callable, + child_network: "HGF", + likelihood_function: Callable, **args, ) -> Tuple: - """Dirichlet process node update. + """Prediction error and update the child networks of a Dirichlet process node. When receiving a new input, this node chose either: - 1. Allocation the value to a pre-existing cluster and updating its value and - volatility parents accordingly. - 2. Creating a new cluster (branching) with value and volatility parents. - 3. Merging clusters. + 1. Allocation the value to a pre-existing cluster. + 2. Creating a new cluster. + 3. Merging two pre-existing clusters. Parameters ---------- - parameters_structure : - The structure of nodes' parameters. Each parameter is a dictionary with the - following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for - continuous nodes. - .. note:: - `"psis"` is the value coupling strength. It should have same length than the - volatility parents' indexes. `"kappas"` is the volatility coupling strength. - It should have same length than the volatility parents' indexes. - node_structure : - Tuple of :py:class:`pyhgf.typing.StandardNode` with same length than number of - node. For each node, the index list value and volatility parents. - H : - Function and parameters controlling the creation and evaluation of the base - distribution (`H`). The following variables are required: - - `create_distribution` : a function creating a new branch starting from the - Dirichlet node. - - `updates` : a sequence of updates to propagate beliefs in the - cluster for a given set of observations. - - `pdf` : a function to evaluate the likelihood of a given observation under - the base distribution. - - `theta` : the distribution parameters. - value : - The new observed value. + edges : + The edges of the probabilistic nodes as a tuple of + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. + For each node, the index list value and volatility parents and children. + attributes : + The attributes of the probabilistic nodes. + values : + The new observed value(s). The shape of the input should match the input shape + at the child level. time_step : Interval between the previous time point and the current time point. node_idx : @@ -58,9 +47,9 @@ def dirichlet_node_update( Returns ------- - parameters_structure : + edges : The updated node structure. - node_structure : + attributes : The updated node structure. See Also @@ -68,111 +57,59 @@ def dirichlet_node_update( update_continuous_input_parents, update_binary_input_parents """ - ( - value_parents, - volatility_parents, - ) = node_structure[node_idx] - create_distribution, updates, pdf_distribution, theta = H + alpha = attributes[node_idx]["alpha"] # concentration parameter + cluster_idxs = attributes[node_idx]["cluster_idx"] # clusters inputs indexes + + ############################################ + # Case 1: no cluster available, create one # + ############################################ + if attributes[node_idx]["cluster_idx"] is None: + # increment the number of observations + attributes[node_idx]["n"] = [1] + attributes[node_idx]["cluster_idx"] = len(attributes) - # using the current node index, unwrap parameters - node_parameters = parameters_structure[node_idx] + # create a new branch + attributes, edges = cluster_creation(values, child_network, attributes, edges) - # store value and time step in the input node - for idx, node in enumerate(node_structure): - if node.value_parents is not None: - if node_idx in node.value_parents: - parameters_structure[idx]["time_step"] = time_step - parameters_structure[idx]["value"] = value + return cluster_creation(values, child_network, attributes, edges) - alpha = node_parameters["alpha"] # concentration value - cluster_idxs = node_parameters["cluster_idx"] # list of clusters with nodes indexes + # - A new cluster -# + # --------------- # # probability to draw a new cluster - n_total = jnp.sum( - jnp.array(node_parameters["n"]) - ) # n of observation before this one + n_total = jnp.sum(jnp.array(attributes[node_idx]["n"])) pi_new = alpha / (alpha + n_total) - # likelihood for new cluster given the base distribution - new_likelihood = pdf_distribution(value=value, theta=theta) + # likelihood for a new cluster given the base network + new_attributes, new_edges = parametrize_cluster(values, child_network) + new_likelihood = likelihood_function(0, values, new_attributes, new_edges) # joint likelihoods for a new clusters new_likelihood *= pi_new - # probability of being assignet to pre-existing cluster - pi_clusters = jnp.where( - node_parameters["k"] == 0, - jnp.array([10.0]), - jnp.array(node_parameters["n"]) / (alpha + n_total), - ) - from jax.debug import print as pt - - pt("{x}", x=pi_clusters) - # estimate the likelihood of current observation under each value/volatility pairs - likelihoods_cluster = [ - pdf_distribution( - value=value, - cluster_idxs=c, + # - An existing cluster -# + # --------------------- # + + # likelihood for an existing cluster given the base network + clusters_likelihood = [] + for input_idx in cluster_idxs: + clusters_likelihood.append( + likelihood_function(input_idx, values, attributes, edges) ) - for c in cluster_idxs - ] - likelihoods_cluster = jnp.where( - len(likelihoods_cluster) == 0, - jnp.array([0.0], ndmin=1), - jnp.array(likelihoods_cluster), - ) + + # probability of being assignet to pre-existing cluster + pi_clusters = jnp.array(attributes[node_idx]["n"]) / (alpha + n_total) # joint likelihoods for previous clusters - likelihoods_cluster *= pi_clusters - # ############################ - # # 1- Creating a new cluser # - # ############################ - # if new_likelihood > likelihoods_cluster.max(): - # print("Creating a new cluster") - - # # define the parameters for the new distribution - # theta = value, 1.0, -3.0 - - # # create a new distribution and update the cluster index accordingly - # node_structure, parameters_structure, cluster_idx = create_distribution( - # dirichlet_idx=node_idx, - # node_structure=node_structure, - # parameters_structure=parameters_structure, - # cluster_idx=cluster_idx, - # theta=theta, - # ) - - # # append a new cluster count - # node_parameters["n"].append(1) - # node_parameters["z"] = idx - # node_parameters["k"] += 1 # number of clusters - - # ####################################################### - # # 2 - Assigning observation to a pre-existing cluster # - # ####################################################### - - # # pass the value to the selected cluster and update each branch of the network - # # pass None to clusters that are not updated - # if new_likelihood <= jnp.array(likelihoods_cluster).max(): - # # which cluster should be updated - # idx = jnp.argmax(likelihoods_cluster) - # c_idx = cluster_idx[idx] # the cluster nodes - - # print(f"Assigning value to cluster {idx}") - - # # create the update sequence for this distribution - # update_sequence = tuple([(c, f[1]) for c, f in zip(c_idx, cluster_updates)]) - - # parameters_structure = apply_sequence( - # value=value, - # time_step=time_step, - # parameters_structure=parameters_structure, - # node_structure=node_structure, - # update_sequence=update_sequence, - # ) - - # # update the number of observation in this cluster - # node_parameters["n"][idx] += 1 - # node_parameters["z"] = idx - - return parameters_structure, node_structure + clusters_likelihood *= pi_clusters + + print(clusters_likelihood) + ################################# + # Case 2: Creating a new cluser # + ################################# + + ########################################################### + # Case 3: Assigning observation to a pre-existing cluster # + ########################################################### + + return edges, attributes