Skip to content

Commit

Permalink
rename omega, pi, mu, rho and kappas
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 2, 2023
1 parent c876298 commit 04d1e7a
Show file tree
Hide file tree
Showing 29 changed files with 988 additions and 977 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ u, _ = load_data("binary")
hgf = HGF(
n_levels=2,
model_type="binary",
initial_mu={"1": .0, "2": .5},
initial_pi={"1": .0, "2": 1e4},
omega={"2": -3.0},
initial_mean={"1": .0, "2": .5},
initial_precision={"1": .0, "2": 1e4},
tonic_volatility={"2": -3.0},
)

# add new observations
Expand Down
6 changes: 3 additions & 3 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ u, _ = load_data("binary")
hgf = HGF(
n_levels=2,
model_type="binary",
initial_mu={"1": .0, "2": .5},
initial_pi={"1": .0, "2": 1e4},
omega={"2": -3.0},
initial_mean={"1": .0, "2": .5},
initial_precision={"1": .0, "2": 1e4},
tonic_volatility={"2": -3.0},
)

# add new observations
Expand Down
36 changes: 18 additions & 18 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@
"two_levels_hgf = HGF(\n",
" n_levels=2,\n",
" model_type=\"binary\",\n",
" initial_mu={\"1\": .0, \"2\": .5},\n",
" initial_pi={\"1\": .0, \"2\": 1.0},\n",
" omega={\"2\": -3.0},\n",
" initial_mean={\"1\": .0, \"2\": .5},\n",
" initial_precision={\"1\": .0, \"2\": 1.0},\n",
" tonic_volatility={\"2\": -3.0},\n",
")"
]
},
Expand Down Expand Up @@ -379,11 +379,11 @@
"three_levels_hgf = HGF(\n",
" n_levels=3,\n",
" model_type=\"binary\",\n",
" initial_mu={\"1\": .0, \"2\": .5, \"3\": 0.},\n",
" initial_pi={\"1\": .0, \"2\": 1.0, \"3\": 1.0},\n",
" omega={\"1\": None, \"2\": -3.0, \"3\": -2.0},\n",
" rho={\"1\": None, \"2\": 0.0, \"3\": 0.0},\n",
" kappas={\"1\": None, \"2\": 1.0},\n",
" initial_mean={\"1\": .0, \"2\": .5, \"3\": 0.},\n",
" initial_precision={\"1\": .0, \"2\": 1.0, \"3\": 1.0},\n",
" tonic_volatility={\"1\": None, \"2\": -3.0, \"3\": -2.0},\n",
" tonic_drift={\"1\": None, \"2\": 0.0, \"3\": 0.0},\n",
" volatility_coupling={\"1\": None, \"2\": 1.0},\n",
" eta0=0.0,\n",
" eta1=1.0,\n",
" binary_precision=jnp.inf,\n",
Expand Down Expand Up @@ -822,11 +822,11 @@
"hgf_mcmc = HGF(\n",
" n_levels=2,\n",
" model_type=\"binary\",\n",
" initial_mu={\"1\": jnp.inf, \"2\": 0.5},\n",
" initial_pi={\"1\": 0.0, \"2\": 1.0},\n",
" omega={\"1\": jnp.inf, \"2\": omega_2},\n",
" rho={\"1\": 0.0, \"2\": 0.0},\n",
" kappas={\"1\": 1.0}).input_data(\n",
" initial_mean={\"1\": jnp.inf, \"2\": 0.5},\n",
" initial_precision={\"1\": 0.0, \"2\": 1.0},\n",
" tonic_volatility={\"1\": jnp.inf, \"2\": omega_2},\n",
" tonic_drift={\"1\": 0.0, \"2\": 0.0},\n",
" volatility_coupling={\"1\": 1.0}).input_data(\n",
" input_data=u\n",
" )"
]
Expand Down Expand Up @@ -1158,11 +1158,11 @@
"hgf_mcmc = HGF(\n",
" n_levels=3,\n",
" model_type=\"binary\",\n",
" initial_mu={\"1\": jnp.inf, \"2\": 0.5, \"3\": 0.0},\n",
" initial_pi={\"1\": 0.0, \"2\": 1e4, \"3\": 1e1},\n",
" omega={\"1\": jnp.inf, \"2\": omega_2, \"3\": omega_3},\n",
" rho={\"1\": 0.0, \"2\": 0.0, \"3\": 0.0},\n",
" kappas={\"1\": 1.0, \"2\": 1.0}).input_data(\n",
" initial_mean={\"1\": jnp.inf, \"2\": 0.5, \"3\": 0.0},\n",
" initial_precision={\"1\": 0.0, \"2\": 1e4, \"3\": 1e1},\n",
" tonic_volatility={\"1\": jnp.inf, \"2\": omega_2, \"3\": omega_3},\n",
" tonic_drift={\"1\": 0.0, \"2\": 0.0, \"3\": 0.0},\n",
" volatility_coupling={\"1\": 1.0, \"2\": 1.0}).input_data(\n",
" input_data=u\n",
" )"
]
Expand Down
36 changes: 18 additions & 18 deletions docs/source/notebooks/1.1-Binary_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ slideshow:
two_levels_hgf = HGF(
n_levels=2,
model_type="binary",
initial_mu={"1": .0, "2": .5},
initial_pi={"1": .0, "2": 1.0},
omega={"2": -3.0},
initial_mean={"1": .0, "2": .5},
initial_precision={"1": .0, "2": 1.0},
tonic_volatility={"2": -3.0},
)
```

Expand Down Expand Up @@ -142,11 +142,11 @@ Here, we create a new {py:class}`pyhgf.model.HGF` instance, setting the number o
three_levels_hgf = HGF(
n_levels=3,
model_type="binary",
initial_mu={"1": .0, "2": .5, "3": 0.},
initial_pi={"1": .0, "2": 1.0, "3": 1.0},
omega={"1": None, "2": -3.0, "3": -2.0},
rho={"1": None, "2": 0.0, "3": 0.0},
kappas={"1": None, "2": 1.0},
initial_mean={"1": .0, "2": .5, "3": 0.},
initial_precision={"1": .0, "2": 1.0, "3": 1.0},
tonic_volatility={"1": None, "2": -3.0, "3": -2.0},
tonic_drift={"1": None, "2": 0.0, "3": 0.0},
volatility_coupling={"1": None, "2": 1.0},
eta0=0.0,
eta1=1.0,
binary_precision=jnp.inf,
Expand Down Expand Up @@ -244,11 +244,11 @@ omega_2 = az.summary(two_level_hgf_idata)["mean"]["omega_2"]
hgf_mcmc = HGF(
n_levels=2,
model_type="binary",
initial_mu={"1": jnp.inf, "2": 0.5},
initial_pi={"1": 0.0, "2": 1.0},
omega={"1": jnp.inf, "2": omega_2},
rho={"1": 0.0, "2": 0.0},
kappas={"1": 1.0}).input_data(
initial_mean={"1": jnp.inf, "2": 0.5},
initial_precision={"1": 0.0, "2": 1.0},
tonic_volatility={"1": jnp.inf, "2": omega_2},
tonic_drift={"1": 0.0, "2": 0.0},
volatility_coupling={"1": 1.0}).input_data(
input_data=u
)
```
Expand Down Expand Up @@ -330,11 +330,11 @@ omega_3 = az.summary(three_level_hgf_idata)["mean"]["omega_3"]
hgf_mcmc = HGF(
n_levels=3,
model_type="binary",
initial_mu={"1": jnp.inf, "2": 0.5, "3": 0.0},
initial_pi={"1": 0.0, "2": 1e4, "3": 1e1},
omega={"1": jnp.inf, "2": omega_2, "3": omega_3},
rho={"1": 0.0, "2": 0.0, "3": 0.0},
kappas={"1": 1.0, "2": 1.0}).input_data(
initial_mean={"1": jnp.inf, "2": 0.5, "3": 0.0},
initial_precision={"1": 0.0, "2": 1e4, "3": 1e1},
tonic_volatility={"1": jnp.inf, "2": omega_2, "3": omega_3},
tonic_drift={"1": 0.0, "2": 0.0, "3": 0.0},
volatility_coupling={"1": 1.0, "2": 1.0}).input_data(
input_data=u
)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/source/notebooks/1.2-Categorical_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ tags: [hide-cell]
---
fig, axs = plt.subplots(nrows=5, figsize=(12, 9), sharex=True)
plot_nodes(categorical_hgf, node_idxs=31, axs=axs[0])
axs[1].imshow(categorical_hgf.node_trajectories[0]["mu"].T, interpolation="none", aspect="auto")
axs[1].imshow(categorical_hgf.node_trajectories[0]["mean"].T, interpolation="none", aspect="auto")
axs[1].set_title("Mean of the implied Dirichlet distribution", loc="left")
axs[1].set_ylabel("Categories")
Expand Down Expand Up @@ -191,7 +191,7 @@ def categorical_surprise(omega_2, hgf, input_data):
for va_pa in hgf.edges[0].value_parents:
for va_pa_va_pa in hgf.edges[va_pa].value_parents:
for va_pa_va_pa_va_pa in hgf.edges[va_pa_va_pa].value_parents:
hgf.attributes[va_pa_va_pa_va_pa]["omega"] = omega_2
hgf.attributes[va_pa_va_pa_va_pa]["tonic_volatility"] = omega_2
# fit the model to new data
hgf.input_data(input_data=input_data.T)
Expand Down
30 changes: 15 additions & 15 deletions docs/source/notebooks/1.3-Continuous_HGF.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@
"two_levels_continuous_hgf = HGF(\n",
" n_levels=2,\n",
" model_type=\"continuous\",\n",
" initial_mu={\"1\": timeserie[0], \"2\": 0.0},\n",
" initial_pi={\"1\": 1e4, \"2\": 1e1},\n",
" omega={\"1\": -13.0, \"2\": -2.0}\n",
" initial_mean={\"1\": timeserie[0], \"2\": 0.0},\n",
" initial_precision={\"1\": 1e4, \"2\": 1e1},\n",
" tonic_volatility={\"1\": -13.0, \"2\": -2.0}\n",
")"
]
},
Expand Down Expand Up @@ -365,9 +365,9 @@
"three_levels_continuous_hgf = HGF(\n",
" n_levels=3,\n",
" model_type=\"continuous\",\n",
" initial_mu={\"1\": 1.04, \"2\": 0.0, \"3\": 0.0},\n",
" initial_pi={\"1\": 1e4, \"2\": 1e1, \"3\": 1e1},\n",
" omega={\"1\": -13.0, \"2\": -2.0, \"3\": -2.0}\n",
" initial_mean={\"1\": 1.04, \"2\": 0.0, \"3\": 0.0},\n",
" initial_precision={\"1\": 1e4, \"2\": 1e1, \"3\": 1e1},\n",
" tonic_volatility={\"1\": -13.0, \"2\": -2.0, \"3\": -2.0}\n",
")"
]
},
Expand Down Expand Up @@ -574,9 +574,9 @@
"three_levels_continuous_hgf_bis = HGF(\n",
" n_levels=3,\n",
" model_type=\"continuous\",\n",
" initial_mu={\"1\": 1.04, \"2\": 0.0, \"3\": 0.0},\n",
" initial_pi={\"1\": 1e4, \"2\": 1e1, \"3\": 1e1},\n",
" omega={\"1\": -13.0, \"2\": -1.0, \"3\": -2.0},\n",
" initial_mean={\"1\": 1.04, \"2\": 0.0, \"3\": 0.0},\n",
" initial_precision={\"1\": 1e4, \"2\": 1e1, \"3\": 1e1},\n",
" tonic_volatility={\"1\": -13.0, \"2\": -1.0, \"3\": -2.0},\n",
").input_data(input_data=timeserie)"
]
},
Expand Down Expand Up @@ -938,9 +938,9 @@
"hgf_mcmc = HGF(\n",
" n_levels=2,\n",
" model_type=\"continuous\",\n",
" initial_mu={\"1\": timeserie[0], \"2\": 0.0},\n",
" initial_pi={\"1\": 1e4, \"2\": 1e1},\n",
" omega={\"1\": omega_1, \"2\": -2.0}).input_data(\n",
" initial_mean={\"1\": timeserie[0], \"2\": 0.0},\n",
" initial_precision={\"1\": 1e4, \"2\": 1e1},\n",
" tonic_volatility={\"1\": omega_1, \"2\": -2.0}).input_data(\n",
" input_data=timeserie\n",
" )"
]
Expand Down Expand Up @@ -1249,9 +1249,9 @@
"hgf_mcmc = HGF(\n",
" n_levels=3,\n",
" model_type=\"continuous\",\n",
" initial_mu={\"1\": timeserie[0], \"2\": 0.0, \"3\": 0.0},\n",
" initial_pi={\"1\": 1e4, \"2\": 1e1, \"3\": 1e1},\n",
" omega={\"1\": omega_1, \"2\": -2.0, \"3\": -2.0}).input_data(\n",
" initial_mean={\"1\": timeserie[0], \"2\": 0.0, \"3\": 0.0},\n",
" initial_precision={\"1\": 1e4, \"2\": 1e1, \"3\": 1e1},\n",
" tonic_volatility={\"1\": omega_1, \"2\": -2.0, \"3\": -2.0}).input_data(\n",
" input_data=timeserie\n",
" )"
]
Expand Down
30 changes: 15 additions & 15 deletions docs/source/notebooks/1.3-Continuous_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ The default response function for a continuous HGF is the [sum of the Gaussian s
two_levels_continuous_hgf = HGF(
n_levels=2,
model_type="continuous",
initial_mu={"1": timeserie[0], "2": 0.0},
initial_pi={"1": 1e4, "2": 1e1},
omega={"1": -13.0, "2": -2.0}
initial_mean={"1": timeserie[0], "2": 0.0},
initial_precision={"1": 1e4, "2": 1e1},
tonic_volatility={"1": -13.0, "2": -2.0}
)
```

Expand Down Expand Up @@ -121,9 +121,9 @@ The three-level HGF can add a meta-volatility layer to the model. This can be us
three_levels_continuous_hgf = HGF(
n_levels=3,
model_type="continuous",
initial_mu={"1": 1.04, "2": 0.0, "3": 0.0},
initial_pi={"1": 1e4, "2": 1e1, "3": 1e1},
omega={"1": -13.0, "2": -2.0, "3": -2.0}
initial_mean={"1": 1.04, "2": 0.0, "3": 0.0},
initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0}
)
```

Expand Down Expand Up @@ -160,9 +160,9 @@ The overall amount of surprise returned by the three-level HGF is quite similar
three_levels_continuous_hgf_bis = HGF(
n_levels=3,
model_type="continuous",
initial_mu={"1": 1.04, "2": 0.0, "3": 0.0},
initial_pi={"1": 1e4, "2": 1e1, "3": 1e1},
omega={"1": -13.0, "2": -1.0, "3": -2.0},
initial_mean={"1": 1.04, "2": 0.0, "3": 0.0},
initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
tonic_volatility={"1": -13.0, "2": -1.0, "3": -2.0},
).input_data(input_data=timeserie)
```

Expand Down Expand Up @@ -261,9 +261,9 @@ omega_1 = az.summary(two_level_hgf_idata)["mean"]["omega_1"]
hgf_mcmc = HGF(
n_levels=2,
model_type="continuous",
initial_mu={"1": timeserie[0], "2": 0.0},
initial_pi={"1": 1e4, "2": 1e1},
omega={"1": omega_1, "2": -2.0}).input_data(
initial_mean={"1": timeserie[0], "2": 0.0},
initial_precision={"1": 1e4, "2": 1e1},
tonic_volatility={"1": omega_1, "2": -2.0}).input_data(
input_data=timeserie
)
```
Expand Down Expand Up @@ -327,9 +327,9 @@ omega_1 = az.summary(three_level_hgf_idata)["mean"]["omega_1"]
hgf_mcmc = HGF(
n_levels=3,
model_type="continuous",
initial_mu={"1": timeserie[0], "2": 0.0, "3": 0.0},
initial_pi={"1": 1e4, "2": 1e1, "3": 1e1},
omega={"1": omega_1, "2": -2.0, "3": -2.0}).input_data(
initial_mean={"1": timeserie[0], "2": 0.0, "3": 0.0},
initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
tonic_volatility={"1": omega_1, "2": -2.0, "3": -2.0}).input_data(
input_data=timeserie
)
```
Expand Down
6 changes: 3 additions & 3 deletions index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ u, _ = pyhgf.updates.prediction_error.continuous
hgf = HGF(
n_levels=2,
model_type="binary",
initial_mu={"1": .0, "2": .5},
initial_pi={"1": .0, "2": 1e4},
omega={"2": -3.0},
initial_mean={"1": .0, "2": .5},
initial_precision={"1": .0, "2": 1e4},
tonic_volatility={"2": -3.0},
)

# add new observations
Expand Down
Loading

0 comments on commit 04d1e7a

Please sign in to comment.