From c76360566e12802b6155412c874b04352a4b46e7 Mon Sep 17 00:00:00 2001
From: LegrandNico <nicolas.legrand@cas.au.dk>
Date: Wed, 4 Oct 2023 18:02:07 +0200
Subject: [PATCH] update the dirichlet step

---
 src/pyhgf/model.py             |  19 ++-
 src/pyhgf/networks.py          |  55 +++++++++
 src/pyhgf/plots.py             |   2 +
 src/pyhgf/updates/dirichlet.py | 203 ++++++++++++---------------------
 4 files changed, 143 insertions(+), 136 deletions(-)

diff --git a/src/pyhgf/model.py b/src/pyhgf/model.py
index 22017e4c1..4d4d5fe9e 100644
--- a/src/pyhgf/model.py
+++ b/src/pyhgf/model.py
@@ -462,11 +462,12 @@ def add_input_node(
             "binary_precision": jnp.inf,
         },
         categorical_parameters: Dict = {"n_categories": 4},
+        dirichlet_parameters: Dict = {"alpha": 1.0},
         additional_parameters: Optional[Dict] = None,
     ):
         """Create an input node.
 
-        Three types of input nodes are supported:
+        Four types of input nodes are supported:
 
         - `continuous`: receive a continuous observation as input. The parameter
         `continuous_precision` is required.
@@ -474,6 +475,10 @@ def add_input_node(
         `binary_precision`, `eta0` and `eta1` are required.
         - `categorical` receive a boolean array as observation. The parameters are
         `n_categories` required.
+        - `dirichlet_process` A Dirichlet process node is a distribution over a network
+        of node. The type of input and the size of inputs is the one supported by the
+        base network. This input node has the capability of merging and creating new
+        networks to fit the complexity of the observation.
 
         .. note:
            When using `categorical`, the implied `n` binary HGFs are automatically
@@ -483,8 +488,8 @@ def add_input_node(
         Parameters
         ----------
         kind :
-            The kind of input that should be created (can be `"continuous"`, `"binary"`
-            or `"categorical"`).
+            The kind of input that should be created (can be `"continuous"`, `"binary"`,
+            `"categorical"` or `"dirichlet_process"`).
         input_idxs :
             The index of the new input (defaults to `0`).
         continuous_parameters :
@@ -504,6 +509,8 @@ def add_input_node(
             .. note::
                When using a categorical state node, the `binary_parameters` can be used
                to parametrize the implied collection of binray HGFs.
+        dirichlet_parameters :
+            Additional parameters for the Dirichlet process node.
         additional_parameters :
             Add more custom parameters to the input node.
 
@@ -544,6 +551,12 @@ def add_input_node(
                 ),
                 "value": jnp.zeros(categorical_parameters["n_categories"]),
             }
+        elif kind == "dirichlet_process":
+            input_node_parameters = {
+                "n": None,
+                "alpha": dirichlet_parameters["alpha"],
+                "cluster_idx": None,
+            }
 
         # add more parameters (optional)
         if additional_parameters is not None:
diff --git a/src/pyhgf/networks.py b/src/pyhgf/networks.py
index aa8a17097..d75e8bd65 100644
--- a/src/pyhgf/networks.py
+++ b/src/pyhgf/networks.py
@@ -521,3 +521,58 @@ def to_pandas(hgf: "HGF") -> pd.DataFrame:
     ].sum(axis=1, min_count=1)
 
     return structure_df
+
+
+def concatenate_networks(attributes_1, attributes_2, edges_1, edges_2):
+    """ "Concatenate two networks.
+
+    Parameters
+    ----------
+    attributes_1 :
+
+    attributes_2 :
+
+    edges_1 :
+
+    edges_2 :
+
+    Returns
+    -------
+    attributes :
+
+    edges :
+
+    """
+    n_nodes = len(attributes_2)
+    edges_1 = list(edges_1)
+    attributes = {}
+    for i in range(len(attributes_1)):
+        # update the attributes
+        attributes[i + n_nodes] = attributes_1[i]
+
+        # update the edges
+        edges_1[i] = Indexes(
+            value_parents=tuple([e + n_nodes for e in list(edges_1[i].value_parents)])
+            if edges_1[i].value_parents is not None
+            else None,
+            volatility_parents=tuple(
+                [e + n_nodes for e in list(edges_1[i].volatility_parents)]
+            )
+            if edges_1[i].volatility_parents is not None
+            else None,
+            value_children=tuple([e + n_nodes for e in list(edges_1[i].value_children)])
+            if edges_1[i].value_children is not None
+            else None,
+            volatility_children=tuple(
+                [e + n_nodes for e in list(edges_1[i].volatility_children)]
+            )
+            if edges_1[i].volatility_children is not None
+            else None,
+        )
+
+    edges_1 = tuple(edges_1)
+
+    attributes = {**attributes_2, **attributes}
+    edges = edges_2 + edges_1
+
+    return attributes, edges
diff --git a/src/pyhgf/plots.py b/src/pyhgf/plots.py
index 52ccde447..52a7afed7 100644
--- a/src/pyhgf/plots.py
+++ b/src/pyhgf/plots.py
@@ -261,6 +261,8 @@ def plot_network(hgf: "HGF") -> "Source":
             label, shape = f"Bi-{idx}", "box"
         elif kind == "categorical":
             label, shape = f"Ca-{idx}", "diamond"
+        elif kind == "dirichlet_process":
+            label, shape = f"DP-{idx}", "doublecircle"
         graphviz_structure.node(
             f"x_{idx}",
             label=label,
diff --git a/src/pyhgf/updates/dirichlet.py b/src/pyhgf/updates/dirichlet.py
index aaed679b9..702ae5f7b 100644
--- a/src/pyhgf/updates/dirichlet.py
+++ b/src/pyhgf/updates/dirichlet.py
@@ -1,55 +1,44 @@
 # Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
 
-from functools import partial
-from typing import Dict, Tuple
+from typing import Callable, Dict, Tuple
 
 import jax.numpy as jnp
-from jax import jit
 
+from pyhgf.model import HGF
+from pyhgf.typing import Edges
 
-@partial(jit, static_argnames=("node_idx", "node_structure"))
-def dirichlet_node_update(
-    parameters_structure: Dict,
-    node_structure,
-    H,
-    value: float,
+
+# @partial(jit, static_argnames=("edges", "cluster_creation", "child_network"))
+def dirichlet_process_node_prediction_error(
+    edges: Edges,
+    attributes: Dict,
+    values: float,
     time_step: float,
     node_idx: int,
+    cluster_creation: Callable,
+    parametrize_cluster: Callable,
+    child_network: "HGF",
+    likelihood_function: Callable,
     **args,
 ) -> Tuple:
-    """Dirichlet process node update.
+    """Prediction error and update the child networks of a Dirichlet process node.
 
     When receiving a new input, this node chose either:
-    1. Allocation the value to a pre-existing cluster and updating its value and
-    volatility parents accordingly.
-    2. Creating a new cluster (branching) with value and volatility parents.
-    3. Merging clusters.
+    1. Allocation the value to a pre-existing cluster.
+    2. Creating a new cluster.
+    3. Merging two pre-existing clusters.
 
     Parameters
     ----------
-    parameters_structure :
-        The structure of nodes' parameters. Each parameter is a dictionary with the
-        following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for
-        continuous nodes.
-    .. note::
-        `"psis"` is the value coupling strength. It should have same length than the
-        volatility parents' indexes. `"kappas"` is the volatility coupling strength.
-        It should have same length than the volatility parents' indexes.
-    node_structure :
-        Tuple of :py:class:`pyhgf.typing.StandardNode` with same length than number of
-        node. For each node, the index list value and volatility parents.
-    H :
-        Function and parameters controlling the creation and evaluation of the base
-        distribution (`H`). The following variables are required:
-        - `create_distribution` : a function creating a new branch starting from the
-        Dirichlet node.
-        - `updates` : a sequence of updates to propagate beliefs in the
-        cluster for a given set of observations.
-        - `pdf` : a function to evaluate the likelihood of a given observation under
-        the base distribution.
-        - `theta` : the distribution parameters.
-    value :
-        The new observed value.
+    edges :
+        The edges of the probabilistic nodes as a tuple of
+        :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
+        For each node, the index list value and volatility parents and children.
+    attributes :
+        The attributes of the probabilistic nodes.
+    values :
+        The new observed value(s). The shape of the input should match the input shape
+        at the child level.
     time_step :
         Interval between the previous time point and the current time point.
     node_idx :
@@ -58,9 +47,9 @@ def dirichlet_node_update(
 
     Returns
     -------
-    parameters_structure :
+    edges :
         The updated node structure.
-    node_structure :
+    attributes :
         The updated node structure.
 
     See Also
@@ -68,111 +57,59 @@ def dirichlet_node_update(
     update_continuous_input_parents, update_binary_input_parents
 
     """
-    (
-        value_parents,
-        volatility_parents,
-    ) = node_structure[node_idx]
-    create_distribution, updates, pdf_distribution, theta = H
+    alpha = attributes[node_idx]["alpha"]  # concentration parameter
+    cluster_idxs = attributes[node_idx]["cluster_idx"]  # clusters inputs indexes
+
+    ############################################
+    # Case 1: no cluster available, create one #
+    ############################################
+    if attributes[node_idx]["cluster_idx"] is None:
+        # increment the number of observations
+        attributes[node_idx]["n"] = [1]
+        attributes[node_idx]["cluster_idx"] = len(attributes)
 
-    # using the current node index, unwrap parameters
-    node_parameters = parameters_structure[node_idx]
+        # create a new branch
+        attributes, edges = cluster_creation(values, child_network, attributes, edges)
 
-    # store value and time step in the input node
-    for idx, node in enumerate(node_structure):
-        if node.value_parents is not None:
-            if node_idx in node.value_parents:
-                parameters_structure[idx]["time_step"] = time_step
-                parameters_structure[idx]["value"] = value
+        return cluster_creation(values, child_network, attributes, edges)
 
-    alpha = node_parameters["alpha"]  # concentration value
-    cluster_idxs = node_parameters["cluster_idx"]  # list of clusters with nodes indexes
+    # - A new cluster -#
+    # --------------- #
 
     # probability to draw a new cluster
-    n_total = jnp.sum(
-        jnp.array(node_parameters["n"])
-    )  # n of observation before this one
+    n_total = jnp.sum(jnp.array(attributes[node_idx]["n"]))
     pi_new = alpha / (alpha + n_total)
 
-    # likelihood for new cluster given the base distribution
-    new_likelihood = pdf_distribution(value=value, theta=theta)
+    # likelihood for a new cluster given the base network
+    new_attributes, new_edges = parametrize_cluster(values, child_network)
+    new_likelihood = likelihood_function(0, values, new_attributes, new_edges)
 
     # joint likelihoods for a new clusters
     new_likelihood *= pi_new
 
-    # probability of being assignet to pre-existing cluster
-    pi_clusters = jnp.where(
-        node_parameters["k"] == 0,
-        jnp.array([10.0]),
-        jnp.array(node_parameters["n"]) / (alpha + n_total),
-    )
-    from jax.debug import print as pt
-
-    pt("{x}", x=pi_clusters)
-    # estimate the likelihood of current observation under each value/volatility pairs
-    likelihoods_cluster = [
-        pdf_distribution(
-            value=value,
-            cluster_idxs=c,
+    # - An existing cluster -#
+    # --------------------- #
+
+    # likelihood for an existing cluster given the base network
+    clusters_likelihood = []
+    for input_idx in cluster_idxs:
+        clusters_likelihood.append(
+            likelihood_function(input_idx, values, attributes, edges)
         )
-        for c in cluster_idxs
-    ]
-    likelihoods_cluster = jnp.where(
-        len(likelihoods_cluster) == 0,
-        jnp.array([0.0], ndmin=1),
-        jnp.array(likelihoods_cluster),
-    )
+
+    # probability of being assignet to pre-existing cluster
+    pi_clusters = jnp.array(attributes[node_idx]["n"]) / (alpha + n_total)
 
     # joint likelihoods for previous clusters
-    likelihoods_cluster *= pi_clusters
-    # ############################
-    # # 1- Creating a new cluser #
-    # ############################
-    # if new_likelihood > likelihoods_cluster.max():
-    #     print("Creating a new cluster")
-
-    #     # define the parameters for the new distribution
-    #     theta = value, 1.0, -3.0
-
-    #     # create a new distribution and update the cluster index accordingly
-    #     node_structure, parameters_structure, cluster_idx = create_distribution(
-    #         dirichlet_idx=node_idx,
-    #         node_structure=node_structure,
-    #         parameters_structure=parameters_structure,
-    #         cluster_idx=cluster_idx,
-    #         theta=theta,
-    #     )
-
-    #     # append a new cluster count
-    #     node_parameters["n"].append(1)
-    #     node_parameters["z"] = idx
-    #     node_parameters["k"] += 1  # number of clusters
-
-    # #######################################################
-    # # 2 - Assigning observation to a pre-existing cluster #
-    # #######################################################
-
-    # # pass the value to the selected cluster and update each branch of the network
-    # # pass None to clusters that are not updated
-    # if new_likelihood <= jnp.array(likelihoods_cluster).max():
-    #     # which cluster should be updated
-    #     idx = jnp.argmax(likelihoods_cluster)
-    #     c_idx = cluster_idx[idx]  # the cluster nodes
-
-    #     print(f"Assigning value to cluster {idx}")
-
-    #     # create the update sequence for this distribution
-    #     update_sequence = tuple([(c, f[1]) for c, f in zip(c_idx, cluster_updates)])
-
-    #     parameters_structure = apply_sequence(
-    #         value=value,
-    #         time_step=time_step,
-    #         parameters_structure=parameters_structure,
-    #         node_structure=node_structure,
-    #         update_sequence=update_sequence,
-    #     )
-
-    #     # update the number of observation in this cluster
-    #     node_parameters["n"][idx] += 1
-    #     node_parameters["z"] = idx
-
-    return parameters_structure, node_structure
+    clusters_likelihood *= pi_clusters
+
+    print(clusters_likelihood)
+    #################################
+    # Case 2: Creating a new cluser #
+    #################################
+
+    ###########################################################
+    # Case 3: Assigning observation to a pre-existing cluster #
+    ###########################################################
+
+    return edges, attributes