Skip to content

Commit

Permalink
default parmeter values and merge mu and muhat and pi and pihat in th…
Browse files Browse the repository at this point in the history
…e node creation
  • Loading branch information
LegrandNico committed Sep 6, 2023
1 parent 717a15f commit b6a6516
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 26 deletions.
37 changes: 13 additions & 24 deletions pyhgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
"2": -3.0,
"3": -3.0,
},
kappas: Dict = {"1": 1.0, "2": 0.0},
kappas: Dict = {"1": 1.0, "2": 1.0},
eta0: Union[float, np.ndarray, ArrayLike] = 0.0,
eta1: Union[float, np.ndarray, ArrayLike] = 1.0,
binary_precision: Union[float, np.ndarray, ArrayLike] = jnp.inf,
Expand Down Expand Up @@ -575,7 +575,6 @@ def add_input_node(
"pi_3": 1.0,
"mu_1": 0.0,
"mu_2": -jnp.log(categorical_parameters["n_categories"] - 1),
"mu_hat_2": -jnp.log(categorical_parameters["n_categories"] - 1),
"mu_3": 0.0,
"omega_2": -4.0,
"omega_3": -4.0,
Expand All @@ -595,9 +594,7 @@ def add_value_parent(
children_idxs: Union[List, int],
value_coupling: Union[float, np.ndarray, ArrayLike] = 1.0,
mu: Union[float, np.ndarray, ArrayLike] = 0.0,
mu_hat: Union[float, np.ndarray, ArrayLike] = 0.0,
pi: Union[float, np.ndarray, ArrayLike] = 1.0,
pi_hat: Union[float, np.ndarray, ArrayLike] = 1.0,
omega: Union[float, np.ndarray, ArrayLike] = -4.0,
rho: Union[float, np.ndarray, ArrayLike] = 0.0,
additional_parameters: Optional[Dict] = None,
Expand All @@ -612,14 +609,11 @@ def add_value_parent(
The value_coupling between the child and parent node. This is will be
appended to the `psis` parameters in the parent and child node(s).
mu :
The mean of the Gaussian distribution.
mu_hat :
The expected mean of the Gaussian distribution for the next observation.
The mean of the Gaussian distribution. This value is passed both to the
current and expected states.
pi :
The precision of the Gaussian distribution (inverse variance).
pi_hat :
The expected precision of the Gaussian distribution (inverse variance) for
the next observation.
The precision of the Gaussian distribution (inverse variance). This
value is passed both to the current and expected states.
omega :
The tonic part of the variance (the variance that is not inherited from
volatility parent(s)).
Expand All @@ -639,9 +633,9 @@ def add_value_parent(
# parent's parameter
node_parameters = {
"mu": mu,
"muhat": mu_hat,
"muhat": mu,
"pi": pi,
"pihat": pi_hat,
"pihat": pi,
"kappas_children": None,
"kappas_parents": None,
"psis_children": tuple(value_coupling for _ in range(len(children_idxs))),
Expand Down Expand Up @@ -699,9 +693,7 @@ def add_volatility_parent(
children_idxs: Union[List, int],
volatility_coupling: Union[float, np.ndarray, ArrayLike] = 1.0,
mu: Union[float, np.ndarray, ArrayLike] = 0.0,
mu_hat: Union[float, np.ndarray, ArrayLike] = 0.0,
pi: Union[float, np.ndarray, ArrayLike] = 1.0,
pi_hat: Union[float, np.ndarray, ArrayLike] = 1.0,
omega: Union[float, np.ndarray, ArrayLike] = -4.0,
rho: Union[float, np.ndarray, ArrayLike] = 0.0,
additional_parameters: Optional[Dict] = None,
Expand All @@ -716,14 +708,11 @@ def add_volatility_parent(
The volatility coupling between the child and parent node. This is will be
appended to the `kappas` parameters in the parent and child node(s).
mu :
The mean of the Gaussian distribution.
mu_hat :
The expected mean of the Gaussian distribution for the next observation.
The mean of the Gaussian distribution. This value is passed both to the
current and expected states.
pi :
The precision of the Gaussian distribution (inverse variance).
pi_hat :
The expected precision of the Gaussian distribution (inverse variance) for
the next observation.
The precision of the Gaussian distribution (inverse variance). This
value is passed both to the current and expected states.
omega :
The tonic part of the variance (the variance that is not inherited from
volatility parent(s)).
Expand All @@ -743,9 +732,9 @@ def add_volatility_parent(
# parent's parameter
node_parameters = {
"mu": mu,
"muhat": mu_hat,
"muhat": mu,
"pi": pi,
"pihat": pi_hat,
"pihat": pi,
"kappas_children": tuple(
volatility_coupling for _ in range(len(children_idxs))
),
Expand Down
2 changes: 0 additions & 2 deletions pyhgf/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def fill_categorical_state_node(
value_coupling=1.0,
pi=implied_binary_parameters["pi_1"],
mu=implied_binary_parameters["mu_1"],
mu_hat=implied_binary_parameters["mu_1"],
)

# add the continuous parent node
Expand All @@ -246,7 +245,6 @@ def fill_categorical_state_node(
+ implied_binary_parameters["n_categories"],
value_coupling=1.0,
mu=implied_binary_parameters["mu_2"],
mu_hat=implied_binary_parameters["mu_hat_2"],
pi=implied_binary_parameters["pi_2"],
omega=implied_binary_parameters["omega_2"],
)
Expand Down

0 comments on commit b6a6516

Please sign in to comment.