Skip to content

Commit

Permalink
set default paraneters to the HGF distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 22, 2023
1 parent 0c4e3e3 commit 0d63429
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 110 deletions.
218 changes: 110 additions & 108 deletions src/pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,123 +14,130 @@


def hgf_logp(
omega_1: Union[np.ndarray, ArrayLike, float],
omega_2: Union[np.ndarray, ArrayLike, float],
omega_3: Union[np.ndarray, ArrayLike, float],
continuous_precision: Union[np.ndarray, ArrayLike, float],
binary_precision: Union[np.ndarray, ArrayLike, float],
rho_1: Union[np.ndarray, ArrayLike, float],
rho_2: Union[np.ndarray, ArrayLike, float],
rho_3: Union[np.ndarray, ArrayLike, float],
pi_1: Union[np.ndarray, ArrayLike, float],
pi_2: Union[np.ndarray, ArrayLike, float],
pi_3: Union[np.ndarray, ArrayLike, float],
mu_1: Union[np.ndarray, ArrayLike, float],
mu_2: Union[np.ndarray, ArrayLike, float],
mu_3: Union[np.ndarray, ArrayLike, float],
kappa_1: Union[np.ndarray, ArrayLike, float],
kappa_2: Union[np.ndarray, ArrayLike, float],
input_data: List[np.ndarray],
response_function: Callable,
model_type: str,
n_levels: int,
omega_1: Union[np.ndarray, ArrayLike, float] = -3.0,
omega_2: Union[np.ndarray, ArrayLike, float] = -3.0,
omega_3: Union[np.ndarray, ArrayLike, float] = -3.0,
continuous_precision: Union[np.ndarray, ArrayLike, float] = 1e4,
binary_precision: Union[np.ndarray, ArrayLike, float] = np.inf,
rho_1: Union[np.ndarray, ArrayLike, float] = 0.0,
rho_2: Union[np.ndarray, ArrayLike, float] = 0.0,
rho_3: Union[np.ndarray, ArrayLike, float] = 0.0,
pi_1: Union[np.ndarray, ArrayLike, float] = 1.0,
pi_2: Union[np.ndarray, ArrayLike, float] = 1.0,
pi_3: Union[np.ndarray, ArrayLike, float] = 1.0,
mu_1: Union[np.ndarray, ArrayLike, float] = 0.0,
mu_2: Union[np.ndarray, ArrayLike, float] = 0.0,
mu_3: Union[np.ndarray, ArrayLike, float] = 0.0,
kappa_1: Union[np.ndarray, ArrayLike, float] = 1.0,
kappa_2: Union[np.ndarray, ArrayLike, float] = 1.0,
input_data: List[np.ndarray] = [np.nan],
response_function: Optional[Callable] = None,
model_type: str = "continuous",
n_levels: int = 2,
response_function_parameters: List[Tuple] = [()],
time_steps: Optional[List] = None,
) -> float:
r"""Log probability from HGF model(s) given input data and parameter(s).
This function support broadcasting along the first axis:
- If the input data contains many time series, the function will automatically
create the corresponding number of HGF models and fit them separately.
- If a single input data is provided but some parameters have array-like inputs, the
number of HGF models will match the length of the arrays, using the value *i*
for the *i* th model. When floats are provided for some parameters, the same value
will be used for all HGF models.
- If multiple input data are provided with array-like inputs for some parameter, the
function will create and fit the models separately using the value *i* for the
*i* th model.
r"""HGF log-probability given input data, response function and parameters.
.. note::
This function support broadcasting along the first axis, which means that it can
fit multiple HGF when multiple data points are provided:
* If the input data contains many time series, the function will automatically
create the corresponding number of HGF models and fit them separately.
* If a single input data is provided but some parameters have array-like inputs,
the number of HGF models will match the length of the arrays, using the value
*i* for the *i* th model. When floats are provided for some parameters, the same
value will be used for all HGF models.
* If multiple input data are provided with array-like inputs for some parameter,
the function will create and fit the models separately using the value *i* for
the *i* th model.
Parameters
----------
omega_1 :
The :math:`\omega` parameter, or *evolution rate*, at the first level of the
The :math:`\omega_1` parameter, or *evolution rate*, at the first level of the
HGF. This parameter represents the tonic part of the variance (the part that is
not inherited from parents nodes).
not inherited from parents nodes). Defaults to `-3.0`.
omega_2 :
The :math:`\omega` parameter, or *evolution rate*, at the second level of the
The :math:`\omega_1` parameter, or *evolution rate*, at the second level of the
HGF. This parameter represents the tonic part of the variance (the part that is
not inherited from parents nodes).
not inherited from parents nodes). Defaults to `-3.0`.
omega_3 :
The :math:`\omega` parameter, or *evolution rate*, at the third level of the
The :math:`\omega_3` parameter, or *evolution rate*, at the third level of the
HGF. This parameter represents the tonic part of the variance (the part that is
not inherited from parents nodes). The value of this parameter will be ignored
when using a two-level HGF (`n_levels=2`).
when using a two-level HGF (`n_levels=2`). Defaults to `-3.0`.
continuous_precision :
The expected precision associated with the continuous input.
The expected precision associated with the continuous input. Defaults to `1e4`.
binary_precision :
The expected precision associated with the binary input.
The expected precision associated with the binary input. Defaults to np.inf`.
rho_1 :
The :math:`\rho` parameter at the first level of the HGF. This parameter
represents the drift of the random walk.
The :math:`\rho_1` parameter at the first level of the HGF. This parameter
represents the drift of the random walk. Defaults to `0.0`.
rho_2 :
The :math:`\rho` parameter at the second level of the HGF. This parameter
represents the drift of the random walk.
The :math:`\rho_2` parameter at the second level of the HGF. This parameter
represents the drift of the random walk. Defaults to `0.0`.
rho_3 :
The :math:`\rho` parameter at the first level of the HGF. This parameter
The :math:`\rho_3` parameter at the first level of the HGF. This parameter
represents the drift of the random walk. The value of this parameter will be
ignored when using a two-level HGF (`n_levels=2`).
ignored when using a two-level HGF (`n_levels=2`). Defaults to `0.0`.
pi_1 :
The :math:`\pi` parameter, or *precision*, at the first level of the HGF.
The :math:`\pi_1` parameter, or *precision*, at the first level of the HGF.
Defaults to `1.0`.
pi_2 :
The :math:`\pi` parameter, or *precision*, at the second level of the HGF.
The :math:`\pi_2` parameter, or *precision*, at the second level of the HGF.
Defaults to `1.0`.
pi_3 :
The :math:`\pi` parameter, or *precision*, at the third level of the HGF. The
The :math:`\pi_3` parameter, or *precision*, at the third level of the HGF. The
value of this parameter will be ignored when using a two-level HGF
(`n_levels=2`).
(`n_levels=2`). Defaults to `1.0`.
mu_1 :
The :math:`\mu` parameter, or *mean*, at the first level of the HGF.
The :math:`\mu_1` parameter, or *mean*, at the first level of the HGF. Defaults
to `0.0`.
mu_2 :
The :math:`\mu` parameter, or *mean*, at the second level of the HGF.
The :math:`\mu_2` parameter, or *mean*, at the second level of the HGF. Defaults
to `0.0`.
mu_3 :
The :math:`\mu` parameter, or *mean*, at the third level of the HGF. The value
The :math:`\mu_3` parameter, or *mean*, at the third level of the HGF. The value
of this parameter will be ignored when using a two-level HGF (`n_levels=2`).
Defaults to `0.0`.
kappa_1 :
The value of the :math:`\\kappa` parameter at the first level of the HGF. Kappa
The value of the :math:`\kappa_1` parameter at the first level of the HGF. Kappa
represents the phasic part of the variance (the part that is affected by the
parent nodes) and will define the strength of the connection between the node
and the parent node. Often fixed to `1`.
and the parent node. Defaults to `1.0`.
kappa_2 :
The value of the :math:`\\kappa` parameter at the second level of the HGF. Kappa
represents the phasic part of the variance (the part that is affected by the
parent nodes) and will define the strength of the connection between the node
and the parent node. Often fixed to `1`. The value of this parameter will be
ignored when using a two-level HGF (`n_levels=2`).
The value of the :math:`\kappa_2` parameter at the second level of the HGF.
Kappa represents the phasic part of the variance (the part that is affected by
the parent nodes) and will define the strength of the connection between the
node and the parent node. The value of this parameter will be ignored when
using a two-level HGF (`n_levels=2`). Defaults to `1.0`.
input_data :
List of input data. When `n` models should be fitted, the list contains `n` 1d
Numpy arrays. By default, the associated time vector is the integers vector
starting at `0`. A different time vector can be passed to the `time` argument.
Numpy arrays.
response_function :
The response function to use to compute the model surprise.
model_type :
The model type to use (can be "continuous" or "binary").
The model type to use (can be "continuous" or "binary"). Defaults to
`"continuous"`.
n_levels :
The number of hierarchies in the perceptual model (can be `2` or `3`). If
`None`, the nodes hierarchy is not created and might be provided afterwards
using :py:meth:`pyhgf.model.HGF.add_nodes`.
The number of hierarchies in the perceptual model (can be `2` or `3`). Defaults
to `2`.
response_function_parameters :
A list of tuples with the same length as the number of models. Each tuple
contains additional data and parameters that can be accessible to the response
functions.
time_steps :
List of 1d Numpy arrays containing the time vectors for each input time series.
If one of the list items is `None`, or if `None` is provided instead, the time
vector will default to the unit vector.
If one of the list items is `None`, or if `None` is provided instead. By default
all time steps are set to `1.0`.
Returns
-------
log_prob :
The sum of the log probabilities (or negative surprise).
The log probability of the HGF model (or negative surprise) given by the
response function.
"""
# number of models
Expand Down Expand Up @@ -256,17 +263,15 @@ def __init__(
vector. A different time vector can be passed to the `time_steps` argument.
time_steps :
List of 1d Numpy arrays containing the time_steps vectors for each input
time series. If one of the list items is `None`, or if `None` is provided
instead, the time_steps vector will default to an integers vector starting
at 0.
time series. By defaults, all time steps are set to `1.0`
model_type :
The model type to use (can be "continuous" or "binary").
The model type to use (can be "continuous" or "binary"). Defaults to
`"continuous"`.
n_levels :
The number of hierarchies in the perceptual model (can be `2` or `3`). If
`None`, the nodes hierarchy is not created and might be provided afterward
using `add_nodes()`.
The number of hierarchies in the perceptual model (can be `2` or `3`).
Defaults to `2`.
response_function :
The response function to use to compute the model surprise.
The response function to use to compute the model's surprise.
response_function_parameters :
A list of tuples with the same length as the number of models. Each tuple
contains additional data and parameters that can be accessible to the
Expand All @@ -292,20 +297,20 @@ def make_node(
self,
omega_1=np.array(-3.0),
omega_2=np.array(-3.0),
omega_3=np.array(0.0),
omega_3=np.array(-0.0),
continuous_precision=np.array(1e4),
binary_precision=np.inf,
rho_1=np.array(0.0),
rho_2=np.array(0.0),
rho_3=np.array(0.0),
pi_1=np.array(1e4),
pi_2=np.array(1e1),
pi_3=np.array(0.0),
pi_1=np.array(1.0),
pi_2=np.array(1.0),
pi_3=np.array(1.0),
mu_1=np.array(0.0),
mu_2=np.array(0.0),
mu_3=np.array(0.0),
kappa_1=np.array(1.0),
kappa_2=np.array(0.0),
kappa_2=np.array(1.0),
):
"""Initialize node structure."""
# Convert our inputs to symbolic variables
Expand Down Expand Up @@ -457,19 +462,16 @@ def __init__(
----------
input_data :
List of input data. When `n` models should be fitted, the list contains `n`
1d Numpy arrays. By default, the associated time vector is the unit
vector starting at `0`. A different time_steps vector can be passed to
the `time_steps` argument.
1d Numpy arrays.
time_steps :
List of 1d Numpy arrays containing the time_steps vectors for each input
time series. If one of the list items is `None`, or if `None` is provided
instead, the time vector will default to an integers vector starting at 0.
time series. By default, all time steps are set to `1.0`.
model_type :
The model type to use (can be "continuous" or "binary").
The model type to use (can be "continuous" or "binary"). Defaults to
`"continuous"`.
n_levels :
The number of hierarchies in the perceptual model (can be `2` or `3`). If
`None`, the nodes hierarchy is not created and might be provided afterwards
using `add_nodes()`.
The number of hierarchies in the perceptual model (can be `2` or `3`).
Defaults to `2`.
response_function :
The response function to use to compute the model surprise.
response_function_parameters :
Expand Down Expand Up @@ -503,22 +505,22 @@ def __init__(

def make_node(
self,
omega_1,
omega_2,
omega_3,
continuous_precision,
binary_precision,
rho_1,
rho_2,
rho_3,
pi_1,
pi_2,
pi_3,
mu_1,
mu_2,
mu_3,
kappa_1,
kappa_2,
omega_1=np.array(-3.0),
omega_2=np.array(-3.0),
omega_3=np.array(-0.0),
continuous_precision=np.array(1e4),
binary_precision=np.inf,
rho_1=np.array(0.0),
rho_2=np.array(0.0),
rho_3=np.array(0.0),
pi_1=np.array(1.0),
pi_2=np.array(1.0),
pi_3=np.array(1.0),
mu_1=np.array(0.0),
mu_2=np.array(0.0),
mu_3=np.array(0.0),
kappa_1=np.array(1.0),
kappa_2=np.array(1.0),
):
"""Convert inputs to symbolic variables."""
inputs = [
Expand Down
4 changes: 2 additions & 2 deletions src/pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def __init__(
The model type to use (can be `"continuous"` or `"binary"`).
n_levels :
The number of hierarchies in the perceptual model (can be `2` or `3`). If
`None`, the nodes hierarchy is not created and might be provided afterward
using `add_nodes()`. Defaults to `2` for a 2-level HGF.
`None`, the nodes hierarchy is not created and might be provided afterward.
Defaults to `2` for a 2-level HGF.
omega :
A dictionary containing the initial values for the :math:`\\omega` parameter
at different levels of the hierarchy. :math:`\\omega` represents the tonic
Expand Down

0 comments on commit 0d63429

Please sign in to comment.