Skip to content

Commit

Permalink
CPC tutorial and documentation (#89)
Browse files Browse the repository at this point in the history
* cpc tutorial

* default parmeter values and merge mu and muhat and pi and pihat in the node creation

* smal fixes in the plotting functions
  • Loading branch information
LegrandNico authored Sep 6, 2023
1 parent c4d8288 commit c17f047
Show file tree
Hide file tree
Showing 8 changed files with 2,504 additions and 735 deletions.
128 changes: 73 additions & 55 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

27 changes: 21 additions & 6 deletions docs/source/notebooks/1.1-Binary_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ two_levels_hgf = HGF(
n_levels=2,
model_type="binary",
initial_mu={"1": .0, "2": .5},
initial_pi={"1": .0, "2": 1e4},
initial_pi={"1": .0, "2": 1.0},
omega={"2": -3.0},
)
```
Expand Down Expand Up @@ -143,7 +143,7 @@ three_levels_hgf = HGF(
n_levels=3,
model_type="binary",
initial_mu={"1": .0, "2": .5, "3": 0.},
initial_pi={"1": .0, "2": 1e4, "3": 1e1},
initial_pi={"1": .0, "2": 1.0, "3": 1.0},
omega={"1": None, "2": -3.0, "3": -2.0},
rho={"1": None, "2": 0.0, "3": 0.0},
kappas={"1": None, "2": 1.0},
Expand Down Expand Up @@ -218,7 +218,7 @@ with pm.Model() as two_levels_binary_hgf:
rho_1=0.0,
rho_2=0.0,
pi_1=0.0,
pi_2=1e4,
pi_2=1.0,
mu_1=jnp.inf,
mu_2=0.5,
kappa_1=1.0,
Expand Down Expand Up @@ -261,7 +261,7 @@ hgf_mcmc = HGF(
n_levels=2,
model_type="binary",
initial_mu={"1": jnp.inf, "2": 0.5},
initial_pi={"1": 0.0, "2": 1e4},
initial_pi={"1": 0.0, "2": 1.0},
omega={"1": jnp.inf, "2": omega_2},
rho={"1": 0.0, "2": 0.0},
kappas={"1": 1.0}).input_data(
Expand All @@ -281,6 +281,11 @@ hgf_mcmc.surprise()
#### Creating the model

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
hgf_logp_op = HGFDistribution(
n_levels=3,
model_type="binary",
Expand All @@ -294,6 +299,11 @@ The data is passed to the distribution when the instance is created.
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
with pm.Model() as three_levels_binary_hgf:
omega_2 = pm.Uniform("omega_2", -4.0, 0.0)
Expand All @@ -310,8 +320,8 @@ with pm.Model() as three_levels_binary_hgf:
rho_2=0.0,
rho_3=0.0,
pi_1=0.0,
pi_2=1e4,
pi_3=1e1,
pi_2=1.0,
pi_3=1.0,
mu_1=jnp.inf,
mu_2=0.5,
mu_3=0.0,
Expand All @@ -330,6 +340,11 @@ pm.model_to_graphviz(three_levels_binary_hgf)
#### Sampling

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
with three_levels_binary_hgf:
three_level_hgf_idata = pm.sample(chains=2)
```
Expand Down
15 changes: 8 additions & 7 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,13 @@
"cell_type": "code",
"execution_count": 8,
"id": "6e27b3d9-1b10-4cf9-9074-725a9638345d",
"metadata": {},
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"categorical_hgf.input_data(input_data=input_data.T);"
Expand All @@ -683,9 +689,7 @@
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-cell"
]
"tags": []
},
"outputs": [
{
Expand Down Expand Up @@ -983,9 +987,6 @@
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
Expand Down
2,271 changes: 1,717 additions & 554 deletions docs/source/notebooks/Exercise_1_Using_the_HGF.ipynb

Large diffs are not rendered by default.

729 changes: 648 additions & 81 deletions docs/source/notebooks/Exercise_1_Using_the_HGF.md

Large diffs are not rendered by default.

37 changes: 13 additions & 24 deletions pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
"2": -3.0,
"3": -3.0,
},
kappas: Dict = {"1": 1.0, "2": 0.0},
kappas: Dict = {"1": 1.0, "2": 1.0},
eta0: Union[float, np.ndarray, ArrayLike] = 0.0,
eta1: Union[float, np.ndarray, ArrayLike] = 1.0,
binary_precision: Union[float, np.ndarray, ArrayLike] = jnp.inf,
Expand Down Expand Up @@ -575,7 +575,6 @@ def add_input_node(
"pi_3": 1.0,
"mu_1": 0.0,
"mu_2": -jnp.log(categorical_parameters["n_categories"] - 1),
"mu_hat_2": -jnp.log(categorical_parameters["n_categories"] - 1),
"mu_3": 0.0,
"omega_2": -4.0,
"omega_3": -4.0,
Expand All @@ -595,9 +594,7 @@ def add_value_parent(
children_idxs: Union[List, int],
value_coupling: Union[float, np.ndarray, ArrayLike] = 1.0,
mu: Union[float, np.ndarray, ArrayLike] = 0.0,
mu_hat: Union[float, np.ndarray, ArrayLike] = 0.0,
pi: Union[float, np.ndarray, ArrayLike] = 1.0,
pi_hat: Union[float, np.ndarray, ArrayLike] = 1.0,
omega: Union[float, np.ndarray, ArrayLike] = -4.0,
rho: Union[float, np.ndarray, ArrayLike] = 0.0,
additional_parameters: Optional[Dict] = None,
Expand All @@ -612,14 +609,11 @@ def add_value_parent(
The value_coupling between the child and parent node. This is will be
appended to the `psis` parameters in the parent and child node(s).
mu :
The mean of the Gaussian distribution.
mu_hat :
The expected mean of the Gaussian distribution for the next observation.
The mean of the Gaussian distribution. This value is passed both to the
current and expected states.
pi :
The precision of the Gaussian distribution (inverse variance).
pi_hat :
The expected precision of the Gaussian distribution (inverse variance) for
the next observation.
The precision of the Gaussian distribution (inverse variance). This
value is passed both to the current and expected states.
omega :
The tonic part of the variance (the variance that is not inherited from
volatility parent(s)).
Expand All @@ -639,9 +633,9 @@ def add_value_parent(
# parent's parameter
node_parameters = {
"mu": mu,
"muhat": mu_hat,
"muhat": mu,
"pi": pi,
"pihat": pi_hat,
"pihat": pi,
"kappas_children": None,
"kappas_parents": None,
"psis_children": tuple(value_coupling for _ in range(len(children_idxs))),
Expand Down Expand Up @@ -699,9 +693,7 @@ def add_volatility_parent(
children_idxs: Union[List, int],
volatility_coupling: Union[float, np.ndarray, ArrayLike] = 1.0,
mu: Union[float, np.ndarray, ArrayLike] = 0.0,
mu_hat: Union[float, np.ndarray, ArrayLike] = 0.0,
pi: Union[float, np.ndarray, ArrayLike] = 1.0,
pi_hat: Union[float, np.ndarray, ArrayLike] = 1.0,
omega: Union[float, np.ndarray, ArrayLike] = -4.0,
rho: Union[float, np.ndarray, ArrayLike] = 0.0,
additional_parameters: Optional[Dict] = None,
Expand All @@ -716,14 +708,11 @@ def add_volatility_parent(
The volatility coupling between the child and parent node. This is will be
appended to the `kappas` parameters in the parent and child node(s).
mu :
The mean of the Gaussian distribution.
mu_hat :
The expected mean of the Gaussian distribution for the next observation.
The mean of the Gaussian distribution. This value is passed both to the
current and expected states.
pi :
The precision of the Gaussian distribution (inverse variance).
pi_hat :
The expected precision of the Gaussian distribution (inverse variance) for
the next observation.
The precision of the Gaussian distribution (inverse variance). This
value is passed both to the current and expected states.
omega :
The tonic part of the variance (the variance that is not inherited from
volatility parent(s)).
Expand All @@ -743,9 +732,9 @@ def add_volatility_parent(
# parent's parameter
node_parameters = {
"mu": mu,
"muhat": mu_hat,
"muhat": mu,
"pi": pi,
"pihat": pi_hat,
"pihat": pi,
"kappas_children": tuple(
volatility_coupling for _ in range(len(children_idxs))
),
Expand Down
2 changes: 0 additions & 2 deletions pyhgf/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def fill_categorical_state_node(
value_coupling=1.0,
pi=implied_binary_parameters["pi_1"],
mu=implied_binary_parameters["mu_1"],
mu_hat=implied_binary_parameters["mu_1"],
)

# add the continuous parent node
Expand All @@ -246,7 +245,6 @@ def fill_categorical_state_node(
+ implied_binary_parameters["n_categories"],
value_coupling=1.0,
mu=implied_binary_parameters["mu_2"],
mu_hat=implied_binary_parameters["mu_hat_2"],
pi=implied_binary_parameters["pi_2"],
omega=implied_binary_parameters["omega_2"],
)
Expand Down
30 changes: 24 additions & 6 deletions pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,8 @@ def plot_nodes(

# compute mu +/- sd at time t-1
# and use the sigmoid transform before plotting
mu_parent = np.insert(
trajectories_df[f"x_{parent_idx}_mu"][:-1], 0, np.nan
)
pi_parent = np.insert(
trajectories_df[f"x_{parent_idx}_pi"][:-1], 0, np.nan
)
mu_parent = trajectories_df[f"x_{parent_idx}_muhat"]
pi_parent = trajectories_df[f"x_{parent_idx}_pihat"]
sd = np.sqrt(1 / pi_parent)
y1 = 1 / (1 + np.exp(-mu_parent + sd))
y2 = 1 / (1 + np.exp(-mu_parent - sd))
Expand Down Expand Up @@ -577,6 +573,28 @@ def plot_nodes(
alpha=0.5,
color=input_colors[ii],
)
else:
child_idx = np.where(
np.array(hgf.input_nodes_idx.idx) == child_idx
)[0][0]
axs[i].scatter(
trajectories_df.time,
trajectories_df[f"observation_input_{child_idx}"],
s=3,
label=f"Value child node - {ii}",
alpha=0.3,
color=input_colors[ii],
edgecolors="grey",
)
axs[i].plot(
trajectories_df.time,
trajectories_df[f"observation_input_{child_idx}"],
linewidth=0.5,
linestyle="--",
alpha=0.3,
color=input_colors[ii],
)
axs[i].legend()

# plotting surprise
# -----------------
Expand Down

0 comments on commit c17f047

Please sign in to comment.