Skip to content

Commit

Permalink
add binray_precision as an argument of HGFDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Sep 21, 2023
1 parent 2864a2d commit 4fe7bbc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
41 changes: 27 additions & 14 deletions src/pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def hgf_logp(
omega_2: Union[np.ndarray, ArrayLike, float],
omega_3: Union[np.ndarray, ArrayLike, float],
continuous_precision: Union[np.ndarray, ArrayLike, float],
binary_precision: Union[np.ndarray, ArrayLike, float],
rho_1: Union[np.ndarray, ArrayLike, float],
rho_2: Union[np.ndarray, ArrayLike, float],
rho_3: Union[np.ndarray, ArrayLike, float],
Expand Down Expand Up @@ -66,7 +67,9 @@ def hgf_logp(
not inherited from parents nodes). The value of this parameter will be ignored
when using a two-level HGF (`n_levels=2`).
continuous_precision :
Represent the expected precision associated with the continuous input.
The expected precision associated with the continuous input.
binary_precision :
The expected precision associated with the binary input.
rho_1 :
The :math:`\rho` parameter at the first level of the HGF. This parameter
represents the drift of the random walk.
Expand Down Expand Up @@ -139,6 +142,7 @@ def hgf_logp(
_omega_2,
_omega_3,
_continuous_precision,
_binary_precision,
_rho_1,
_rho_2,
_rho_3,
Expand All @@ -156,6 +160,7 @@ def hgf_logp(
omega_2,
omega_3,
continuous_precision,
binary_precision,
rho_1,
rho_2,
rho_3,
Expand Down Expand Up @@ -209,13 +214,13 @@ def hgf_logp(
initial_pi=initial_pi,
omega=omega,
continuous_precision=_continuous_precision[i],
binary_precision=_binary_precision[i],
rho=rho,
kappas=kappas,
model_type=model_type,
n_levels=n_levels,
eta0=0.0,
eta1=1.0,
binary_precision=jnp.inf,
verbose=False,
)
.input_data(input_data=input_data[i], time_steps=time_steps[i])
Expand Down Expand Up @@ -279,7 +284,7 @@ def __init__(
model_type=model_type,
response_function_parameters=response_function_parameters,
),
argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
)
)

Expand All @@ -289,6 +294,7 @@ def make_node(
omega_2=np.array(-3.0),
omega_3=np.array(0.0),
continuous_precision=np.array(1e4),
binary_precision=np.inf,
rho_1=np.array(0.0),
rho_2=np.array(0.0),
rho_3=np.array(0.0),
Expand All @@ -308,6 +314,7 @@ def make_node(
pt.as_tensor_variable(omega_2),
pt.as_tensor_variable(omega_3),
pt.as_tensor_variable(continuous_precision),
pt.as_tensor_variable(binary_precision),
pt.as_tensor_variable(rho_1),
pt.as_tensor_variable(rho_2),
pt.as_tensor_variable(rho_3),
Expand All @@ -334,6 +341,7 @@ def perform(self, node, inputs, outputs):
grad_omega_2,
grad_omega_3,
grad_continuous_precision,
grad_binary_precision,
grad_rho_1,
grad_rho_2,
grad_rho_3,
Expand All @@ -353,17 +361,18 @@ def perform(self, node, inputs, outputs):
outputs[3][0] = np.asarray(
grad_continuous_precision, dtype=node.outputs[3].dtype
)
outputs[4][0] = np.asarray(grad_rho_1, dtype=node.outputs[4].dtype)
outputs[5][0] = np.asarray(grad_rho_2, dtype=node.outputs[5].dtype)
outputs[6][0] = np.asarray(grad_rho_3, dtype=node.outputs[6].dtype)
outputs[7][0] = np.asarray(grad_pi_1, dtype=node.outputs[7].dtype)
outputs[8][0] = np.asarray(grad_pi_2, dtype=node.outputs[8].dtype)
outputs[9][0] = np.asarray(grad_pi_3, dtype=node.outputs[9].dtype)
outputs[10][0] = np.asarray(grad_mu_1, dtype=node.outputs[10].dtype)
outputs[11][0] = np.asarray(grad_mu_2, dtype=node.outputs[11].dtype)
outputs[12][0] = np.asarray(grad_mu_3, dtype=node.outputs[12].dtype)
outputs[13][0] = np.asarray(grad_kappa_1, dtype=node.outputs[13].dtype)
outputs[14][0] = np.asarray(grad_kappa_2, dtype=node.outputs[14].dtype)
outputs[4][0] = np.asarray(grad_binary_precision, dtype=node.outputs[4].dtype)
outputs[5][0] = np.asarray(grad_rho_1, dtype=node.outputs[5].dtype)
outputs[6][0] = np.asarray(grad_rho_2, dtype=node.outputs[6].dtype)
outputs[7][0] = np.asarray(grad_rho_3, dtype=node.outputs[7].dtype)
outputs[8][0] = np.asarray(grad_pi_1, dtype=node.outputs[8].dtype)
outputs[9][0] = np.asarray(grad_pi_2, dtype=node.outputs[9].dtype)
outputs[10][0] = np.asarray(grad_pi_3, dtype=node.outputs[10].dtype)
outputs[11][0] = np.asarray(grad_mu_1, dtype=node.outputs[11].dtype)
outputs[12][0] = np.asarray(grad_mu_2, dtype=node.outputs[12].dtype)
outputs[13][0] = np.asarray(grad_mu_3, dtype=node.outputs[13].dtype)
outputs[14][0] = np.asarray(grad_kappa_1, dtype=node.outputs[14].dtype)
outputs[15][0] = np.asarray(grad_kappa_2, dtype=node.outputs[15].dtype)


class HGFDistribution(Op):
Expand Down Expand Up @@ -497,6 +506,7 @@ def make_node(
omega_2,
omega_3,
continuous_precision,
binary_precision,
rho_1,
rho_2,
rho_3,
Expand All @@ -515,6 +525,7 @@ def make_node(
pt.as_tensor_variable(omega_2),
pt.as_tensor_variable(omega_3),
pt.as_tensor_variable(continuous_precision),
pt.as_tensor_variable(binary_precision),
pt.as_tensor_variable(rho_1),
pt.as_tensor_variable(rho_2),
pt.as_tensor_variable(rho_3),
Expand Down Expand Up @@ -543,6 +554,7 @@ def grad(self, inputs, output_gradients):
grad_omega_2,
grad_omega_3,
grad_continuous_precision,
grad_binary_precision,
grad_rho_1,
grad_rho_2,
grad_rho_3,
Expand All @@ -564,6 +576,7 @@ def grad(self, inputs, output_gradients):
output_gradient * grad_omega_2,
output_gradient * grad_omega_3,
output_gradient * grad_continuous_precision,
output_gradient * grad_binary_precision,
output_gradient * grad_rho_1,
output_gradient * grad_rho_2,
output_gradient * grad_rho_3,
Expand Down
20 changes: 16 additions & 4 deletions tests/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_hgf_logp(self):
omega_2=-3.0,
omega_3=jnp.nan,
continuous_precision=np.array(1e4),
binary_precision=np.inf,
rho_1=0.0,
rho_2=0.0,
rho_3=jnp.nan,
Expand Down Expand Up @@ -76,6 +77,7 @@ def test_hgf_logp(self):
omega_2=jnp.array(-6.0),
omega_3=jnp.nan,
continuous_precision=jnp.nan,
binary_precision=jnp.inf,
rho_1=jnp.array(0.0),
rho_2=jnp.array(0.0),
rho_3=jnp.nan,
Expand Down Expand Up @@ -107,7 +109,7 @@ def test_grad_logp(self):
model_type="continuous",
response_function_parameters=None,
),
argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
),
)

Expand All @@ -116,6 +118,7 @@ def test_grad_logp(self):
omega_2,
omega_3,
continuous_precision,
binary_precision,
rho_1,
rho_2,
rho_3,
Expand All @@ -132,6 +135,7 @@ def test_grad_logp(self):
np.array(-3.0),
np.array(0.0),
np.array(1e4),
np.nan,
np.array(0.0),
np.array(0.0),
np.array(0.0),
Expand Down Expand Up @@ -164,15 +168,16 @@ def test_grad_logp(self):
model_type="binary",
response_function_parameters=None,
),
argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
argnums=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
),
)

(
omega_1,
omega_2,
omega_3,
input_precision,
continuous_precision,
binary_precision,
rho_1,
rho_2,
rho_3,
Expand All @@ -188,6 +193,7 @@ def test_grad_logp(self):
np.array(0.0),
np.array(-2.0),
np.array(0.0),
np.array(1e4),
np.inf,
np.array(0.0),
np.array(0.0),
Expand Down Expand Up @@ -227,6 +233,7 @@ def test_aesara_logp(self):
omega_2=np.array(-3.0),
omega_3=np.array(0.0),
continuous_precision=np.array(1e4),
binary_precision=np.inf,
rho_1=np.array(0.0),
rho_2=np.array(0.0),
rho_3=np.array(0.0),
Expand Down Expand Up @@ -261,7 +268,8 @@ def test_aesara_logp(self):
omega_1=np.inf,
omega_2=-6.0,
omega_3=np.inf,
continuous_precision=np.inf,
continuous_precision=np.array(1e4),
binary_precision=np.inf,
rho_1=0.0,
rho_2=0.0,
rho_3=np.inf,
Expand Down Expand Up @@ -299,6 +307,7 @@ def test_aesara_grad_logp(self):
omega_1=-3.0,
omega_2=-3.0,
continuous_precision=np.array(1e4),
binary_precision=np.nan,
rho_1=0.0,
rho_2=0.0,
pi_1=1e4,
Expand Down Expand Up @@ -329,6 +338,7 @@ def test_aesara_grad_logp(self):
omega_1=jnp.inf,
omega_2=-6.0,
continuous_precision=jnp.nan,
binary_precision=jnp.inf,
rho_1=0.0,
rho_2=0.0,
pi_1=0.0,
Expand Down Expand Up @@ -366,6 +376,7 @@ def test_pymc_sampling(self):
omega_1=np.array(0.0),
omega_2=omega_2,
continuous_precision=np.array(1e4),
binary_precision=np.inf,
rho_1=np.array(0.0),
rho_2=np.array(0.0),
pi_1=np.array(1e4),
Expand Down Expand Up @@ -417,6 +428,7 @@ def test_pymc_sampling(self):
omega_1=np.inf,
omega_2=omega_2,
continuous_precision=np.nan,
binary_precision=np.inf,
rho_1=0.0,
rho_2=0.0,
pi_1=0.0,
Expand Down

0 comments on commit 4fe7bbc

Please sign in to comment.