Skip to content

Commit

Permalink
- add binary softmax (#88)
Browse files Browse the repository at this point in the history
add binary softmax response function
- add binary response dataset
- use pkgutil for data import
  • Loading branch information
LegrandNico authored Sep 1, 2023
1 parent c8c7dbd commit c4d8288
Show file tree
Hide file tree
Showing 19 changed files with 1,574 additions and 1,139 deletions.
5 changes: 3 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include pyhgf/data/usdchf.dat
include pyhgf/data/binary_input.dat
include pyhgf/data/usdchf.txt
include pyhgf/data/binary_input.txt
include pyhgf/data/binary_response.txt
218 changes: 104 additions & 114 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

50 changes: 36 additions & 14 deletions docs/source/notebooks/1.1-Binary_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.7
jupytext_version: 1.15.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -46,10 +46,15 @@ In this example, we will use data from a decision-making task where the outcome
+++

## Imports
We import a time series of binary responses from the decision task described in {cite:p}`2013:iglesias`.
We import a time series of binary observations from the decision task described in {cite:p}`Iglesias2021`.

```{code-cell} ipython3
timeserie = load_data("binary")
---
editable: true
slideshow:
slide_type: ''
---
u, _ = load_data("binary")
```

## Fitting the binary HGF with fixed parameters
Expand All @@ -63,6 +68,11 @@ The response function used is the binary surprise at each time point ({py:func}`
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
two_levels_hgf = HGF(
n_levels=2,
model_type="binary",
Expand All @@ -75,24 +85,36 @@ two_levels_hgf = HGF(
This function creates an instance of a HGF model automatically parametrized for a two-level binary structure, so we do not have to worry about creating the node structure ourselves. This class also embed function to add new observations and plot results that we are going to use below. We can have a look at the node structure itself using the {ref}`pyhgf.plots.plot_network` function. This function will automatically dray the provided node structure using [Graphviz](https://github.com/xflr6/graphviz).

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
two_levels_hgf.plot_network()
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

#### Add data

```{code-cell} ipython3
# Provide new observations
two_levels_hgf = two_levels_hgf.input_data(input_data=timeserie)
two_levels_hgf = two_levels_hgf.input_data(input_data=u)
```

#### Plot trajectories

+++
+++ {"editable": true, "slideshow": {"slide_type": ""}}

A Hierarchical Gaussian Filter is acting as a Bayesian filter when presented with new observation, and by running the update equation forward, we can observe the trajectories of the parameters of the node that are being updated after each new observation (i.e. the mean $\mu$ and the precision $\pi$). The `plot_trajectories` function automatically extracts the relevant parameters given the model structure and will plot their evolution together with the input data.

```{code-cell} ipython3
two_levels_hgf.plot_trajectories()
---
editable: true
slideshow:
slide_type: ''
---
two_levels_hgf.plot_trajectories();
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}
Expand Down Expand Up @@ -140,13 +162,13 @@ three_levels_hgf.plot_network()
#### Add data

```{code-cell} ipython3
three_levels_hgf = three_levels_hgf.input_data(input_data=timeserie)
three_levels_hgf = three_levels_hgf.input_data(input_data=u)
```

#### Plot trajectories

```{code-cell} ipython3
three_levels_hgf.plot_trajectories()
three_levels_hgf.plot_trajectories();
```

## Learning parameters with MCMC sampling
Expand All @@ -168,7 +190,7 @@ from pyhgf.response import first_level_binary_surprise
hgf_logp_op = HGFDistribution(
n_levels=2,
model_type="binary",
input_data=[timeserie],
input_data=[u],
response_function=first_level_binary_surprise,
)
```
Expand Down Expand Up @@ -243,12 +265,12 @@ hgf_mcmc = HGF(
omega={"1": jnp.inf, "2": omega_2},
rho={"1": 0.0, "2": 0.0},
kappas={"1": 1.0}).input_data(
input_data=timeserie
input_data=u
)
```

```{code-cell} ipython3
hgf_mcmc.plot_trajectories()
hgf_mcmc.plot_trajectories();
```

```{code-cell} ipython3
Expand All @@ -262,7 +284,7 @@ hgf_mcmc.surprise()
hgf_logp_op = HGFDistribution(
n_levels=3,
model_type="binary",
input_data=[timeserie],
input_data=[u],
response_function=first_level_binary_surprise,
)
```
Expand Down Expand Up @@ -334,12 +356,12 @@ hgf_mcmc = HGF(
omega={"1": jnp.inf, "2": omega_2, "3": omega_3},
rho={"1": 0.0, "2": 0.0, "3": 0.0},
kappas={"1": 1.0, "2": 1.0}).input_data(
input_data=timeserie
input_data=u
)
```

```{code-cell} ipython3
hgf_mcmc.plot_trajectories()
hgf_mcmc.plot_trajectories();
```

```{code-cell} ipython3
Expand Down
44 changes: 26 additions & 18 deletions docs/source/notebooks/2-Using_custom_response_functions.ipynb

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions docs/source/notebooks/2-Using_custom_response_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.7
jupytext_version: 1.15.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -67,7 +67,7 @@ editable: true
slideshow:
slide_type: ''
---
observations = load_data("binary")
u, _ = load_data("binary")
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}
Expand All @@ -89,7 +89,7 @@ agent = HGF(
initial_mu={"1": .0, "2": .5},
initial_pi={"1": .0, "2": 1e4},
omega={"2": -4.0},
).input_data(input_data=observations)
).input_data(input_data=u)
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}
Expand Down Expand Up @@ -137,7 +137,7 @@ tags: [hide-input]
---
plt.figure(figsize=(12, 3))
jitter = responses * .1 + (1-responses) * -.1
plt.scatter(np.arange(len(observations)), observations, label="Observations", color="#4c72b0", edgecolor="k", alpha=.2)
plt.scatter(np.arange(len(u)), u, label="Observations", color="#4c72b0", edgecolor="k", alpha=.2)
plt.scatter(np.arange(len(responses)), responses + jitter, label="Responses", color="#c44e52", alpha=.2, edgecolor="k")
plt.plot(agent.node_trajectories[1]["muhat"], label="Beliefs", linestyle="--")
plt.legend()
Expand Down Expand Up @@ -210,6 +210,8 @@ def response_function(hgf, response_function_parameters):
return jnp.sum(jnp.where(responses, -jnp.log(beliefs), -jnp.log(1.0 - beliefs)))
```

This function takes the expected probability from the binary node and uses it to predict the participant's decision. The surprise is computed using the binary surprise (see {py:func}`pyhgf.update.binary.binary_surprise`). This corresponds to the standard binary softmax response function that is also accessible from the {py:func}`pyhgf.response.binary_softmax` function.

+++ {"editable": true, "slideshow": {"slide_type": ""}}

```{note}
Expand Down Expand Up @@ -252,7 +254,7 @@ import arviz as az
hgf_logp_op = HGFDistribution(
n_levels=2,
model_type="binary",
input_data=[observations],
input_data=[u],
response_function=response_function,
response_function_parameters=[(responses, )]
)
Expand Down
26 changes: 11 additions & 15 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,17 @@ @article{2015:kanai
year = {2015}
}

@article{2013:iglesias,
title = {Hierarchical Prediction Errors in Midbrain and Basal Forebrain during Sensory Learning},
journal = {Neuron},
volume = {80},
number = {2},
pages = {519-530},
year = {2013},
issn = {0896-6273},
doi = {https://doi.org/10.1016/j.neuron.2013.09.009},
url = {https://www.sciencedirect.com/science/article/pii/S0896627313008076},
author = {Sandra Iglesias and Christoph Mathys and Kay H. Brodersen and Lars Kasper and Marco Piccirelli and Hanneke E.M. den Ouden and Klaas E. Stephan},
abstract = {Summary
In Bayesian brain theories, hierarchically related prediction errors (PEs) play a central role for predicting sensory inputs and inferring their underlying causes, e.g., the probabilistic structure of the environment and its volatility. Notably, PEs at different hierarchical levels may be encoded by different neuromodulatory transmitters. Here, we tested this possibility in computational fMRI studies of audio-visual learning. Using a hierarchical Bayesian model, we found that low-level PEs about visual stimulus outcome were reflected by widespread activity in visual and supramodal areas but also in the midbrain. In contrast, high-level PEs about stimulus probabilities were encoded by the basal forebrain. These findings were replicated in two groups of healthy volunteers. While our fMRI measures do not reveal the exact neuron types activated in midbrain and basal forebrain, they suggest a dichotomy between neuromodulatory systems, linking dopamine to low-level PEs about stimulus outcome and acetylcholine to more abstract PEs about stimulus probabilities.
Video Abstract
}
@article{Iglesias2021,
doi = {10.1016/j.neuroimage.2020.117590},
url = {https://doi.org/10.1016/j.neuroimage.2020.117590},
year = {2021},
month = feb,
publisher = {Elsevier {BV}},
volume = {226},
pages = {117590},
author = {Sandra Iglesias and Lars Kasper and Samuel J. Harrison and Robert Manka and Christoph Mathys and Klaas E. Stephan},
title = {Cholinergic and dopaminergic effects on prediction error and uncertainty responses during sensory associative learning},
journal = {{NeuroImage}}
}

@book{2014:lee,
Expand Down
46 changes: 34 additions & 12 deletions pyhgf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

import pkg_resources # type: ignore
from numpy import loadtxt
import pkgutil
from io import BytesIO
from typing import Tuple, Union

import numpy as np
import pandas as pd

__version__ = "0.0.6"


def load_data(dataset: str):
def load_data(dataset: str) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
"""Load dataset for continuous or binary HGF.
Parameters
Expand All @@ -22,22 +26,40 @@ def load_data(dataset: str):
Notes
-----
The continuous time series is the standard USD-CHF conversion rates over time used
in the Matlab examples. The binary dataset is from Iglesias et al. (2013) [#].
in the Matlab examples.
The binary dataset is from Iglesias et al. (2013) [#] (see the full dataset
`here <https://www.research-collection.ethz.ch/handle/20.500.11850/454711)>`_. The
binary set consist of one vector *u*, the observations, and one vector *y*, the
decisions.
References
----------
.. [#] Iglesias, S., Mathys, C., Brodersen, K. H., Kasper, L., Piccirelli, M., den
Ouden, H. E. M., & Stephan, K. E. (2013). Hierarchical Prediction Errors in Midbrain
and Basal Forebrain during Sensory Learning. In Neuron (Vol. 80, Issue 2, pp.
519–530). Elsevier BV. https://doi.org/10.1016/j.neuron.2013.09.009
.. [#] Iglesias, S., Kasper, L., Harrison, S. J., Manka, R., Mathys, C., & Stephan,
K. E. (2021). Cholinergic and dopaminergic effects on prediction error and
uncertainty responses during sensory associative learning. In NeuroImage (Vol.
226, p. 117590). Elsevier BV. https://doi.org/10.1016/j.neuroimage.2020.117590
"""
if dataset == "continuous":
data = loadtxt(pkg_resources.resource_filename("pyhgf", "/data/usdchf.dat"))
data = pd.read_csv(
BytesIO(pkgutil.get_data(__name__, "data/usdchf.txt")), # type: ignore
names=["x"],
).x.to_numpy()
elif dataset == "binary":
data = loadtxt(
pkg_resources.resource_filename("pyhgf", "/data/binary_input.dat")
)
u = pd.read_csv(
BytesIO(
pkgutil.get_data(__name__, "data/binary_input.txt") # type: ignore
),
names=["x"],
).x.to_numpy()
y = pd.read_csv(
BytesIO(
pkgutil.get_data(__name__, "data/binary_response.txt") # type: ignore
),
names=["x"],
).x.to_numpy()
data = (u, y)
else:
raise ValueError("Invalid dataset argument. Should be 'continous' or 'binary'.")

Expand Down
Loading

0 comments on commit c4d8288

Please sign in to comment.