Skip to content

Commit

Permalink
Add network plotting functionalities (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico authored Mar 16, 2023
1 parent 4086bdc commit 71b194a
Show file tree
Hide file tree
Showing 15 changed files with 948 additions and 448 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install isort==5.10.1 flake8 black==22.12.0 pydocstyle==6.3.0 mypy==1.1.1
pip install isort==5.10.1 flake8 black==23.1.0 pydocstyle==6.3.0 mypy==1.1.1
- name: Run linting
run: |
flake8 ./pyhgf/
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
${{ runner.os }}-pip-
- name: Install dependencies
run: |
sudo apt-get install graphviz
pip install -r requirements-tests.txt
pip install ipykernel coverage pytest pytest-cov
python -m ipykernel install --user --name python3
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: isort
files: ^pyhgf/
- repo: https://github.com/ambv/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black
language_version: python3
Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ observing new data.

plot_trajectories
plot_correlations
plot_network

Response
--------
Expand Down
515 changes: 349 additions & 166 deletions docs/source/notebooks/1-Binary_HGF.ipynb

Large diffs are not rendered by default.

21 changes: 12 additions & 9 deletions docs/source/notebooks/1-Binary_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.1
jupytext_version: 1.14.5
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -68,18 +68,15 @@ two_levels_hgf = HGF(
model_type="binary",
initial_mu={"1": .0, "2": .5},
initial_pi={"1": .0, "2": 1e4},
omega={"1": None, "2": -3.0},
rho={"1": None, "2": 0.0},
kappas={"1": None},
eta0=0.0,
eta1=1.0,
pihat = jnp.inf,
omega={"2": -3.0},
)
```

This function create an instance of a HGF model automatically parametrized for a 2-levels binary structure, so we do not have to worry about creating the nodes structure ourself. This class also embed function to add new observations and plots results that we are going to use below.
This function create an instance of a HGF model automatically parametrized for a 2-levels binary structure, so we do not have to worry about creating the nodes structure ourself. This class also embed function to add new observations and plots results that we are going to use below. We can have a look at the node structure itself using the {ref}`pyhgf.plots.plot_network` function. This function will automatically dray the provided node structure using [Graphviz](https://github.com/xflr6/graphviz).

+++
```{code-cell} ipython3
two_levels_hgf.plot_network()
```

#### Add data

Expand Down Expand Up @@ -132,6 +129,12 @@ three_levels_hgf = HGF(
)
```

The node structure now includes a volatility parent at the third level.

```{code-cell} ipython3
three_levels_hgf.plot_network()
```

#### Add data

```{code-cell} ipython3
Expand Down
541 changes: 364 additions & 177 deletions docs/source/notebooks/2-Continuous_HGF.ipynb

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions docs/source/notebooks/2-Continuous_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.1
jupytext_version: 1.14.5
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -67,8 +67,6 @@ The response function used is the [sum of the Guassian surprise](pyhgf.response.
```

```{code-cell} ipython3
:tags: []
two_levels_continuous_hgf = HGF(
n_levels=2,
model_type="continuous",
Expand All @@ -79,9 +77,11 @@ two_levels_continuous_hgf = HGF(
kappas={"1": 1.0})
```

This function create an instance of a HGF model automatically parametrized for a 2-levels continuous structure, so we do not have to worry about creating the nodes structure ourself. This class also embed function to add new observations and plots results that we are going to use below.
This function create an instance of a HGF model automatically parametrized for a 2-levels continuous structure, so we do not have to worry about creating the nodes structure ourself. This class also embed function to add new observations and plots results that we are going to use below. We can have a look at the node structure itself using the {ref}`pyhgf.plots.plot_network` function. This function will automatically dray the provided node structure using [Graphviz](https://github.com/xflr6/graphviz).

+++
```{code-cell} ipython3
two_levels_continuous_hgf.plot_network()
```

#### Add data

Expand Down Expand Up @@ -138,6 +138,12 @@ three_levels_continuous_hgf = HGF(
kappas={"1": 1.0, "2": 1.0})
```

The node structure now includes a volatility parent at the third level.

```{code-cell} ipython3
three_levels_continuous_hgf.plot_network()
```

#### Add data

```{code-cell} ipython3
Expand Down Expand Up @@ -388,7 +394,7 @@ hgf_mcmc = HGF(
```

```{code-cell} ipython3
hgf_mcmc.plot_trajectories(ci=False)
hgf_mcmc.plot_trajectories(ci=True)
```

```{code-cell} ipython3
Expand Down
6 changes: 0 additions & 6 deletions pyhgf/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,10 @@ def continuous_node_update(
# Update the continuous value parents #
#######################################
if value_parents_idx is not None:

vape = node_parameters["mu"] - node_parameters["muhat"]
psis = node_parameters["psis"]

for va_pa_idx, psi in zip(value_parents_idx, psis):

# unpack the current parent's parameters with value and volatility parents
va_pa_node_parameters = parameters_structure[va_pa_idx]
va_pa_value_parents_idx = node_structure[va_pa_idx].value_parents
Expand Down Expand Up @@ -125,7 +123,6 @@ def continuous_node_update(
# Update volatility parents #
#############################
if volatility_parents_idx is not None:

nu = node_parameters["nu"]
kappas = node_parameters["kappas"]
vope = (
Expand All @@ -134,7 +131,6 @@ def continuous_node_update(
) * node_parameters["pihat"] - 1

for vo_pa_idx, kappa in zip(volatility_parents_idx, kappas):

# unpack the current parent's parameters with value and volatility parents
vo_pa_node_parameters = parameters_structure[vo_pa_idx]
vo_pa_value_parents_idx = node_structure[vo_pa_idx].value_parents
Expand Down Expand Up @@ -248,7 +244,6 @@ def continuous_input_update(
# Update value parents #
########################
if value_parents_idx is not None:

# unpack the current parent's parameters with value and volatility parents
va_pa_node_parameters = parameters_structure[value_parents_idx[0]]
va_pa_value_parents_idx = node_structure[value_parents_idx[0]].value_parents
Expand Down Expand Up @@ -306,7 +301,6 @@ def continuous_input_update(
# Update volatility parents #
#############################
if volatility_parents_idx is not None:

# unpack the current parent's parameters with value and volatility parents
vo_pa_node_parameters = parameters_structure[volatility_parents_idx[0]]
vo_pa_value_parents_idx = node_structure[
Expand Down
1 change: 0 additions & 1 deletion pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def hgf_logp(

# Fitting n HGF models to the n datasets
for i in range(n):

# Format HGF parameters
initial_mu: Dict = {
"1": _mu_1[i],
Expand Down
54 changes: 45 additions & 9 deletions pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
continuous_node_update,
gaussian_surprise,
)
from pyhgf.plots import plot_correlations, plot_trajectories
from pyhgf.plots import plot_correlations, plot_network, plot_trajectories
from pyhgf.response import total_binary_surprise, total_gaussian_surprise
from pyhgf.structure import loop_inputs
from pyhgf.typing import Indexes, NodeStructure
from pyhgf.typing import Indexes, InputIndexes, NodeStructure


class HGF(object):
Expand All @@ -30,6 +30,9 @@ class HGF(object):
Attributes
----------
input_nodes_idx :
Indexes of the input nodes. Defaults to `(0,)` if the network only has one input
node.
model_type :
The model implemented (can be `"continuous"`, `"binary"` or `"custom"`).
n_levels :
Expand All @@ -38,7 +41,8 @@ class HGF(object):
node_structure :
A tuple of :py:class:`pyhgf.typing.Indexes` representing the nodes hierarchy.
node_trajectories :
The node structure updated at each new observation.
The parameter structure that incluse the consecutive updates at each new
observation.
parameters_structure :
The structure of nodes' parameters. Each parameter is a dictionary with the
following parameters: `"pihat", "pi", "muhat", "mu", "nu", "psis", "omega"` for
Expand Down Expand Up @@ -140,6 +144,7 @@ def __init__(
self.node_structure: NodeStructure
self.node_trajectories: Dict
self.parameters_structure: Dict
self.input_nodes_idx: Tuple[InputIndexes, ...]

if model_type in ["continuous", "binary"]:
if self.verbose:
Expand All @@ -166,8 +171,8 @@ def __init__(
value_coupling=1.0,
mu=initial_mu["1"],
pi=initial_pi["1"],
omega=omega["1"],
rho=rho["1"],
omega=omega["1"] if self.model_type != "binary" else np.nan,
rho=rho["1"] if self.model_type != "binary" else np.nan,
)

#########
Expand Down Expand Up @@ -258,6 +263,10 @@ def plot_correlations(self):
"""Plot the heatmap of cross-trajectories correlation."""
return plot_correlations(hgf=self)

def plot_network(self):
"""Visualization of node network using GraphViz."""
return plot_network(hgf=self)

def surprise(
self,
response_function: Optional[Callable] = None,
Expand Down Expand Up @@ -333,12 +342,31 @@ def to_pandas(self) -> pd.DataFrame:
def add_input_node(
self,
kind: str,
input_idx: int = 0,
omega_input: Union[float, np.ndarray, ArrayLike] = log(1e-4),
pihat: Union[float, np.ndarray, ArrayLike] = jnp.inf,
eta0: Union[float, np.ndarray, ArrayLike] = 0.0,
eta1: Union[float, np.ndarray, ArrayLike] = 1.0,
):
"""Create an input node."""
"""Create an input node.
Parameters
----------
kind :
The kind of input that should be created (can be `"continuous"` or
`"binary"`).
input_idx :
The index of the new input (defaults to `0`).
omega_input :
The input precision (only relevant if `kind="continuous"`).
pihat :
The input precision (only relevant if `kind="binary"`).
eta0 :
The lower bound of the binary process (only relevant if `kind="binary"`).
eta1 :
The lower bound of the binary process (only relevant if `kind="binary"`).
"""
if kind == "continuous":
input_node_parameters = {
"kappas": None,
Expand All @@ -356,8 +384,16 @@ def add_input_node(
"time_step": jnp.nan,
"value": jnp.nan,
}
self.parameters_structure = {0: input_node_parameters}
self.node_structure = (Indexes(None, None),)
if input_idx == 0:
# this is the first node, create the node structure
self.parameters_structure = {input_idx: input_node_parameters}
self.node_structure = (Indexes(None, None),)
self.input_nodes_idx = (InputIndexes(input_idx, kind),)
else:
# update the node structure
self.parameters_structure[input_idx] = input_node_parameters
self.node_structure += (Indexes(None, None),)
self.input_nodes_idx += (InputIndexes(input_idx, kind),)
return self

def add_value_parent(
Expand Down Expand Up @@ -458,7 +494,7 @@ def add_value_parent(
def add_volatility_parent(
self,
children_idxs: List,
volatility_coupling: Union[float, np.ndarray, ArrayLike],
volatility_coupling: Union[float, np.ndarray, ArrayLike] = 1.0,
mu: Union[float, np.ndarray, ArrayLike] = 0.0,
mu_hat: Union[float, np.ndarray, ArrayLike] = jnp.nan,
pi: Union[float, np.ndarray, ArrayLike] = 1.0,
Expand Down
Loading

0 comments on commit 71b194a

Please sign in to comment.