Skip to content

Commit

Permalink
hodgkin huxley model
Browse files Browse the repository at this point in the history
  • Loading branch information
djpasseyjr committed Feb 6, 2024
1 parent 6ebc5fe commit 5a31f82
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 18 deletions.
1 change: 1 addition & 0 deletions interfere/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .geometric_brownian_motion import GeometricBrownianMotion
from .kuramoto import Kuramoto, KuramotoSakaguchi, StuartLandauKuramoto
from .lotka_voltera import LotkaVoltera, LotkaVolteraSDE
from .neuroscience import HodgkinHuxleyPyclustering, LEGIONPyclustering
from .ornstein_uhlenbeck import OrnsteinUhlenbeck
from .quadratic_sdes import Belozyorov3DQuad, Liping3DQuadFinance
from .simple_linear_sdes import (
Expand Down
9 changes: 7 additions & 2 deletions interfere/dynamics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def simulate(
X_do[i] = intervention(X_do[i], time_points[i])

# Compute next state
X_do[i+1] = self.step(X_do[i], rng)
X_do[i+1] = self.step(X_do[i], t=time_points[i], rng=rng)

# After the loop, apply interention to the last step
if intervention is not None:
Expand All @@ -296,7 +296,12 @@ def simulate(
return X_do

@abstractmethod
def step(self, x: np.ndarray, rng: np.random.mtrand.RandomState):
def step(
self,
x: np.ndarray,
t: float = None,
rng: np.random.mtrand.RandomState = None
):
"""Uses the current state to compute the next state of the system.
Args:
Expand Down
6 changes: 3 additions & 3 deletions interfere/dynamics/coupled_map_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
super().__init__(self.adjacency_matrix.shape[0], measurement_noise_std)


def step(self, x: np.ndarray, rng: np.random.mtrand.RandomState):
def step(self, x: np.ndarray, t: float, rng: np.random.mtrand.RandomState):
"""One step forward in time for a coupled map lattice
A coupled map lattice where coupling is determined by
Expand Down Expand Up @@ -248,7 +248,7 @@ def __init__(



def step(self, x: np.ndarray, rng: np.random.mtrand.RandomState):
def step(self, x: np.ndarray, t: float, rng: np.random.mtrand.RandomState):
"""One step forward in time for a stochastic coupled map lattice.
A stochastic coupled map lattice where coupling is determined by
Expand All @@ -259,7 +259,7 @@ def step(self, x: np.ndarray, rng: np.random.mtrand.RandomState):
where w[n] ~ N(0, sigma) and x_i is constrained to be in the interval
(self.x_min, self.x_max).
"""
x_next = super().step(x, rng)
x_next = super().step(x, t, rng)

# This check enables sigma == 0.0 to generate deterministic dynamics.
if self.sigma != 0.0:
Expand Down
19 changes: 6 additions & 13 deletions interfere/dynamics/kuramoto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,13 @@
from typing import Optional, Callable

import numpy as np
from pyclustering.nnet import conn_type
from pyclustering.nnet.fsync import fsync_network

from .base import StochasticDifferentialEquation, DEFAULT_RANGE
from .pyclustering_utils import CONN_TYPE_MAP
from ..utils import copy_doc


# Maps string arguments to pyclustering arguments
CONN_TYPE_MAP = {
"all_to_all": conn_type.ALL_TO_ALL,
"grid_four": conn_type.GRID_FOUR,
"grid_eight": conn_type.GRID_EIGHT,
"list_bdir": conn_type.LIST_BIDIR,
"dynamic": conn_type.DYNAMIC
}


def kuramoto_intervention_wrapper(
intervention: Callable[[np.ndarray, float], np.ndarray]
) -> Callable[[np.ndarray, float], np.ndarray]:
Expand Down Expand Up @@ -259,10 +249,13 @@ def __init__(
self.omega = omega
self.rho = rho
self.K = K
self.Sigma = sigma * np.diag(np.ones(dim))
self.sigma = sigma
self.type_conn = type_conn
self.convert_to_real = convert_to_real


# Make independent noise matrix.
self.Sigma = sigma * np.diag(np.ones(dim))

# Initialize the pyclustering model.
self.pyclustering_model = fsync_network(
dim, omega, rho, K, CONN_TYPE_MAP[type_conn])
Expand Down
241 changes: 241 additions & 0 deletions interfere/dynamics/neuroscience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from typing import Callable, Optional

import numpy as np
from pyclustering.nnet.hhn import hhn_network, hhn_parameters
from pyclustering.nnet.legion import legion_network, legion_parameters

from .base import (
StochasticDifferentialEquation, DEFAULT_RANGE, DiscreteTimeDynamics
)
from .pyclustering_utils import CONN_TYPE_MAP
from ..utils import copy_doc

# Default Hodgkin Huxley neural net parameters.
DEFAULT_HHN_PARAMS = hhn_parameters()
DEFAULT_HHN_PARAMS.deltah = 400

# Default LEGION parameters
DEFAULT_LEGION_PARAMETERS = legion_parameters()


class HodgkinHuxleyPyclustering(StochasticDifferentialEquation):

def __init__(
self,
stimulus: np.array,
sigma: float = 1,
parameters: hhn_parameters = DEFAULT_HHN_PARAMS,
type_conn: str = "all_to_all",
measurement_noise_std: Optional[np.ndarray] = None
):
"""
Args:
stimulus (np.ndarray): Array of stimulus for oscillators, number of
stimulus. Length equal to number of oscillators.
sigma (float): Scale of the independent stochastic noise added to
the system.
parameters (hhn_parameters): A pyclustering.nnet.hhn.hhn_paramerers
object.
type_conn (str): Type of connection between oscillators. One
of ["all_to_all", "grid_four", "grid_eight", "list_bdir",
"dynamic"]. See pyclustering.nnet.__init__::conn_type for
details.
measurement_noise_std (ndarray): None, or a vector with shape (n,)
where each entry corresponds to the standard deviation of the
measurement noise for that particular dimension of the dynamic
model. For example, if the dynamic model had two variables x1
and x2 and measurement_noise_std = [1, 10], then independent
gaussian noise with standard deviation 1 and 10 will be added to
x1 and x2 respectively at each point in time.
"""
dim = len(stimulus)
self.stimulus = stimulus
self.sigma = sigma
self.parameters = parameters
self.type_conn = type_conn

# Make independent noise matrix.
self.Sigma = sigma * np.diag(np.ones(dim))

super().__init__(dim, measurement_noise_std)

@copy_doc(StochasticDifferentialEquation.simulate)
def simulate(
self,
initial_condition: np.ndarray,
time_points: np.ndarray,
intervention: Optional[Callable[[np.ndarray, float], np.ndarray]]= None,
rng: np.random.mtrand.RandomState = DEFAULT_RANGE,
dW: Optional[np.ndarray] = None,
) -> np.ndarray:

# Initialize pyclustering model.
self.hhn_model = hhn_network(
self.dim,
self.stimulus,
self.parameters,
CONN_TYPE_MAP[self.type_conn],
ccore=False
)
# Overwrite pyclustering initial noise generation with noise
# controllable via the passed random state.
self.hhn_model._noise = [
rng.random() * 2.0 - 1.0
for i in range(self.hhn_model._num_osc)
]

# Allocate array to hold observed states.
m = len(time_points)
X_do = np.zeros((m, self.dim), dtype=initial_condition.dtype)

# Optionally apply intervention to initial condition
if intervention is not None:
initial_condition = intervention(
initial_condition.copy(),
time_points[0]
)
X_do[0, :] = initial_condition

# Asign initial condition to internal model.
self.hhn_model._membrane_potential = list(initial_condition)

# Compute timestep size.
dt = (time_points[-1] - time_points[0]) / m

if dW is None:
# Generate sequence of weiner increments
dW = rng.normal(0.0, np.sqrt(dt), (m - 1, self.dim))

# Since each neuron has one observed state and three unobserved, we
# create a matrix to house the current state of the model. Additionally
# The HH model contains neurons that are not observed. We allocate space
# for these too.
num_neurons = self.hhn_model._num_osc + len(
self.hhn_model._central_element)
N = np.zeros((num_neurons, 4))

for i, t in zip(range(m - 1), time_points):
# Current state of the model.
x = X_do[i, :]

# Noise differential.
dw = self.noise(x, t) @ dW[i, :]

# Deterministic change in neuron states.
dN = self.drift(N, t) * dt

# Add noise (to membrane potential only).
dN[:self.dim, 0] += dw

# Next state of the model via Euler-Marayama update.
next_N = N + dN

# Optionally apply the intervention (to membrane potential only).
if intervention is not None:
next_N[:self.dim, :] = intervention(
next_N[:self.dim, :], time_points[i + 1])

# Intervene on pyclustering model internal potential
self.hhn_model._membrane_potential = list(next_N[:self.dim, 0])

# Store membrane potential only.
X_do[i + 1, :] = next_N[:self.dim, 0]

# Update internal model neuron states
self.step(next_N, t, dt, rng)

# Update neuron state array.
N = next_N

if self.measurement_noise_std is not None:
# Don't add measurement noise to initial condition
X_do[1:, :] = self.add_measurement_noise(X_do[1:, :], rng)

return X_do


def step(
self, N: np.ndarray, t: float, dt: float, rng: np.random.RandomState):
"""Discrete time dynamics, to be computed after continuous time updates.
Args:
N (np.ndarray): 2D array. Dimensions = (num_neurons x 4). Contains
the current state of the model. Each row represents a neuron
and the columns contain, membrane potential, active sodium
channels, inactive sodium channels and active potassium
channels respectively.
t (float): Current time.
dt (float): Time step size.
rng (np.random.RandomState)
"""
# Adapted from pyclustering.nnet.hhn_network._calculate_states().
num_periph = self.hhn_model._num_osc

# Noise generation. I copied it don't judge me.
self.hhn_model._noise = [
1.0 + 0.01 * (rng.random() * 2.0 - 1.0)
for i in range(num_periph)
]

# Updating states of peripheral neurons
self.hhn_model._hhn_network__update_peripheral_neurons(
t, dt, *N[:num_periph, :].T)

# Updation states of central neurons
self.hhn_model._hhn_network__update_central_neurons(
t, *N[num_periph:, :].T)


def drift(self, N, t):
"""Computes the deterministic derivative of the model."""

num_neurons = self.hhn_model._num_osc + len(
self.hhn_model._central_element)

# We initialize an array of derivatives. The dimensions are
# (num_neurons x 4) because each neuron has four states: membrane
# potential, active sodium channels, inactive sodium channels and
# active potassium channels.
dN = np.zeros((num_neurons, 4))

# Peripheral neuron derivatives.
for i in range(self.hhn_model._num_osc):

# Collect peripheral neuron state into a list.
neuron_state = [
self.hhn_model._membrane_potential[i],
self.hhn_model._active_cond_sodium[i],
self.hhn_model._inactive_cond_sodium[i],
self.hhn_model._active_cond_potassium[i]
]

# Compute the derivative of each state.
dN[i] = self.hhn_model.hnn_state(neuron_state, t, i)

# Central neuron derivatives.
for i in range(len(self.hhn_model._central_element)):

# Collect central neuron state into a list.
central_neuron_state = [
self.hhn_model._central_element[i].membrane_potential,
self.hhn_model._central_element[i].active_cond_sodium,
self.hhn_model._central_element[i].inactive_cond_sodium,
self.hhn_model._central_element[i].active_cond_potassium
]

# Compute the derivative of each state.
dN[self.hhn_model._num_osc + i] = self.hhn_model.hnn_state(
central_neuron_state, t, self.hhn_model._num_osc + i
)

return dN


def noise(self, x: np.ndarray, t: float):
return self.Sigma


class LEGIONPyclustering(DiscreteTimeDynamics):
pass

10 changes: 10 additions & 0 deletions interfere/dynamics/pyclustering_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pyclustering.nnet import conn_type

# Maps string arguments to pyclustering arguments
CONN_TYPE_MAP = {
"all_to_all": conn_type.ALL_TO_ALL,
"grid_four": conn_type.GRID_FOUR,
"grid_eight": conn_type.GRID_EIGHT,
"list_bdir": conn_type.LIST_BIDIR,
"dynamic": conn_type.DYNAMIC
}
7 changes: 7 additions & 0 deletions tests/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,11 @@ def test_kuramoto():
model = interfere.dynamics.KuramotoSakaguchi(omega, K, A, A, sigma)
check_simulate_method(model)
model = interfere.dynamics.StuartLandauKuramoto(omega, rho, K, sigma)
check_simulate_method(model)


def test_hodgkin_huxley():
stimulus = [0, 0, 0, 15, 15, 15, 25, 25, 25, 40]
sigma = 0.1
model = interfere.dynamics.HodgkinHuxleyPyclustering(stimulus, sigma)
check_simulate_method(model)

0 comments on commit 5a31f82

Please sign in to comment.