Skip to content

Commit

Permalink
remove control over verbosity (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico authored Aug 23, 2024
1 parent eb7485a commit 1345689
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 320 deletions.
28 changes: 14 additions & 14 deletions docs/source/notebooks/0.1-Theory.ipynb

Large diffs are not rendered by default.

40 changes: 16 additions & 24 deletions docs/source/notebooks/0.2-Creating_networks.ipynb

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb

Large diffs are not rendered by default.

258 changes: 128 additions & 130 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion docs/source/notebooks/3-Multilevel_HGF.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@
"# create just one default network - we will simply change the values of interest before fitting to save time\n",
"agent = HGF(\n",
" n_levels=2,\n",
" verbose=False,\n",
" model_type=\"binary\",\n",
" initial_mean={\"1\": 0.5, \"2\": 0.0},\n",
")"
Expand Down
1 change: 0 additions & 1 deletion docs/source/notebooks/4-Parameter_recovery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@
"# create just one default network - we will simply change the values of interest before fitting to save time\n",
"agent = HGF(\n",
" n_levels=2,\n",
" verbose=False,\n",
" model_type=\"binary\",\n",
" initial_mean={\"1\": 0.5, \"2\": 0.0},\n",
")"
Expand Down
157 changes: 62 additions & 95 deletions docs/source/notebooks/Example_3_Multi_armed_bandit.ipynb

Large diffs are not rendered by default.

17 changes: 1 addition & 16 deletions src/pyhgf/model/hgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class HGF(Network):
strength with children and parents (i.e. `"value_coupling_parents"`,
`"value_coupling_children"`, `"volatility_coupling_parents"`,
`"volatility_coupling_children"`).
verbose : bool
Verbosity level.
"""

Expand Down Expand Up @@ -62,7 +60,6 @@ def __init__(
"2": 0.0,
"3": 0.0,
},
verbose: bool = True,
) -> None:
r"""Parameterization of the HGF model.
Expand Down Expand Up @@ -107,27 +104,15 @@ def __init__(
A dictionary containing the initial values for the tonic drift
at different levels of the hierarchy. This represents the drift of the
random walk. Defaults set all entries to `0.0` (no drift).
verbose :
The verbosity of the methods for model creation and fitting. Defaults to
`True`.
"""
self.verbose = verbose
Network.__init__(self)
self.model_type = model_type
self.n_levels = n_levels

if model_type not in ["continuous", "binary"]:
if self.verbose:
print("Initializing a network with custom node structure.")
raise ValueError("Invalid model type.")
else:
if self.verbose:
print(
(
f"Creating a {self.model_type} Hierarchical Gaussian Filter "
f"with {self.n_levels} levels."
)
)
if model_type == "continuous":
# Input
self.add_nodes(
Expand Down
16 changes: 0 additions & 16 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(self) -> None:
self.update_sequence: Optional[UpdateSequence] = None
self.scan_fn: Optional[Callable] = None
self.inputs: Inputs
self.verbose: bool = False

def create_belief_propagation_fn(self, overwrite: bool = True) -> "Network":
"""Create the belief propagation function.
Expand All @@ -88,8 +87,6 @@ def create_belief_propagation_fn(self, overwrite: bool = True) -> "Network":
# create the update sequence if it does not already exist
if self.update_sequence is None:
self.set_update_sequence()
if self.verbose:
print("... Create the update sequence from the network structure.")

# create the belief propagation function
# this function is used by scan to loop over observations
Expand All @@ -99,20 +96,13 @@ def create_belief_propagation_fn(self, overwrite: bool = True) -> "Network":
update_sequence=self.update_sequence,
structure=self.structure,
)
if self.verbose:
print("... Create the belief propagation function.")
else:
if overwrite:
self.scan_fn = Partial(
beliefs_propagation,
update_sequence=self.update_sequence,
structure=self.structure,
)
if self.verbose:
print("... Create the belief propagation function (overwrite).")
else:
if self.verbose:
print("... The belief propagation function is already defined.")

return self

Expand All @@ -138,8 +128,6 @@ def cache_belief_propagation_fn(self) -> "Network":
jnp.ones((1, 1)),
),
)
if self.verbose:
print("... Cache the belief propagation function.")

return self

Expand Down Expand Up @@ -175,8 +163,6 @@ def input_data(
"""
if self.scan_fn is None:
self = self.create_belief_propagation_fn()
if self.verbose:
print((f"Adding {len(input_data)} new observations."))
if time_steps is None:
time_steps = np.ones((len(input_data), 1)) # time steps vector
else:
Expand Down Expand Up @@ -247,8 +233,6 @@ def input_custom_sequence(
missing in the event log, or rejected trials).
"""
if self.verbose:
print((f"Adding {len(input_data)} new observations."))
if time_steps is None:
time_steps = np.ones(len(input_data)) # time steps vector

Expand Down
4 changes: 2 additions & 2 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from pyhgf.model import HGF
from pyhgf.model import Network


def test_categorical_state_node():
Expand All @@ -13,7 +13,7 @@ def test_categorical_state_node():
input_data = np.vstack([[0.0] * input_data.shape[1], input_data])

# create the categorical HGF
categorical_hgf = HGF(model_type=None, verbose=False).add_nodes(
categorical_hgf = Network().add_nodes(
kind="categorical-input",
node_parameters={
"n_categories": 3,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from pyhgf import load_data
from pyhgf.model import HGF
from pyhgf.model import HGF, Network
from pyhgf.response import total_gaussian_surprise


Expand All @@ -16,7 +16,7 @@ def test_HGF():
#####################

custom_hgf = (
HGF(model_type=None)
Network()
.add_nodes(kind="continuous-input")
.add_nodes(kind="binary-input")
.add_nodes(value_children=0)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from pyhgf import load_data
from pyhgf.model import HGF
from pyhgf.model import HGF, Network


def test_plotting_functions():
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_plotting_functions():
input_data = np.vstack([[0.0] * input_data.shape[1], input_data])

# create the categorical HGF
categorical_hgf = HGF(model_type=None, verbose=False).add_nodes(
categorical_hgf = Network().add_nodes(
kind="categorical-input",
node_parameters={
"n_categories": 3,
Expand Down

0 comments on commit 1345689

Please sign in to comment.