Skip to content

Commit

Permalink
update the dirichlet step
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 4, 2023
1 parent 72cc684 commit c763605
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 136 deletions.
19 changes: 16 additions & 3 deletions src/pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,23 @@ def add_input_node(
"binary_precision": jnp.inf,
},
categorical_parameters: Dict = {"n_categories": 4},
dirichlet_parameters: Dict = {"alpha": 1.0},
additional_parameters: Optional[Dict] = None,
):
"""Create an input node.
Three types of input nodes are supported:
Four types of input nodes are supported:
- `continuous`: receive a continuous observation as input. The parameter
`continuous_precision` is required.
- `binary` receive a single boolean as observation. The parameters
`binary_precision`, `eta0` and `eta1` are required.
- `categorical` receive a boolean array as observation. The parameters are
`n_categories` required.
- `dirichlet_process` A Dirichlet process node is a distribution over a network
of node. The type of input and the size of inputs is the one supported by the
base network. This input node has the capability of merging and creating new
networks to fit the complexity of the observation.
.. note:
When using `categorical`, the implied `n` binary HGFs are automatically
Expand All @@ -483,8 +488,8 @@ def add_input_node(
Parameters
----------
kind :
The kind of input that should be created (can be `"continuous"`, `"binary"`
or `"categorical"`).
The kind of input that should be created (can be `"continuous"`, `"binary"`,
`"categorical"` or `"dirichlet_process"`).
input_idxs :
The index of the new input (defaults to `0`).
continuous_parameters :
Expand All @@ -504,6 +509,8 @@ def add_input_node(
.. note::
When using a categorical state node, the `binary_parameters` can be used
to parametrize the implied collection of binray HGFs.
dirichlet_parameters :
Additional parameters for the Dirichlet process node.
additional_parameters :
Add more custom parameters to the input node.
Expand Down Expand Up @@ -544,6 +551,12 @@ def add_input_node(
),
"value": jnp.zeros(categorical_parameters["n_categories"]),
}
elif kind == "dirichlet_process":
input_node_parameters = {
"n": None,
"alpha": dirichlet_parameters["alpha"],
"cluster_idx": None,
}

# add more parameters (optional)
if additional_parameters is not None:
Expand Down
55 changes: 55 additions & 0 deletions src/pyhgf/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,58 @@ def to_pandas(hgf: "HGF") -> pd.DataFrame:
].sum(axis=1, min_count=1)

return structure_df


def concatenate_networks(attributes_1, attributes_2, edges_1, edges_2):
""" "Concatenate two networks.
Parameters
----------
attributes_1 :
attributes_2 :
edges_1 :
edges_2 :
Returns
-------
attributes :
edges :
"""
n_nodes = len(attributes_2)
edges_1 = list(edges_1)
attributes = {}
for i in range(len(attributes_1)):
# update the attributes
attributes[i + n_nodes] = attributes_1[i]

# update the edges
edges_1[i] = Indexes(
value_parents=tuple([e + n_nodes for e in list(edges_1[i].value_parents)])
if edges_1[i].value_parents is not None
else None,
volatility_parents=tuple(
[e + n_nodes for e in list(edges_1[i].volatility_parents)]
)
if edges_1[i].volatility_parents is not None
else None,
value_children=tuple([e + n_nodes for e in list(edges_1[i].value_children)])
if edges_1[i].value_children is not None
else None,
volatility_children=tuple(
[e + n_nodes for e in list(edges_1[i].volatility_children)]
)
if edges_1[i].volatility_children is not None
else None,
)

edges_1 = tuple(edges_1)

attributes = {**attributes_2, **attributes}
edges = edges_2 + edges_1

return attributes, edges
2 changes: 2 additions & 0 deletions src/pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def plot_network(hgf: "HGF") -> "Source":
label, shape = f"Bi-{idx}", "box"
elif kind == "categorical":
label, shape = f"Ca-{idx}", "diamond"
elif kind == "dirichlet_process":
label, shape = f"DP-{idx}", "doublecircle"
graphviz_structure.node(
f"x_{idx}",
label=label,
Expand Down
203 changes: 70 additions & 133 deletions src/pyhgf/updates/dirichlet.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,44 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial
from typing import Dict, Tuple
from typing import Callable, Dict, Tuple

import jax.numpy as jnp
from jax import jit

from pyhgf.model import HGF
from pyhgf.typing import Edges

@partial(jit, static_argnames=("node_idx", "node_structure"))
def dirichlet_node_update(
parameters_structure: Dict,
node_structure,
H,
value: float,

# @partial(jit, static_argnames=("edges", "cluster_creation", "child_network"))
def dirichlet_process_node_prediction_error(
edges: Edges,
attributes: Dict,
values: float,
time_step: float,
node_idx: int,
cluster_creation: Callable,
parametrize_cluster: Callable,
child_network: "HGF",
likelihood_function: Callable,
**args,
) -> Tuple:
"""Dirichlet process node update.
"""Prediction error and update the child networks of a Dirichlet process node.
When receiving a new input, this node chose either:
1. Allocation the value to a pre-existing cluster and updating its value and
volatility parents accordingly.
2. Creating a new cluster (branching) with value and volatility parents.
3. Merging clusters.
1. Allocation the value to a pre-existing cluster.
2. Creating a new cluster.
3. Merging two pre-existing clusters.
Parameters
----------
parameters_structure :
The structure of nodes' parameters. Each parameter is a dictionary with the
following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for
continuous nodes.
.. note::
`"psis"` is the value coupling strength. It should have same length than the
volatility parents' indexes. `"kappas"` is the volatility coupling strength.
It should have same length than the volatility parents' indexes.
node_structure :
Tuple of :py:class:`pyhgf.typing.StandardNode` with same length than number of
node. For each node, the index list value and volatility parents.
H :
Function and parameters controlling the creation and evaluation of the base
distribution (`H`). The following variables are required:
- `create_distribution` : a function creating a new branch starting from the
Dirichlet node.
- `updates` : a sequence of updates to propagate beliefs in the
cluster for a given set of observations.
- `pdf` : a function to evaluate the likelihood of a given observation under
the base distribution.
- `theta` : the distribution parameters.
value :
The new observed value.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
For each node, the index list value and volatility parents and children.
attributes :
The attributes of the probabilistic nodes.
values :
The new observed value(s). The shape of the input should match the input shape
at the child level.
time_step :
Interval between the previous time point and the current time point.
node_idx :
Expand All @@ -58,121 +47,69 @@ def dirichlet_node_update(
Returns
-------
parameters_structure :
edges :
The updated node structure.
node_structure :
attributes :
The updated node structure.
See Also
--------
update_continuous_input_parents, update_binary_input_parents
"""
(
value_parents,
volatility_parents,
) = node_structure[node_idx]
create_distribution, updates, pdf_distribution, theta = H
alpha = attributes[node_idx]["alpha"] # concentration parameter
cluster_idxs = attributes[node_idx]["cluster_idx"] # clusters inputs indexes

############################################
# Case 1: no cluster available, create one #
############################################
if attributes[node_idx]["cluster_idx"] is None:
# increment the number of observations
attributes[node_idx]["n"] = [1]
attributes[node_idx]["cluster_idx"] = len(attributes)

# using the current node index, unwrap parameters
node_parameters = parameters_structure[node_idx]
# create a new branch
attributes, edges = cluster_creation(values, child_network, attributes, edges)

# store value and time step in the input node
for idx, node in enumerate(node_structure):
if node.value_parents is not None:
if node_idx in node.value_parents:
parameters_structure[idx]["time_step"] = time_step
parameters_structure[idx]["value"] = value
return cluster_creation(values, child_network, attributes, edges)

alpha = node_parameters["alpha"] # concentration value
cluster_idxs = node_parameters["cluster_idx"] # list of clusters with nodes indexes
# - A new cluster -#
# --------------- #

# probability to draw a new cluster
n_total = jnp.sum(
jnp.array(node_parameters["n"])
) # n of observation before this one
n_total = jnp.sum(jnp.array(attributes[node_idx]["n"]))
pi_new = alpha / (alpha + n_total)

# likelihood for new cluster given the base distribution
new_likelihood = pdf_distribution(value=value, theta=theta)
# likelihood for a new cluster given the base network
new_attributes, new_edges = parametrize_cluster(values, child_network)
new_likelihood = likelihood_function(0, values, new_attributes, new_edges)

# joint likelihoods for a new clusters
new_likelihood *= pi_new

# probability of being assignet to pre-existing cluster
pi_clusters = jnp.where(
node_parameters["k"] == 0,
jnp.array([10.0]),
jnp.array(node_parameters["n"]) / (alpha + n_total),
)
from jax.debug import print as pt

pt("{x}", x=pi_clusters)
# estimate the likelihood of current observation under each value/volatility pairs
likelihoods_cluster = [
pdf_distribution(
value=value,
cluster_idxs=c,
# - An existing cluster -#
# --------------------- #

# likelihood for an existing cluster given the base network
clusters_likelihood = []
for input_idx in cluster_idxs:
clusters_likelihood.append(
likelihood_function(input_idx, values, attributes, edges)
)
for c in cluster_idxs
]
likelihoods_cluster = jnp.where(
len(likelihoods_cluster) == 0,
jnp.array([0.0], ndmin=1),
jnp.array(likelihoods_cluster),
)

# probability of being assignet to pre-existing cluster
pi_clusters = jnp.array(attributes[node_idx]["n"]) / (alpha + n_total)

# joint likelihoods for previous clusters
likelihoods_cluster *= pi_clusters
# ############################
# # 1- Creating a new cluser #
# ############################
# if new_likelihood > likelihoods_cluster.max():
# print("Creating a new cluster")

# # define the parameters for the new distribution
# theta = value, 1.0, -3.0

# # create a new distribution and update the cluster index accordingly
# node_structure, parameters_structure, cluster_idx = create_distribution(
# dirichlet_idx=node_idx,
# node_structure=node_structure,
# parameters_structure=parameters_structure,
# cluster_idx=cluster_idx,
# theta=theta,
# )

# # append a new cluster count
# node_parameters["n"].append(1)
# node_parameters["z"] = idx
# node_parameters["k"] += 1 # number of clusters

# #######################################################
# # 2 - Assigning observation to a pre-existing cluster #
# #######################################################

# # pass the value to the selected cluster and update each branch of the network
# # pass None to clusters that are not updated
# if new_likelihood <= jnp.array(likelihoods_cluster).max():
# # which cluster should be updated
# idx = jnp.argmax(likelihoods_cluster)
# c_idx = cluster_idx[idx] # the cluster nodes

# print(f"Assigning value to cluster {idx}")

# # create the update sequence for this distribution
# update_sequence = tuple([(c, f[1]) for c, f in zip(c_idx, cluster_updates)])

# parameters_structure = apply_sequence(
# value=value,
# time_step=time_step,
# parameters_structure=parameters_structure,
# node_structure=node_structure,
# update_sequence=update_sequence,
# )

# # update the number of observation in this cluster
# node_parameters["n"][idx] += 1
# node_parameters["z"] = idx

return parameters_structure, node_structure
clusters_likelihood *= pi_clusters

print(clusters_likelihood)
#################################
# Case 2: Creating a new cluser #
#################################

###########################################################
# Case 3: Assigning observation to a pre-existing cluster #
###########################################################

return edges, attributes

0 comments on commit c763605

Please sign in to comment.