Skip to content

Commit

Permalink
notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 19, 2023
1 parent 647931d commit 6ff25f8
Show file tree
Hide file tree
Showing 2 changed files with 1,059 additions and 0 deletions.
747 changes: 747 additions & 0 deletions docs/source/notebooks/Example_2_Dirichlet_process.ipynb

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions docs/source/notebooks/Example_2_Dirichlet_process.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.5
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

(example_2)=
# Dirichlet nodes

```{code-cell} ipython3
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pyhgf.model import HGF
from scipy.stats import multinomial
import jax.numpy as jnp
```

Cite https://arxiv.org/pdf/2205.10787.pdf

```{code-cell} ipython3
# simulate a dataset
n_trials = 1000
def get_value():
cat = multinomial(n=1, p=[.2, .2, .6]).rvs()
vals = np.random.normal([0, 2, 4], .2)
return (cat * vals).sum()
input_data = []
for i in range(n_trials):
input_data.append(get_value())
```

```{code-cell} ipython3
sns.histplot(input_data, alpha=.2, bins=40)
```

## Using the standard HGF

```{code-cell} ipython3
standard_hgf = HGF(
n_levels=2,
model_type="continuous",
initial_mu={"1": 0.0, "2": 0.0},
initial_pi={"1": 1e4, "2": 1e1},
omega={"1": -4.0, "2": -4.0}
)
```

```{code-cell} ipython3
standard_hgf.input_data(input_data=input_data)
```

```{code-cell} ipython3
standard_hgf.plot_trajectories();
```

## Using a Dirichlet node

A Dirichlet node has the following parameters:
- $\sigma_\xi$, the default cluster precision.
- $\alpha$, the concentration parameter.
- $\pi_i = \frac{n_i}{\alpha + k - 1}$
- $k$ the current number of clusters
- $n_i$ the number of observations in cluster $C_i$

+++

A Dirichlet process over a measurable set $S$ is specifided by a base distribution $H$ and a concentration parameter $\alpha$. In pyhgf, the base distribution is a tree probabilistic neural network that can operate over $S$. When a new set is observed, the Dirichlet node will chose between the following alternatives:

1. Create a new cluster $C_n$ with probability:
- $\frac{\alpha}{\alpha + n - 1}p(x|\mu_n, \sigma_n)$
- with $\mu_n = x$ and $\sigma_n = \sigma_\xi$
2. Merge two cluster $j$ and $k$ into a new cluster $l$ if:
- $p(x|\mu_l, \sigma_l)\pi_l < p(x|\mu_k, \sigma_k)\pi_k + p(x|\mu_j, \sigma_j)\pi_j$
- $\mu_l = \frac{\mu_j + \mu_k}{2}$
- $\sigma_l = \frac{\sigma_j + \sigma_k}{2}$
- $\pi_l = \frac{n_j + n_k}{\alpha + k - 1}$
3. Otherwise, sample from cluster $C_i$ where:
- $p(x|\mu_i, \sigma_i)\pi_i \le p(x|\mu_j, \sigma_j)\pi_j \forall j \in C $.

+++

First we start by defining the base distribution $H$. This requires use to detail how the node can create a new distribution, and how to update this distribution. Those steps are declared in the function `base_distribution` and the variable `cluster_updates` (respectively).

```{code-cell} ipython3
from pyhgf.dirichlet import dirichlet_node_update
from pyhgf.structure import add_input_node, add_value_parent, add_volatility_parent, apply_sequence
from jax.scipy.stats import norm
from pyhgf.continuous import continuous_node_update, continuous_input_update
from pyhgf.typing import UpdateSequence, NodeStructure, StandardNode
from typing import List, Tuple
```

```{code-cell} ipython3
def create_distribution(
dirichlet_idx: int,
node_structure: NodeStructure,
parameters_structure,
cluster_idx: List,
theta: Tuple
):
"""Create a new distribution as a branch of the probabilistic network.
Here we create a two-level continuous HGF.
"""
# theta is a list of parameters defining this distribution
# here, the mean and precision of the first node
mu, pi, omega = theta
val_idx = len(node_structure) # the first value parent (input node)
val2_idx = len(node_structure) + 1 # the second value parent (x1)
vol_idx = val_idx + 2 # the volatility parent
# update the cluster list
cluster_idx.append((val_idx, val2_idx, vol_idx))
# add a continuous input
node_structure, parameters_structure = add_input_node(
kind="continuous",
node_structure=node_structure,
parameters_structure=parameters_structure
)
# manually set this continuous input as value parent of the Dirichlet node
structure_as_list = list(node_structure)
if structure_as_list[dirichlet_idx].value_parents is None:
new_value_parents = (val_idx,)
else:
new_value_parents = structure_as_list[dirichlet_idx].value_parents
new_value_parents += (val_idx,)
structure_as_list[dirichlet_idx] = Indexes(new_value_parents, None)
node_structure = tuple(structure_as_list)
# add a value parent
node_structure, parameters_structure = add_value_parent(
children_idxs=[val_idx],
mu=mu,
pi=pi,
omega=omega,
node_structure=node_structure,
parameters_structure=parameters_structure
)
# add a volatility parent
node_structure, parameters_structure = add_volatility_parent(
children_idxs=[val2_idx],
node_structure=node_structure,
parameters_structure=parameters_structure,
omega=omega
)
return node_structure, parameters_structure, cluster_idx
```

Next, we describe the update sequence (how we want to propagate the prediction error when a branch of the Dirichlet node observe a new value). Here, we create a *generic* update sequence as the indexes of the nodes are not fixed, but will depends on the branch that needs to be updated. The indexes of the nodes is saved separately in the `cluster_idx` variable (see above).

```{code-cell} ipython3
cluster_updates: UpdateSequence = (
(None, continuous_input_update),
(None, continuous_node_update),
(None, continuous_node_update)
)
```

```{code-cell} ipython3
def pdf_distribution(value, cluster_idxs=None, theta=None):
if theta is not None:
pi = theta
else:
pi = parameters_structure[cluster_idxs[1]]["pi"]
"""Likelihood function for a value under the default distribution"""
return norm.pdf(x=value, loc=value, scale=1/jnp.sqrt(pi))
```

```{code-cell} ipython3
# the default parameter(s)
# here only contains the precision of the base Gaussian distribution
theta = 1.0
```

```{code-cell} ipython3
# wrap the base distribution
dirichlet_node = create_distribution, cluster_updates, pdf_distribution, theta
```

```{code-cell} ipython3
# create the Dirichlet node structure
dirichlet_hgf = (
HGF(model_type=None)
.add_input_node(kind="continuous")
.add_dirichlet_parent(children_idx=0, dirichlet_node=dirichlet_node)
)
dirichlet_hgf.plot_network()
```

```{code-cell} ipython3
dirichlet_hgf.update_sequence = (
(1, dirichlet_node_update)
)
```

```{code-cell} ipython3
dirichlet_hgf.parameters_structure
```

```{code-cell} ipython3
dirichlet_hgf.node_structure
```

```{code-cell} ipython3
parameters_structure, node_structure = dirichlet_node_update(
parameters_structure=dirichlet_hgf.parameters_structure,
node_structure=dirichlet_hgf.node_structure,
value=1,
time_step=1.0,
node_idx=1,
)
```

```{code-cell} ipython3
jnp.where(
dirichlet_hgf.parameters_structure[1]["k"] == 0,
jnp.array(1.0),
jnp.array(2.0)
)
```

```{code-cell} ipython3
```

```{code-cell} ipython3
```

```{code-cell} ipython3
```

```{code-cell} ipython3
jnp.where(jnp.array(0) == 0, jnp.array(1.0), jnp.array(2.0))
```

```{code-cell} ipython3
dirichlet_hgf.node_structure[1]
```

```{code-cell} ipython3
for i in input_data:
parameters_structure, node_structure = dirichlet_node_update(
parameters_structure=dirichlet_hgf.parameters_structure,
node_structure=dirichlet_hgf.node_structure,
value=1,
time_step=1.0,
node_idx=1,
)
dirichlet_hgf.parameters_structure = parameters_structure
dirichlet_hgf.node_structure = node_structure
```

```{code-cell} ipython3
parameters_structure = dirichlet_hgf.parameters_structure
node_structure = dirichlet_hgf.node_structure
node_idx = 1
time_step = 1.0
```

```{code-cell} ipython3
value = 1.2
```

```{code-cell} ipython3
n_total +=1
value_vector.append(value)
```

```{code-cell} ipython3
parameters_structure
```

## Ploting results

```{code-cell} ipython3
_, axs = plt.subplots(figsize=(12, 6))
for cluster in cluster_idx:
mu = parameters_structure[cluster[1]]["mu"]
pi = parameters_structure[cluster[1]]["pi"]
axs.plot(np.linspace(-3, 8, 1000), norm.pdf(x=np.linspace(-3, 8, 1000), loc=mu, scale=1/np.sqrt(pi)))
axs.scatter(x=value_vector, y=np.random.normal(scale=.002, size=len(value_vector))-0.1)
```

```{code-cell} ipython3
```

0 comments on commit 6ff25f8

Please sign in to comment.