JAX implementation of phase response curve
prax
is created based on jax
, and please install jax
at first. See JAX page for installation.
After the installation of jax
, prax
can be installed with pip directly from GitHub, with the following command:
pip install git+https://github.com/yonesuke/prax.git
We give an example on how to use this package with Van der Pol oscillator.
First, import packages:
import jax.numpy as jnp
from prax import Oscillator
from jax.config import config; config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
Create an oscillator class by inheriting Oscillator
class:
class VanderPol(Oscillator):
def __init__(self, mu, dt=0.01, eps=10**-5):
super().__init__(n_dim=2, dt=dt, eps=eps)
self.mu = mu
def forward(self, state):
x, y = state
vx = y
vy = self.mu * (1.0 - x*x) * y - x
return jnp.array([vx, vy])
model = VanderPol(mu=0.2)
Find periodic orbit (choose init_val
nicely so that it goes to periodic orbit):
init_val = jnp.array([0.1, 0.2])
model.find_periodic_orbit(init_val)
print(model.period) # 6.3088767
plt.plot(model.ts, model.periodic_orbit)
Calculate phase response curve:
model.calc_phase_response()
plt.plot(model.ts, model.phase_response_curve)
See examples directory!!
-
Van der Pol equation [code]
class VanderPol(Oscillator): def __init__(self, mu, dt=0.01, eps=10**-5): super().__init__(n_dim=2, dt=dt, eps=eps) self.mu = mu def forward(self, state): x, y = state vx = y vy = self.mu * (1.0 - x*x) * y - x return jnp.array([vx, vy]) model = VanderPol(mu=0.2)
-
Stuart Landau equation [code]
class StuartLandau(Oscillator): def __init__(self, dt=0.01, eps=10**-5): super().__init__(n_dim=2, dt=dt, eps=eps) def forward(self, state): x, y = state vx = x - y - x * (x * x + y * y) vy = x + y - y * (x * x + y * y) return jnp.array([vx, vy]) model = StuartLandau()
-
FitzHugh-Nagumo equation [code]
class FitzHughNagumo(Oscillator): def __init__(self, params, dt=0.01, eps=10**-5): super().__init__(n_dim=2, dt=dt, eps=eps) self.a, self.b, self.c = params def forward(self, state): x, y = state vx = self.c * (x - x ** 3 - y) vy = x - self.b * y + self.a return jnp.array([vx, vy]) model = FitzHughNagumo(params=[0.2, 0.5, 10.0])
-
Brusselator equation [code]
class Brusselator(Oscillator): def __init__(self, params, dt=0.01, eps=10**-5): super().__init__(n_dim=2, dt=dt, eps=eps) self.a, self.b = params def forward(self, state): x, y = state vx = self.a - (self.b + 1.0) * x + x * x * y vy = self.b * x - x * x * y return jnp.array([vx, vy]) model = Brusselator(params=[1.0, 3.0])
-
Hodgkin Huxley equation [code]
class HodgkinHuxley(Oscillator): def __init__(self, input_current, C=1.0, G_Na=120.0, G_K=36.0, G_L=0.3, E_Na=50.0, E_K=-77.0, E_L=-54.4, dt=0.01, eps=10**-5): super().__init__(n_dim=4, dt=dt, eps=eps) self.input_current = input_current self.C = C self.G_Na = G_Na self.G_K = G_K self.G_L = G_L self.E_Na = E_Na self.E_K = E_K self.E_L = E_L def alpha_m(self, V): return 0.1*(V+40.0)/(1.0 - jnp.exp(-(V+40.0) / 10.0)) def beta_m(self, V): return 4.0*jnp.exp(-(V+65.0) / 18.0) def alpha_h(self, V): return 0.07*jnp.exp(-(V+65.0) / 20.0) def beta_h(self, V): return 1.0/(1.0 + jnp.exp(-(V+35.0) / 10.0)) def alpha_n(self, V): return 0.01*(V+55.0)/(1.0 - jnp.exp(-(V+55.0) / 10.0)) def beta_n(self, V): return 0.125*jnp.exp(-(V+65) / 80.0) def forward(self, state): V, m, h, n = state dVdt = self.G_Na * (m ** 3) * h * (self.E_Na - V) + self.G_K * (n ** 4) * (self.E_K - V) + self.G_L * (self.E_L - V) + self.input_current dVdt /= self.C dmdt = self.alpha_m(V) * (1.0 - m) - self.beta_m(V) * m dhdt = self.alpha_h(V) * (1.0 - h) - self.beta_h(V) * h dndt = self.alpha_n(V) * (1.0 - n) - self.beta_n(V) * n return jnp.array([dVdt, dmdt, dhdt, dndt]) model = HodgkinHuxley(input_current=30.0)