Skip to content

Commit

Permalink
remonve input nodes - binary hgf
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 19, 2024
1 parent ff704e7 commit 0da432a
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 440 deletions.
336 changes: 83 additions & 253 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

125 changes: 80 additions & 45 deletions src/pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,54 +47,66 @@ def logp(
Parameters
----------
mean_1 :
The mean at the first level of the HGF.
The mean at the first level of the HGF. For the continuous HGF, this is the mean
of the first value parent (x_1). For the binary HGF this is the mean of the
binary state node (x_0).
mean_2 :
The mean at the second level of the HGF.
The mean at the second level of the HGF. For the continuous HGF, this is the
mean of the first volatility parent (x_2). For the binary HGF this is the mean
of the first continuous state node (x_1).
mean_3 :
The mean at the third level of the HGF. The value of this parameter will be
ignored when using a two-level HGF (`n_levels=2`).
ignored when using a two-level HGF (`n_levels=2`). For the continuous HGF, this
is the mean of the second volatility parent (x_3). For the binary HGF this is
the mean of the first volatility parent (x_2).
precision_1 :
The precision at the first level of the HGF.
The precision at the first level of the HGF. For the continuous HGF, this is the
precision of the first value parent (x_1). For the binary HGF this is the
precision of the binary state node (x_0).
precision_2 :
The precision at the second level of the HGF.
The precision at the second level of the HGF. For the continuous HGF, this is
the precision of the first volatility parent (x_2). For the binary HGF this is
the precision of the first continuous state node (x_1).
precision_3 :
The precision at the third level of the HGF. The value of this parameter will
be ignored when using a two-level HGF (`n_levels=2`).
be ignored when using a two-level HGF (`n_levels=2`). For the continuous HGF,
this is the precision of the second volatility parent (x_3). For the binary HGF
this is the precision of the first volatility parent (x_2).
tonic_volatility_1 :
The tonic volatility at the first level of the HGF. This parameter represents
the tonic part of the variance (the part that is not inherited from parent
nodes).
The tonic volatility at the first level (x_1 for the continuous HGF, x_2 for the
binary HGF). This parameter represents the tonic part of the variance (the part
that is not inherited from parent nodes).
tonic_volatility_2 :
The tonic volatility at the second level of the HGF. This parameter represents
the tonic part of the variance (the part that is not inherited from parent
nodes).
The tonic volatility at the second level (x_2 for the continuous HGF, x_3 for
the binary HGF). This parameter represents the tonic part of the variance (the
part that is not inherited from parent nodes).
tonic_volatility_3 :
The tonic volatility at the third level of the HGF. This parameter represents
the tonic part of the variance (the part that is not inherited from parent
nodes). The value of this parameter will be ignored when using a two-level HGF
(`n_levels=2`).
nodes). This parameter is only used for a three-level continuous HGF.
tonic_drift_1 :
The tonic drift at the first level of the HGF. This parameter represents the
drift of the random walk.
The tonic drift at the first level of the HGF (x_1 for the continuous HGF,
x_2 for the binary HGF). This parameter represents the drift of the random walk.
tonic_drift_2 :
The tonic drift at the second level of the HGF. This parameter represents the
drift of the random walk.
The tonic drift at the second level of the HGF (x_2 for the continuous HGF,
x_3 for the binary HGF). This parameter represents the drift of the random walk.
tonic_drift_3 :
The tonic drift at the third level of the HGF. This parameter represents the
drift of the random walk. The value of this parameter will be ignored when
using a two-level HGF (`n_levels=2`).
drift of the random walk. This parameter is only used for a three-level
continuous HGF.
volatility_coupling_1 :
The volatility coupling between the first and second levels of the HGF. This
represents the phasic part of the variance (the part affected by the
The volatility coupling between the first and second levels of the HGF (between
x_1 and x_2 for a continuous HGF, and between x_2 and x_3 for a binary HGF).
This represents the phasic part of the variance (the part affected by the
parent nodes). Defaults to `1.0` (full connectivity).
volatility_coupling_2 :
The volatility coupling between the second and third levels of the HGF. This
represents the phasic part of the variance (the part affected by the
parent nodes). Defaults to `1.0` (full connectivity). The value of this
parameter will be ignored when using a two-level HGF (`n_levels=2`).
The volatility coupling between the second and third levels of the HGF (x_2 and
x_2 for a continuous HGF, not applicable to a binary HGF). This represents the
phasic part of the variance (the part affected by the parent nodes). Defaults
to `1.0` (full connectivity). The value of this parameter will be ignored when
using a two-level HGF (`n_levels=2`).
input_precision :
The expected precision associated with the continuous or binary input, depending
on the model type.
The expected precision associated with the continuous input.
response_function_parameters :
An array of additional parameters that will be passed to the response function
to compute the surprise. This can include values over which inference is
Expand All @@ -119,29 +131,52 @@ def logp(
The log-probability (negative surprise).
"""
# update this network's attributes
hgf.attributes[0]["expected_precision"] = input_precision
if hgf.model_type == "continuous":

# update this network's attributes
hgf.attributes[0]["expected_precision"] = input_precision

hgf.attributes[1]["mean"] = mean_1
hgf.attributes[2]["mean"] = mean_2

hgf.attributes[1]["precision"] = precision_1
hgf.attributes[2]["precision"] = precision_2

hgf.attributes[1]["tonic_volatility"] = tonic_volatility_1
hgf.attributes[2]["tonic_volatility"] = tonic_volatility_2

hgf.attributes[1]["tonic_drift"] = tonic_drift_1
hgf.attributes[2]["tonic_drift"] = tonic_drift_2

hgf.attributes[2]["volatility_coupling"] = (volatility_coupling_1,)

if hgf.n_levels == 3:
hgf.attributes[3]["mean"] = mean_3
hgf.attributes[3]["precision"] = precision_3
hgf.attributes[3]["tonic_volatility"] = tonic_volatility_3
hgf.attributes[3]["tonic_drift"] = tonic_drift_3
hgf.attributes[3]["volatility_coupling"] = (volatility_coupling_2,)

hgf.attributes[1]["mean"] = mean_1
hgf.attributes[2]["mean"] = mean_2
elif hgf.model_type == "binary":

hgf.attributes[1]["precision"] = precision_1
hgf.attributes[2]["precision"] = precision_2
# update this network's attributes
hgf.attributes[0]["mean"] = mean_1
hgf.attributes[1]["mean"] = mean_2

hgf.attributes[1]["tonic_volatility"] = tonic_volatility_1
hgf.attributes[2]["tonic_volatility"] = tonic_volatility_2
hgf.attributes[0]["precision"] = precision_1
hgf.attributes[1]["precision"] = precision_2

hgf.attributes[1]["tonic_drift"] = tonic_drift_1
hgf.attributes[2]["tonic_drift"] = tonic_drift_2
hgf.attributes[1]["tonic_volatility"] = tonic_volatility_2

hgf.attributes[2]["volatility_coupling"] = (volatility_coupling_1,)
hgf.attributes[1]["tonic_drift"] = tonic_drift_1

if hgf.n_levels == 3:
hgf.attributes[3]["mean"] = mean_3
hgf.attributes[3]["precision"] = precision_3
hgf.attributes[3]["tonic_volatility"] = tonic_volatility_3
hgf.attributes[3]["tonic_drift"] = tonic_drift_3
hgf.attributes[3]["volatility_coupling"] = (volatility_coupling_2,)
if hgf.n_levels == 3:
hgf.attributes[2]["mean"] = mean_3
hgf.attributes[2]["precision"] = precision_3
hgf.attributes[2]["tonic_volatility"] = tonic_volatility_3
hgf.attributes[2]["tonic_drift"] = tonic_drift_3
hgf.attributes[1]["volatility_coupling_parents"] = (volatility_coupling_2,)
hgf.attributes[2]["volatility_coupling_children"] = (volatility_coupling_2,)

surprise = hgf.input_data(input_data=input_data, time_steps=time_steps).surprise(
response_function=response_function,
Expand Down
79 changes: 42 additions & 37 deletions src/pyhgf/model/hgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def __init__(
raise ValueError("Invalid model type.")
else:
if model_type == "continuous":
# Input
# X - 0
self.add_nodes(
precision=continuous_precision,
)

# X-1
# X - 1
self.add_nodes(
value_children=([0], [1.0]),
node_parameters={
Expand All @@ -130,7 +130,7 @@ def __init__(
},
)

# X-2
# X - 2
self.add_nodes(
volatility_children=([1], [volatility_coupling["1"]]),
node_parameters={
Expand All @@ -141,31 +141,36 @@ def __init__(
},
)

elif model_type == "binary":
# Input
self.add_nodes(
kind="binary-input",
node_parameters={
"eta0": eta0,
"eta1": eta1,
"binary_precision": binary_precision,
},
)
#########################
# Meta volatility level #
#########################
if self.n_levels == 3:
self.add_nodes(
volatility_children=([2], [volatility_coupling["2"]]),
node_parameters={
"mean": initial_mean["3"],
"precision": initial_precision["3"],
"tonic_volatility": tonic_volatility["3"],
"tonic_drift": tonic_drift["3"],
},
)

# X -1
self.add_nodes(
kind="binary-state",
value_children=([0], [1.0]),
node_parameters={
"mean": initial_mean["1"],
"precision": initial_precision["1"],
},
)
elif model_type == "binary":

# X -2
if binary_precision == jnp.inf:
# X - 0
self.add_nodes(
kind="binary-state",
node_parameters={
"mean": initial_mean["1"],
"precision": initial_precision["1"],
},
)

# X - 1
self.add_nodes(
kind="continuous-state",
value_children=([1], [1.0]),
value_children=([0], [1.0]),
node_parameters={
"mean": initial_mean["2"],
"precision": initial_precision["2"],
Expand All @@ -174,19 +179,19 @@ def __init__(
},
)

#########
# x - 3 #
#########
if self.n_levels == 3:
self.add_nodes(
volatility_children=([2], [volatility_coupling["2"]]),
node_parameters={
"mean": initial_mean["3"],
"precision": initial_precision["3"],
"tonic_volatility": tonic_volatility["3"],
"tonic_drift": tonic_drift["3"],
},
)
#########################
# Meta volatility level #
#########################
if self.n_levels == 3:
self.add_nodes(
volatility_children=([1], [volatility_coupling["2"]]),
node_parameters={
"mean": initial_mean["3"],
"precision": initial_precision["3"],
"tonic_volatility": tonic_volatility["3"],
"tonic_drift": tonic_drift["3"],
},
)

# initialize the model so it is ready to receive new observations
self.create_belief_propagation_fn()
Expand Down
15 changes: 3 additions & 12 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,23 +465,14 @@ def add_nodes(
}
elif kind == "binary-state":
default_parameters = {
"mean": 0.0,
"expected_mean": 0.0,
"observed": 1,
"mean": 0.5,
"expected_mean": 0.5,
"precision": 1.0,
"expected_precision": 1.0,
"volatility_coupling_children": volatility_children[1],
"volatility_coupling_parents": volatility_parents[1],
"value_coupling_children": value_children[1],
"value_coupling_parents": value_parents[1],
"tonic_volatility": 0.0,
"tonic_drift": 0.0,
"autoconnection_strength": 1.0,
"observed": 1,
"binary_expected_precision": 0.0,
"temp": {
"effective_precision": 0.0,
"value_prediction_error": 0.0,
"volatility_prediction_error": 0.0,
},
}
elif kind == "generic-state":
Expand Down
Loading

0 comments on commit 0da432a

Please sign in to comment.