Skip to content

Commit

Permalink
fix bifurcation analysis bug and #287 (#289)
Browse files Browse the repository at this point in the history
fix bifurcation analysis bug and #287
  • Loading branch information
chaoming0625 authored Nov 8, 2022
2 parents 147d3e8 + 6ff313f commit 1e9c8c2
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 29 deletions.
34 changes: 19 additions & 15 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-

__version__ = "2.2.3.5"

__version__ = "2.2.3.6"

try:
import jaxlib

del jaxlib
except ModuleNotFoundError:
raise ModuleNotFoundError(
Expand Down Expand Up @@ -34,21 +34,17 @@
''') from None


# fundamental modules
from . import errors, tools, check, modes


# "base" module
from . import base
from .base.base import Base
from .base.collector import Collector, TensorCollector


# math foundation
from . import math


# toolboxes
from . import (
connect, # synaptic connection
Expand All @@ -61,7 +57,6 @@
algorithms, # online or offline training algorithms
)


# numerical integrators
from . import integrators
from .integrators import ode
Expand All @@ -72,7 +67,6 @@
from .integrators.fde import fdeint
from .integrators.joint_eq import JointEq


# dynamics simulation
from . import dyn
from .dyn import (
Expand All @@ -82,10 +76,10 @@
neurons, # neuron groups
rates, # rate models
synapses, # synaptic dynamics
synouts, # synaptic output
synouts, # synaptic output
synplast, # synaptic plasticity
)
from brainpy.dyn.base import (
from .dyn.base import (
DynamicalSystem,
Container,
Sequential,
Expand All @@ -101,23 +95,33 @@
)
from .dyn.runners import *


# dynamics training
from . import train

from .train import (
DSTrainer,
OnlineTrainer, ForceTrainer,
OfflineTrainer, RidgeTrainer,
BPFF,
BPTT,
OnlineBPTT,
)

# automatic dynamics analysis
from . import analysis

from .analysis import (
DSAnalyzer,
PhasePlane1D, PhasePlane2D,
Bifurcation1D, Bifurcation2D,
FastSlow1D, FastSlow2D,
SlowPointFinder,
)

# running
from . import running


# "visualization" module, will be removed soon
from .visualization import visualize


# convenient access
conn = connect
init = initialize
Expand Down
6 changes: 3 additions & 3 deletions brainpy/analysis/lowdim/lowdim_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def __init__(
target_pars = dict()
if not isinstance(target_pars, dict):
raise errors.AnalyzerError('"target_pars" must be a dict with the format of {"par1": (val1, val2)}.')
for key in target_pars.keys():
for key, value in target_pars.items():
if key not in self.model.parameters:
raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.')
value = self.target_vars[key]
if value[0] > value[1]:
raise errors.AnalyzerError(f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
raise errors.AnalyzerError(
f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')

self.target_pars = Collector(target_pars)
self.target_par_names = list(self.target_pars.keys()) # list of target_pars
Expand Down
11 changes: 7 additions & 4 deletions brainpy/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
from jax import vmap
import numpy as np
from copy import deepcopy

import brainpy.math as bm
from brainpy import errors
Expand Down Expand Up @@ -79,7 +80,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
pyplot.figure(self.x_var)
for fp_type, points in container.items():
if len(points['x']):
plot_style = plotstyle.plot_schema[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(self.x_var)
Expand Down Expand Up @@ -107,11 +108,12 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['x']):
plot_style = plotstyle.plot_schema[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
xs = points['p0']
ys = points['p1']
zs = points['x']
plot_style.pop('linestyle')
plot_style['s'] = plot_style.pop('markersize', None)
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
Expand Down Expand Up @@ -299,7 +301,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
pyplot.figure(var)
for fp_type, points in container.items():
if len(points['p']):
plot_style = plotstyle.plot_schema[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points['p'], points[var], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(var)
Expand Down Expand Up @@ -331,11 +333,12 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['p0']):
plot_style = plotstyle.plot_schema[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
xs = points['p0']
ys = points['p1']
zs = points[var]
plot_style.pop('linestyle')
plot_style['s'] = plot_style.pop('markersize', None)
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
Expand Down
5 changes: 3 additions & 2 deletions brainpy/analysis/lowdim/lowdim_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from jax import vmap

from copy import deepcopy
import brainpy.math as bm
from brainpy import errors, math
from brainpy.analysis import stability, plotstyle, constants as C, utils
Expand Down Expand Up @@ -107,7 +108,7 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
if with_plot:
for fp_type, points in container.items():
if len(points):
plot_style = plotstyle.plot_schema[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type)
pyplot.legend()
if show:
Expand Down Expand Up @@ -349,7 +350,7 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,
if with_plot:
for fp_type, points in container.items():
if len(points['x']):
plot_style = plotstyle.plot_schema[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type)
pyplot.legend()
if show:
Expand Down
89 changes: 89 additions & 0 deletions brainpy/analysis/lowdim/tests/test_bifurcation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-


import pytest
pytest.skip('Test cannot pass in github action.', allow_module_level=True)
import unittest

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

block = False


class FitzHughNagumoModel(bp.dyn.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(FitzHughNagumoModel, self).__init__()

# parameters
self.a = 0.7
self.b = 0.8
self.tau = 12.5

# variables
self.V = bm.Variable(bm.zeros(1))
self.w = bm.Variable(bm.zeros(1))
self.Iext = bm.Variable(bm.zeros(1))

# functions
def dV(V, t, w, Iext=0.):
dV = V - V * V * V / 3 - w + Iext
return dV

def dw(w, t, V, a=0.7, b=0.8):
dw = (V + a - b * w) / self.tau
return dw

self.int_V = bp.odeint(dV, method=method)
self.int_w = bp.odeint(dw, method=method)

def update(self, tdi):
t, dt = tdi['t'], tdi['dt']
self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt)
self.Iext[:] = 0.


class TestBifurcation1D(unittest.TestCase):
def test_bifurcation_1d(self):
bp.math.enable_x64()

@bp.odeint
def int_x(x, t, a=1., b=1.):
return bp.math.sin(a * x) + bp.math.cos(b * x)

pp = bp.analysis.PhasePlane1D(
model=int_x,
target_vars={'x': [-bp.math.pi, bp.math.pi]},
resolutions=0.1
)
pp.plot_vector_field()
pp.plot_fixed_point(show=True)

bf = bp.analysis.Bifurcation1D(
model=int_x,
target_vars={'x': [-bp.math.pi, bp.math.pi]},
target_pars={'a': [0.5, 1.5], 'b': [0.5, 1.5]},
resolutions={'a': 0.1, 'b': 0.1}
)
bf.plot_bifurcation(show=False)
plt.show(block=block)
plt.close()
bp.math.disable_x64()

def test_bifurcation_2d(self):
bp.math.enable_x64()

model = FitzHughNagumoModel()
bif = bp.analysis.Bifurcation2D(
model=model,
target_vars={'V': [-3., 3.], 'w': [-1, 3.]},
target_pars={'Iext': [0., 1.]},
resolutions={'Iext': 0.1}
)
bif.plot_bifurcation()
bif.plot_limit_cycle_by_sim()
plt.show(block=block)

# bp.math.disable_x64()
4 changes: 1 addition & 3 deletions brainpy/analysis/lowdim/tests/test_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import unittest

import brainpy as bp
import matplotlib.pyplot as plt

block = False


class TestPhasePlane(unittest.TestCase):
def test_1d(self):
import matplotlib.pyplot as plt
bp.math.enable_x64()

@bp.odeint
Expand All @@ -30,8 +30,6 @@ def int_x(x, t, Iext):
bp.math.disable_x64()

def test_2d_decision_making_model(self):
import matplotlib.pyplot as plt

bp.math.enable_x64()
gamma = 0.641 # Saturation factor for gating variable
tau = 0.06 # Synaptic time constant [sec]
Expand Down
2 changes: 1 addition & 1 deletion brainpy/analysis/plotstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D)


_markersize = 20
_markersize = 10

plot_schema = {}

Expand Down
3 changes: 2 additions & 1 deletion brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def register_delay(
elif delay.num_delay_step - 1 < max_delay_step:
self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data)
else:
self.global_delay_data[identifier] = (None, delay_target)
if identifier not in self.global_delay_data:
self.global_delay_data[identifier] = (None, delay_target)
self.register_implicit_nodes(self.local_delay_vars)
return delay_step

Expand Down
20 changes: 20 additions & 0 deletions brainpy/dyn/tests/test_base_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-

import unittest

import brainpy as bp


class TestDynamicalSystem(unittest.TestCase):
def test_delay(self):
A = bp.neurons.LIF(1)
B = bp.neurons.LIF(1)
C = bp.neurons.LIF(1)
A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1)
A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None)
net = bp.Network(A, B, C, A2B, A2C)

runner = bp.DSRunner(net,)
runner.run(10.)


0 comments on commit 1e9c8c2

Please sign in to comment.