Skip to content

Commit

Permalink
Fix MyPy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Aug 22, 2022
1 parent 96f1b45 commit ceb243d
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install isort flake8 black==22.6.0 mypy==0.971
pip install isort==5.10.1 flake8 black==22.6.0 mypy==0.971
- name: Run tests and coverage
run: |
mypy ./ghgf/ --ignore-missing-imports
Expand Down
43 changes: 24 additions & 19 deletions ghgf/jax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Author: Nicolas Legrand <nicolas.legrand@cfin.au.dk>

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import jax.numpy as jnp
from jax import jit
Expand All @@ -11,8 +11,8 @@ def update_parents(
node_parameters: Dict[str, float],
value_parents: Optional[Tuple],
volatility_parents: Optional[Tuple],
old_time: int,
new_time: int,
old_time: Union[float, DeviceArray],
new_time: Union[float, DeviceArray],
) -> Tuple[Dict[str, float], Optional[Tuple], Optional[Tuple]]:
"""Update the value parents from a given node. If the node has value or volatility
parents, they will be updated recursively.
Expand Down Expand Up @@ -64,7 +64,7 @@ def update_parents(
return node_parameters, value_parents, volatility_parents

# Time interval
t = new_time - old_time
t = jnp.subtract(new_time, old_time)

pihat = node_parameters["pihat"]

Expand Down Expand Up @@ -259,7 +259,7 @@ def update_input_parents(
new_time: jnp.DeviceArray,
old_time: jnp.DeviceArray,
) -> Optional[
Union[
Tuple[
jnp.DeviceArray, Tuple[Dict[str, DeviceArray], Optional[Tuple], Optional[Tuple]]
]
]:
Expand All @@ -273,20 +273,22 @@ def update_input_parents(
Parameters
----------
input_node : tuple
The parameter, value parent and volatility parent of the input node.
The parameter, value parent and volatility parent of the input node. The
volatility and value parents contain their own value and volatility parents,
this structure being nested up to the higher level of the hierarchy.
value : DeviceArray
The new input value(s).
The new input value that is observed by the model at time t=`new_time`.
new_time : DeviceArray
The time point (float).
The current time point.
old_time : DeviceArray
The time point (float) of the previous observed value.
The time point of the previous observed value.
Returns
-------
surprise : jnp.DeviceArray
The gaussian surprise given the value(s) presented at time `new_time`.
The gaussian surprise given the value(s) presented at t=`new_time`.
new_input_node : tuple
The input node structure after recursive update.
The input node structure after recursively updating all the nodes.
See also
--------
Expand All @@ -299,11 +301,11 @@ def update_input_parents(
return None

# Time interval
t = new_time - old_time
t = jnp.subtract(new_time, old_time)

# Add a bias to the input value
if input_node_parameters["bias"] is not None:
value += input_node_parameters["bias"]
value = jnp.add(value, input_node_parameters["bias"])

lognoise = input_node_parameters["omega"]

Expand Down Expand Up @@ -426,7 +428,9 @@ def update_input_parents(
new_nu,
]

pi_vo_pa = pihat_vo_pa + 0.5 * input_node_parameters["kappas"] ** 2 * (1 + vope)
pi_vo_pa = pihat_vo_pa + 0.5 * jnp.square(input_node_parameters["kappas"]) * (
1 + vope
)
pi_vo_pa = jnp.where(pi_vo_pa <= 0, jnp.nan, pi_vo_pa)

# Compute new muhat
Expand All @@ -441,7 +445,10 @@ def update_input_parents(
driftrate += p * vo_pa_va_pa[0]["mu"]

muhat_vo_pa = vo_pa_node_parameters["mu"] + t * driftrate
mu_vo_pa = muhat_vo_pa + 0.5 * input_node_parameters["kappas"] / pi_vo_pa * vope
mu_vo_pa = (
muhat_vo_pa
+ jnp.multiply(0.5, input_node_parameters["kappas"]) / pi_vo_pa * vope
)

# Update node's parameters
vo_pa_node_parameters["pihat"] = pihat_vo_pa
Expand Down Expand Up @@ -501,14 +508,12 @@ def gaussian_surprise(
return jnp.array(0.5) * (
jnp.log(jnp.array(2.0) * jnp.pi)
- jnp.log(pihat)
+ pihat * jnp.square(x - muhat)
+ pihat * jnp.square(jnp.subtract(x, muhat))
)


@jit
def loop_inputs(
res: Tuple, el: Tuple
) -> Tuple[Tuple[List, Dict[str, DeviceArray]], DeviceArray]:
def loop_inputs(res: Tuple, el: Tuple) -> Tuple[Tuple, Tuple]:
"""The HGF function to be scanned by JAX. One time step updating node structure and
returning the new node structure with time, value and surprise.
Expand Down
28 changes: 19 additions & 9 deletions ghgf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Dict, Optional, Tuple

import jax.numpy as jnp
from jax.interpreters.xla import DeviceArray
import numpy as np
from jax.lax import scan

from ghgf.jax import loop_inputs, node_validation
Expand All @@ -20,7 +20,8 @@ class HGF(object):
verbose : bool
Verbosity level.
n_levels : int
The number of hierarchies in the model. Cannot be less than 2.
The number of hierarchies in the model, including the input vector. Cannot be
less than 2.
model_type : str
The model implemented (can be `"continous"` or `"binary"`).
nodes : tuple
Expand All @@ -29,8 +30,15 @@ class HGF(object):
After oberving the data using the `input_data` method, the output of the model
are stored in the `hgf_results` dictionary.
Examples
--------
Methods
-------
add_nodes(nodes: Tuple)
Add a custom node structure to the model.
input_data(input_data: np.ndarray)
Input data and update the node structure accordingly.
plot_trajectories(backend: str = "matplotlib", **kwargs)
Plot the trajectories of the different parameters once the data has been
observed.
"""

Expand All @@ -40,7 +48,7 @@ def __init__(
model_type: str = "continuous",
initial_mu: Dict[str, Optional[float]] = {"1": 0.0, "2": 0.0},
initial_pi: Dict[str, Optional[float]] = {"1": 1.0, "2": 1.0},
omega_input: DeviceArray = log(1e-4),
omega_input: float = log(1e-4),
omega: Dict[str, Optional[float]] = {"1": -3.0, "2": -3.0},
kappas: Dict[str, Optional[float]] = {"1": 1.0},
rho: Dict[str, Optional[float]] = {"1": 0.0, "2": 0.0},
Expand Down Expand Up @@ -74,7 +82,7 @@ def __init__(
`{"1": -10.0, "2": -10.0}` for a 2-levels model. This parameters only when
`model_type="GRW"`.
omega_input : float
Default value sets to `np.log(1e-4)`. Represents the noise associated with
Default value sets to `log(1e-4)`. Represents the noise associated with
the input.
rho : dict
Dictionary containing the initial values for the `rho` parameter at
Expand Down Expand Up @@ -156,7 +164,7 @@ def __init__(
"muhat": jnp.nan,
"pi": initial_pi["2"],
"pihat": jnp.nan,
"kappas": (kappas["2"],),
"kappas": (kappas["2"],), # type: ignore
"nu": jnp.nan,
"psis": None,
"omega": omega["2"],
Expand Down Expand Up @@ -199,9 +207,12 @@ def add_nodes(self, nodes: Tuple):

def input_data(
self,
input_data,
input_data: np.ndarray,
):

# Store the input data
self.data = input_data

# Initialise the first values
res_init = (
self.input_node,
Expand All @@ -221,7 +232,6 @@ def input_data(
self.hgf_results[
"final"
] = final # The commulative update of the nodes and results
self.hgf_results["data"] = input_data # The input data

return self

Expand Down
42 changes: 21 additions & 21 deletions ghgf/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,17 @@ def __init__(

def make_node(
self,
omega_1=jnp.array(-3.0),
omega_2=jnp.array(-3.0),
omega_1=np.array(-3.0),
omega_2=np.array(-3.0),
omega_input=np.log(1e-4),
rho_1=jnp.array(0.0),
rho_2=jnp.array(0.0),
pi_1=jnp.array(1e4),
pi_2=jnp.array(1e1),
mu_1=jnp.array(0.0),
mu_2=jnp.array(0.0),
kappa_1=jnp.array(1.0),
bias=jnp.array(0.0),
rho_1=np.array(0.0),
rho_2=np.array(0.0),
pi_1=np.array(1e4),
pi_2=np.array(1e1),
mu_1=np.array(0.0),
mu_2=np.array(0.0),
kappa_1=np.array(1.0),
bias=np.array(0.0),
):

# Convert our inputs to symbolic variables
Expand Down Expand Up @@ -281,17 +281,17 @@ def __init__(

def make_node(
self,
omega_1: float,
omega_2: float,
omega_input: float,
rho_1: float,
rho_2: float,
pi_1: float,
pi_2: float,
mu_1: float,
mu_2: float,
kappa_1: float,
bias: float,
omega_1,
omega_2,
omega_input,
rho_1,
rho_2,
pi_1,
pi_2,
mu_1,
mu_2,
kappa_1,
bias,
):

# Convert our inputs to symbolic variables
Expand Down
6 changes: 5 additions & 1 deletion ghgf/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def hrd_behaviors(
# just a dynamic array indexing.
# ----------------------------------------------------------------------------------
# First define a function to extract values for one trigger using dynamic_slice()
def extract(trigger: int, new_mu1=new_mu1, new_pi1=new_pi1):
def extract(
trigger: np.ndarray,
new_mu1: DeviceArray = new_mu1,
new_pi1: DeviceArray = new_pi1,
):
return (
dynamic_slice(new_mu1, (trigger,), (5000,)).mean(),
dynamic_slice(new_pi1, (trigger,), (5000,)).mean(),
Expand Down

0 comments on commit ceb243d

Please sign in to comment.