-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
647931d
commit 6ff25f8
Showing
2 changed files
with
1,059 additions
and
0 deletions.
There are no files selected for viewing
747 changes: 747 additions & 0 deletions
747
docs/source/notebooks/Example_2_Dirichlet_process.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
--- | ||
jupytext: | ||
formats: ipynb,md:myst | ||
text_representation: | ||
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
jupytext_version: 1.14.5 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
(example_2)= | ||
# Dirichlet nodes | ||
|
||
```{code-cell} ipython3 | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
from pyhgf.model import HGF | ||
from scipy.stats import multinomial | ||
import jax.numpy as jnp | ||
``` | ||
|
||
Cite https://arxiv.org/pdf/2205.10787.pdf | ||
|
||
```{code-cell} ipython3 | ||
# simulate a dataset | ||
n_trials = 1000 | ||
def get_value(): | ||
cat = multinomial(n=1, p=[.2, .2, .6]).rvs() | ||
vals = np.random.normal([0, 2, 4], .2) | ||
return (cat * vals).sum() | ||
input_data = [] | ||
for i in range(n_trials): | ||
input_data.append(get_value()) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
sns.histplot(input_data, alpha=.2, bins=40) | ||
``` | ||
|
||
## Using the standard HGF | ||
|
||
```{code-cell} ipython3 | ||
standard_hgf = HGF( | ||
n_levels=2, | ||
model_type="continuous", | ||
initial_mu={"1": 0.0, "2": 0.0}, | ||
initial_pi={"1": 1e4, "2": 1e1}, | ||
omega={"1": -4.0, "2": -4.0} | ||
) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
standard_hgf.input_data(input_data=input_data) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
standard_hgf.plot_trajectories(); | ||
``` | ||
|
||
## Using a Dirichlet node | ||
|
||
A Dirichlet node has the following parameters: | ||
- $\sigma_\xi$, the default cluster precision. | ||
- $\alpha$, the concentration parameter. | ||
- $\pi_i = \frac{n_i}{\alpha + k - 1}$ | ||
- $k$ the current number of clusters | ||
- $n_i$ the number of observations in cluster $C_i$ | ||
|
||
+++ | ||
|
||
A Dirichlet process over a measurable set $S$ is specifided by a base distribution $H$ and a concentration parameter $\alpha$. In pyhgf, the base distribution is a tree probabilistic neural network that can operate over $S$. When a new set is observed, the Dirichlet node will chose between the following alternatives: | ||
|
||
1. Create a new cluster $C_n$ with probability: | ||
- $\frac{\alpha}{\alpha + n - 1}p(x|\mu_n, \sigma_n)$ | ||
- with $\mu_n = x$ and $\sigma_n = \sigma_\xi$ | ||
2. Merge two cluster $j$ and $k$ into a new cluster $l$ if: | ||
- $p(x|\mu_l, \sigma_l)\pi_l < p(x|\mu_k, \sigma_k)\pi_k + p(x|\mu_j, \sigma_j)\pi_j$ | ||
- $\mu_l = \frac{\mu_j + \mu_k}{2}$ | ||
- $\sigma_l = \frac{\sigma_j + \sigma_k}{2}$ | ||
- $\pi_l = \frac{n_j + n_k}{\alpha + k - 1}$ | ||
3. Otherwise, sample from cluster $C_i$ where: | ||
- $p(x|\mu_i, \sigma_i)\pi_i \le p(x|\mu_j, \sigma_j)\pi_j \forall j \in C $. | ||
|
||
+++ | ||
|
||
First we start by defining the base distribution $H$. This requires use to detail how the node can create a new distribution, and how to update this distribution. Those steps are declared in the function `base_distribution` and the variable `cluster_updates` (respectively). | ||
|
||
```{code-cell} ipython3 | ||
from pyhgf.dirichlet import dirichlet_node_update | ||
from pyhgf.structure import add_input_node, add_value_parent, add_volatility_parent, apply_sequence | ||
from jax.scipy.stats import norm | ||
from pyhgf.continuous import continuous_node_update, continuous_input_update | ||
from pyhgf.typing import UpdateSequence, NodeStructure, StandardNode | ||
from typing import List, Tuple | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
def create_distribution( | ||
dirichlet_idx: int, | ||
node_structure: NodeStructure, | ||
parameters_structure, | ||
cluster_idx: List, | ||
theta: Tuple | ||
): | ||
"""Create a new distribution as a branch of the probabilistic network. | ||
Here we create a two-level continuous HGF. | ||
""" | ||
# theta is a list of parameters defining this distribution | ||
# here, the mean and precision of the first node | ||
mu, pi, omega = theta | ||
val_idx = len(node_structure) # the first value parent (input node) | ||
val2_idx = len(node_structure) + 1 # the second value parent (x1) | ||
vol_idx = val_idx + 2 # the volatility parent | ||
# update the cluster list | ||
cluster_idx.append((val_idx, val2_idx, vol_idx)) | ||
# add a continuous input | ||
node_structure, parameters_structure = add_input_node( | ||
kind="continuous", | ||
node_structure=node_structure, | ||
parameters_structure=parameters_structure | ||
) | ||
# manually set this continuous input as value parent of the Dirichlet node | ||
structure_as_list = list(node_structure) | ||
if structure_as_list[dirichlet_idx].value_parents is None: | ||
new_value_parents = (val_idx,) | ||
else: | ||
new_value_parents = structure_as_list[dirichlet_idx].value_parents | ||
new_value_parents += (val_idx,) | ||
structure_as_list[dirichlet_idx] = Indexes(new_value_parents, None) | ||
node_structure = tuple(structure_as_list) | ||
# add a value parent | ||
node_structure, parameters_structure = add_value_parent( | ||
children_idxs=[val_idx], | ||
mu=mu, | ||
pi=pi, | ||
omega=omega, | ||
node_structure=node_structure, | ||
parameters_structure=parameters_structure | ||
) | ||
# add a volatility parent | ||
node_structure, parameters_structure = add_volatility_parent( | ||
children_idxs=[val2_idx], | ||
node_structure=node_structure, | ||
parameters_structure=parameters_structure, | ||
omega=omega | ||
) | ||
return node_structure, parameters_structure, cluster_idx | ||
``` | ||
|
||
Next, we describe the update sequence (how we want to propagate the prediction error when a branch of the Dirichlet node observe a new value). Here, we create a *generic* update sequence as the indexes of the nodes are not fixed, but will depends on the branch that needs to be updated. The indexes of the nodes is saved separately in the `cluster_idx` variable (see above). | ||
|
||
```{code-cell} ipython3 | ||
cluster_updates: UpdateSequence = ( | ||
(None, continuous_input_update), | ||
(None, continuous_node_update), | ||
(None, continuous_node_update) | ||
) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
def pdf_distribution(value, cluster_idxs=None, theta=None): | ||
if theta is not None: | ||
pi = theta | ||
else: | ||
pi = parameters_structure[cluster_idxs[1]]["pi"] | ||
"""Likelihood function for a value under the default distribution""" | ||
return norm.pdf(x=value, loc=value, scale=1/jnp.sqrt(pi)) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
# the default parameter(s) | ||
# here only contains the precision of the base Gaussian distribution | ||
theta = 1.0 | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
# wrap the base distribution | ||
dirichlet_node = create_distribution, cluster_updates, pdf_distribution, theta | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
# create the Dirichlet node structure | ||
dirichlet_hgf = ( | ||
HGF(model_type=None) | ||
.add_input_node(kind="continuous") | ||
.add_dirichlet_parent(children_idx=0, dirichlet_node=dirichlet_node) | ||
) | ||
dirichlet_hgf.plot_network() | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
dirichlet_hgf.update_sequence = ( | ||
(1, dirichlet_node_update) | ||
) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
dirichlet_hgf.parameters_structure | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
dirichlet_hgf.node_structure | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
parameters_structure, node_structure = dirichlet_node_update( | ||
parameters_structure=dirichlet_hgf.parameters_structure, | ||
node_structure=dirichlet_hgf.node_structure, | ||
value=1, | ||
time_step=1.0, | ||
node_idx=1, | ||
) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
jnp.where( | ||
dirichlet_hgf.parameters_structure[1]["k"] == 0, | ||
jnp.array(1.0), | ||
jnp.array(2.0) | ||
) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
jnp.where(jnp.array(0) == 0, jnp.array(1.0), jnp.array(2.0)) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
dirichlet_hgf.node_structure[1] | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
for i in input_data: | ||
parameters_structure, node_structure = dirichlet_node_update( | ||
parameters_structure=dirichlet_hgf.parameters_structure, | ||
node_structure=dirichlet_hgf.node_structure, | ||
value=1, | ||
time_step=1.0, | ||
node_idx=1, | ||
) | ||
dirichlet_hgf.parameters_structure = parameters_structure | ||
dirichlet_hgf.node_structure = node_structure | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
parameters_structure = dirichlet_hgf.parameters_structure | ||
node_structure = dirichlet_hgf.node_structure | ||
node_idx = 1 | ||
time_step = 1.0 | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
value = 1.2 | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
n_total +=1 | ||
value_vector.append(value) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
parameters_structure | ||
``` | ||
|
||
## Ploting results | ||
|
||
```{code-cell} ipython3 | ||
_, axs = plt.subplots(figsize=(12, 6)) | ||
for cluster in cluster_idx: | ||
mu = parameters_structure[cluster[1]]["mu"] | ||
pi = parameters_structure[cluster[1]]["pi"] | ||
axs.plot(np.linspace(-3, 8, 1000), norm.pdf(x=np.linspace(-3, 8, 1000), loc=mu, scale=1/np.sqrt(pi))) | ||
axs.scatter(x=value_vector, y=np.random.normal(scale=.002, size=len(value_vector))-0.1) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
``` |