diff --git a/README.md b/README.md index f0c18b336..9123362dc 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,8 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu - **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming. - **[brainpylib](https://github.com/brainpy/brainpylib)**: Efficient operators for the sparse and event-driven computation. -- **[BrainPyExamples](https://github.com/brainpy/BrainPyExamples)**: Comprehensive examples of BrainPy computation. -- **[brainpy-largescale](https://github.com/NH-NCL/brainpy-largescale)**: One solution for the large-scale brain modeling. +- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation. +- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling. diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 67073bee0..2b3c621dc 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -58,15 +58,16 @@ synapses, # synaptic dynamics synouts, # synaptic output synplast, # synaptic plasticity - syn, + experimental, ) -from brainpy._src.dyn.base import not_pass_sha +from brainpy._src.dyn.base import not_pass_shared from brainpy._src.dyn.base import (DynamicalSystem, DynamicalSystemNS, Container as Container, Sequential as Sequential, Network as Network, NeuGroup as NeuGroup, + NeuGroupNS as NeuGroupNS, SynConn as SynConn, SynOut as SynOut, SynSTP as SynSTP, @@ -77,8 +78,7 @@ from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations LoopOverTime as LoopOverTime,) from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner -from brainpy._src.dyn.context import share -from brainpy._src.dyn.delay import Delay +from brainpy._src.dyn.context import share, Delay # Part 4: Training # diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py index 74388c04e..0a1a48e82 100644 --- a/brainpy/_src/analysis/highdim/slow_points.py +++ b/brainpy/_src/analysis/highdim/slow_points.py @@ -349,6 +349,7 @@ def f_loss(): def train(idx): gradients, loss = grad_f() optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients}) + optimizer.lr.step_epoch() return loss def batch_train(start_i, n_batch): diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 078a0aba6..fb76ce000 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -50,7 +50,7 @@ SLICE_VARS = 'slice_vars' -def not_pass_sha(func: Callable): +def not_pass_shared(func: Callable): """Label the update function as the one without passing shared arguments. The original update function explicitly requires shared arguments at the first place:: @@ -610,7 +610,8 @@ def __repr__(self): entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(self._modules)) return f'{self.__class__.__name__}(\n{entries}\n)' - def update(self, s, x) -> ArrayType: + @not_pass_shared + def update(self, x) -> ArrayType: """Update function of a sequential model. Parameters @@ -626,12 +627,7 @@ def update(self, s, x) -> ArrayType: The output tensor. """ for m in self._modules: - if isinstance(m, DynamicalSystemNS): - x = m(x) - elif isinstance(m, DynamicalSystem): - x = m(s, x) - else: - x = m(x) + x = m(x) return x @@ -665,7 +661,7 @@ def __init__( mode=mode, **ds_dict) - @not_pass_sha + @not_pass_shared def update(self, *args, **kwargs): """Step function of a network. @@ -807,6 +803,11 @@ def __getitem__(self, item): return NeuGroupView(target=self, index=item) +class NeuGroupNS(NeuGroup): + """Base class for neuron group without shared arguments passed.""" + pass_shared = False + + class SynConn(DynamicalSystem): """Base class to model two-end synaptic connections. diff --git a/brainpy/_src/dyn/context.py b/brainpy/_src/dyn/context.py index 9293add7d..8773df2de 100644 --- a/brainpy/_src/dyn/context.py +++ b/brainpy/_src/dyn/context.py @@ -4,19 +4,305 @@ This context defines all shared data used in all modules in a computation. """ -from typing import Dict, Any, Union +from typing import Any +from typing import Union, Callable, Optional, Dict -from brainpy._src.tools.dicts import DotDict -from brainpy._src.dyn.delay import Delay +import jax +import jax.numpy as jnp +import numpy as np +from brainpy import check +from brainpy import math as bm +from brainpy._src.dyn.base import DynamicalSystemNS +from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE from brainpy._src.math.environment import get_dt -from brainpy._src.math.object_transform.base import BrainPyObject, dyn_dict +from brainpy._src.math.object_transform.base import dyn_dict +from brainpy._src.tools.dicts import DotDict +from brainpy.check import is_integer, jit_error_checking +from jax.lax import stop_gradient __all__ = [ + 'Delay', 'share', ] -class _ShareContext(BrainPyObject): +class Delay(DynamicalSystemNS): + """Delay variable which has a fixed delay length. + + The data in this delay variable is arranged as:: + + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + Parameters + ---------- + target: Variable + The initial delay data. + length: int + The delay data length. + before_t0: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + method: str + The method used for updating delay. + + """ + + data: Optional[bm.Variable] + length: int + + def __init__( + self, + target: bm.Variable, + length: int = 0, + before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, + entries: Optional[Dict] = None, + name: str = None, + method: str = ROTATE_UPDATE, + ): + + super().__init__(name=name) + if method is None: + if self.mode.is_a(bm.NonBatchingMode): + method = ROTATE_UPDATE + elif self.mode.is_parent_of(bm.TrainingMode): + method = CONCAT_UPDATE + else: + method = ROTATE_UPDATE + assert method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.method = method + + # target + self.target = target + if not isinstance(target, bm.Variable): + raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') + + # delay length + self.length = is_integer(length, allow_none=False, min_bound=0) + + # delay data + if before_t0 is not None: + assert isinstance(before_t0, (int, float, bool, bm.Array, jax.Array, Callable)) + self._before_t0 = before_t0 + if length > 0: + self._init_data(length) + else: + self.data = None + + # other info + self._access_to_step = dict() + for entry, value in entries.items(): + self.register_entry(entry, value) + + def register_entry( + self, + entry: str, + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[Union[int, bm.Array, Callable]] = None, + ) -> 'Delay': + """Register an entry to access the data. + + Args: + entry (str): The entry to access the delay data. + delay_step: The delay step of the entry (must be an integer, denoting the delay step). + delay_time: The delay time of the entry (can be a float). + + Returns: + Return the self. + """ + if entry in self._access_to_step: + raise KeyError(f'Entry {entry} has been registered.') + + if delay_time is not None: + if delay_step is not None: + raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') + if callable(delay_time): + delay_time = bm.as_jax(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) + elif isinstance(delay_time, float): + delay_step = int(delay_time / bm.get_dt()) + else: + delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) + + # delay steps + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): + if delay_step.size == 1 and delay_step.ndim == 0: + delay_type = 'homo' + else: + delay_type = 'heter' + delay_step = bm.Array(delay_step) + elif callable(delay_step): + delay_step = delay_step(self.delay_target_shape) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [jnp.int32, jnp.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if self.delay_target_shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') + if delay_type == 'heter': + max_delay_step = int(max(delay_step)) + elif delay_type == 'homo': + max_delay_step = delay_step + else: + max_delay_step = None + + # delay variable + if max_delay_step is not None: + if self.length < max_delay_step: + self._init_data(max_delay_step) + self.length = max_delay_step + self._access_to_step[entry] = delay_step + return self + + def at(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. + + Args: + entry (str): The entry to access the data. + *indices: + + Returns: + The data. + """ + assert isinstance(entry, str) + if entry not in self._access_to_step: + raise KeyError(f'Does not find delay entry "{entry}".') + delay_step = self._access_to_step[entry] + if delay_step is None: + return self.target.value + else: + if self.data is None: + return self.target.value + else: + if isinstance(delay_step, slice): + return self.retrieve(delay_step, *indices) + elif np.ndim(delay_step) == 0: + return self.retrieve(delay_step, *indices) + else: + if len(indices) == 0 and len(delay_step) == self.target.shape[0]: + indices = (jnp.arange(delay_step.size),) + return self.retrieve(delay_step, *indices) + + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.target.shape + + def __repr__(self): + name = self.__class__.__name__ + return (f'{name}(num_delay_step={self.length}, ' + f'delay_target_shape={self.delay_target_shape}, ' + f'update_method={self.method})') + + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.length}. ' + f'But we got {delay_len}') + + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. + + Parameters + ---------- + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + assert delay_step is not None + if check.is_checking(): + jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) + + if self.method == ROTATE_UPDATE: + i = share.load('i') + delay_idx = (i + delay_step) % (self.length + 1) + delay_idx = stop_gradient(delay_idx) + + elif self.method == CONCAT_UPDATE: + delay_idx = delay_step + + else: + raise ValueError(f'Unknown updating method "{self.method}"') + + # the delay index + if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + tuple(indices) + + # the delay data + return self.data[indices] + + def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: + """Update delay variable with the new data. + """ + if self.data is not None: + # get the latest target value + if latest_value is None: + latest_value = self.target.value + + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + i = share.load('i') + idx = bm.as_jax((i - 1) % (self.length + 1)) + self.data[idx] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.length >= 2: + self.data.value = bm.vstack([latest_value, self.data[1:]]) + else: + self.data[0] = latest_value + + def reset_state(self, batch_size: int = None): + """Reset the delay data. + """ + # initialize delay data + if self.data is not None: + self._init_data(self.length, batch_size) + + def _init_data(self, length, batch_size: int = None): + if batch_size is not None: + if self.target.batch_size != batch_size: + raise ValueError(f'The batch sizes of delay variable and target variable differ ' + f'({self.target.batch_size} != {batch_size}). ' + 'Please reset the target variable first, because delay data ' + 'depends on the target variable. ') + + if self.target.batch_axis is None: + batch_axis = None + else: + batch_axis = self.target.batch_axis + 1 + self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), + batch_axis=batch_axis) + # update delay data + self.data[0] = self.target.value + if isinstance(self._before_t0, (bm.Array, jax.Array, float, int, bool)): + self.data[1:] = self._before_t0 + elif callable(self._before_t0): + self.data[1:] = self._before_t0((length,) + self.target.shape, dtype=self.target.dtype) + + +class _ShareContext(DynamicalSystemNS): def __init__(self): super().__init__() diff --git a/brainpy/_src/dyn/delay.py b/brainpy/_src/dyn/delay.py deleted file mode 100644 index 35c6d33cc..000000000 --- a/brainpy/_src/dyn/delay.py +++ /dev/null @@ -1,297 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union, Callable, Optional, Dict - -import jax -import jax.numpy as jnp -import numpy as np -from jax.lax import stop_gradient - -from brainpy import check -from brainpy import math as bm -from brainpy._src.dyn.base import DynamicalSystemNS -from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE -from brainpy.check import is_integer, jit_error_checking - - -class Delay(DynamicalSystemNS): - """Delay variable which has a fixed delay length. - - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Parameters - ---------- - target: Variable - The initial delay data. - length: int - The delay data length. - before_t0: Any - The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - method: str - The method used for updating delay. - - """ - - data: Optional[bm.Variable] - idx: Optional[bm.Variable] - length: int - - def __init__( - self, - target: bm.Variable, - length: int = 0, - before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, - entries: Optional[Dict] = None, - name: str = None, - method: str = ROTATE_UPDATE, - ): - - super().__init__(name=name) - if method is None: - if self.mode.is_a(bm.NonBatchingMode): - method = ROTATE_UPDATE - elif self.mode.is_parent_of(bm.TrainingMode): - method = CONCAT_UPDATE - else: - method = ROTATE_UPDATE - assert method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.method = method - - # target - self.target = target - if not isinstance(target, bm.Variable): - raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') - - # delay length - self.length = is_integer(length, allow_none=False, min_bound=0) - - # delay data - if before_t0 is not None: - assert isinstance(before_t0, (int, float, bool, bm.Array, jax.Array, Callable)) - self._before_t0 = before_t0 - if length > 0: - self._init_data(length) - else: - self.data = None - - # time variables - if self.method == ROTATE_UPDATE: - self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) - - # other info - self._access_to_step = dict() - for entry, value in entries.items(): - self.register_entry(entry, value) - - def register_entry( - self, - entry: str, - delay_time: Optional[Union[float, bm.Array, Callable]] = None, - delay_step: Optional[Union[int, bm.Array, Callable]] = None, - ) -> 'Delay': - """Register an entry to access the data. - - Args: - entry (str): The entry to access the delay data. - delay_step: The delay step of the entry (must be an integer, denoting the delay step). - delay_time: The delay time of the entry (can be a float). - - Returns: - Return the self. - """ - if entry in self._access_to_step: - raise KeyError(f'Entry {entry} has been registered.') - - if delay_time is not None: - if delay_step is not None: - raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') - if callable(delay_time): - delay_time = bm.as_jax(delay_time(self.delay_target_shape)) - delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) - elif isinstance(delay_time, float): - delay_step = int(delay_time / bm.get_dt()) - else: - delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) - - # delay steps - if delay_step is None: - delay_type = 'none' - elif isinstance(delay_step, int): - delay_type = 'homo' - elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): - if delay_step.size == 1 and delay_step.ndim == 0: - delay_type = 'homo' - else: - delay_type = 'heter' - delay_step = bm.Array(delay_step) - elif callable(delay_step): - delay_step = delay_step(self.delay_target_shape) - delay_type = 'heter' - else: - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' - f'integer, array of integers, callable function, brainpy.init.Initializer.') - if delay_type == 'heter': - if delay_step.dtype not in [jnp.int32, jnp.int64]: - raise ValueError('Only support delay steps of int32, int64. If your ' - 'provide delay time length, please divide the "dt" ' - 'then provide us the number of delay steps.') - if self.delay_target_shape[0] != delay_step.shape[0]: - raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') - if delay_type == 'heter': - max_delay_step = int(max(delay_step)) - elif delay_type == 'homo': - max_delay_step = delay_step - else: - max_delay_step = None - - # delay variable - if max_delay_step is not None: - if self.length < max_delay_step: - self._init_data(max_delay_step) - self.length = max_delay_step - self._access_to_step[entry] = delay_step - return self - - def at(self, entry: str, *indices) -> bm.Array: - """Get the data at the given entry. - - Args: - entry (str): The entry to access the data. - *indices: - - Returns: - The data. - """ - assert isinstance(entry, str) - if entry not in self._access_to_step: - raise KeyError(f'Does not find delay entry "{entry}".') - delay_step = self._access_to_step[entry] - if delay_step is None: - return self.target.value - else: - if self.data is None: - return self.target.value - else: - if isinstance(delay_step, slice): - return self.retrieve(delay_step, *indices) - elif np.ndim(delay_step) == 0: - return self.retrieve(delay_step, *indices) - else: - if len(indices) == 0 and len(delay_step) == self.target.shape[0]: - indices = (jnp.arange(delay_step.size),) - return self.retrieve(delay_step, *indices) - - @property - def delay_target_shape(self): - """The data shape of the delay target.""" - return self.target.shape - - def __repr__(self): - name = self.__class__.__name__ - return (f'{name}(num_delay_step={self.length}, ' - f'delay_target_shape={self.delay_target_shape}, ' - f'update_method={self.method})') - - def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.length}. ' - f'But we got {delay_len}') - - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. - - Parameters - ---------- - delay_step: int, ArrayType - The delay length used to retrieve the data. - """ - assert delay_step is not None - if check.is_checking(): - jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) - - if self.method == ROTATE_UPDATE: - delay_idx = (self.idx.value + delay_step) % (self.length + 1) - delay_idx = stop_gradient(delay_idx) - - elif self.method == CONCAT_UPDATE: - delay_idx = delay_step - - else: - raise ValueError(f'Unknown updating method "{self.method}"') - - # the delay index - if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') - indices = (delay_idx,) + tuple(indices) - - # the delay data - return self.data[indices] - - def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: - """Update delay variable with the new data. - """ - if self.data is not None: - # get the latest target value - if latest_value is None: - latest_value = self.target.value - - # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: - self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1))) - self.data[self.idx.value] = latest_value - - # update the delay data at the first position - elif self.method == CONCAT_UPDATE: - if self.length >= 2: - self.data.value = bm.vstack([latest_value, self.data[1:]]) - else: - self.data[0] = latest_value - - def reset_state(self, batch_size: int = None): - """Reset the delay data. - """ - # initialize delay data - if self.data is not None: - self._init_data(self.length, batch_size) - - # time variables - if self.method == ROTATE_UPDATE: - self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) - - def _init_data(self, length, batch_size: int = None): - if batch_size is not None: - if self.target.batch_size != batch_size: - raise ValueError(f'The batch sizes of delay variable and target variable differ ' - f'({self.target.batch_size} != {batch_size}). ' - 'Please reset the target variable first, because delay data ' - 'depends on the target variable. ') - - if self.target.batch_axis is None: - batch_axis = None - else: - batch_axis = self.target.batch_axis + 1 - self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), - batch_axis=batch_axis) - # update delay data - self.data[0] = self.target.value - if isinstance(self._before_t0, (bm.Array, jax.Array, float, int, bool)): - self.data[1:] = self._before_t0 - elif callable(self._before_t0): - self.data[1:] = self._before_t0((length,) + self.target.shape, dtype=self.target.dtype) diff --git a/brainpy/_src/dyn/layers/conv.py b/brainpy/_src/dyn/layers/conv.py index 5fbf393fb..67e00b056 100644 --- a/brainpy/_src/dyn/layers/conv.py +++ b/brainpy/_src/dyn/layers/conv.py @@ -5,7 +5,7 @@ from jax import lax from brainpy import math as bm, tools, check -from brainpy._src.dyn.base import not_pass_sha +from brainpy._src.dyn.base import not_pass_shared from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter from brainpy.types import ArrayType from .base import Layer @@ -154,7 +154,7 @@ def _check_input_dim(self, x): raise ValueError(f"input channels={x.shape[-1]} needs to have " f"the same size as in_channels={self.in_channels}.") - @not_pass_sha + @not_pass_shared def update(self, x): self._check_input_dim(x) w = self.w.value @@ -526,7 +526,7 @@ def __init__( def _check_input_dim(self, x): raise NotImplementedError - @not_pass_sha + @not_pass_shared def update(self, x): self._check_input_dim(x) diff --git a/brainpy/_src/dyn/layers/normalization.py b/brainpy/_src/dyn/layers/normalization.py index 6751e2bbe..79811ff93 100644 --- a/brainpy/_src/dyn/layers/normalization.py +++ b/brainpy/_src/dyn/layers/normalization.py @@ -135,7 +135,8 @@ def update(self, x): if self.axis_name is not None: mean, mean_of_square = jnp.split(lax.pmean(jnp.concatenate([mean, mean_of_square]), axis_name=self.axis_name, - axis_index_groups=self.axis_index_groups), 2) + axis_index_groups=self.axis_index_groups), + 2) var = jnp.maximum(0., mean_of_square - _square(mean)) self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean) self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var) diff --git a/brainpy/_src/dyn/layers/nvar.py b/brainpy/_src/dyn/layers/nvar.py index 43fc5c66f..cf9b55b9f 100644 --- a/brainpy/_src/dyn/layers/nvar.py +++ b/brainpy/_src/dyn/layers/nvar.py @@ -9,7 +9,7 @@ import brainpy.math as bm from brainpy import check from .base import Layer -from brainpy._src.dyn.base import not_pass_sha +from brainpy._src.dyn.base import not_pass_shared __all__ = [ 'NVAR' diff --git a/brainpy/_src/dyn/neurons/biological_models.py b/brainpy/_src/dyn/neurons/biological_models.py index 6238557a6..fefe9253d 100644 --- a/brainpy/_src/dyn/neurons/biological_models.py +++ b/brainpy/_src/dyn/neurons/biological_models.py @@ -4,8 +4,8 @@ import brainpy.math as bm from brainpy import check +from brainpy._src.dyn.base import NeuGroupNS from brainpy._src.dyn.context import share -from brainpy._src.dyn.base import NeuGroup, not_pass_sha from brainpy._src.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_ from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.generic import odeint @@ -20,7 +20,7 @@ ] -class HH(NeuGroup): +class HH(NeuGroupNS): r"""Hodgkin–Huxley neuron model. **Model Descriptions** @@ -211,6 +211,7 @@ def __init__( noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, + input_var: bool = True, # training parameter mode: bm.Mode = None, @@ -234,6 +235,7 @@ def __init__( self.C = parameter(C, self.varshape, allow_none=False) self.V_th = parameter(V_th, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=4) + self.input_var = input_var # initializers check.is_initializer(m_initializer, 'm_initializer', allow_none=True) @@ -286,36 +288,44 @@ def reset_state(self, batch_size=None): self.n = bm.Variable(self.n_inf(self.V.value), batch_axis=self.V.batch_axis) else: self.n = variable_(self._n_initializer, self.varshape, batch_size) - self.input = variable_(bm.zeros, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - def dV(self, V, t, m, h, n): + def dV(self, V, t, m, h, n, I): I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) I_K = (self.gK * n ** 4.0) * (V - self.EK) I_leak = self.gL * (V - self.EL) - dVdt = (- I_Na - I_K - I_leak + self.input) / self.C + dVdt = (- I_Na - I_K - I_leak + I) / self.C return dVdt @property def derivative(self): return JointEq(self.dV, self.dm, self.dh, self.dn) - @not_pass_sha def update(self, x=None): - s = share.get_shargs() - if x is not None: self.input += x - V, m, h, n = self.integral(self.V, self.m, self.h, self.n, s['t'], s['dt']) + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, x, dt) self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.m.value = m self.h.value = h self.n.value = n + return self.spike.value def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. -class MorrisLecar(NeuGroup): +class MorrisLecar(NeuGroupNS): r"""The Morris-Lecar neuron model. **Model Descriptions** @@ -411,6 +421,7 @@ def __init__( noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, + input_var: bool = True, # training parameter mode: bm.Mode = None, @@ -437,12 +448,11 @@ def __init__( self.phi = parameter(phi, self.varshape, allow_none=False) self.V_th = parameter(V_th, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=2) + self.input_var = input_var # initializers - check.is_initializer(V_initializer, 'V_initializer', allow_none=False) - check.is_initializer(W_initializer, 'W_initializer', allow_none=False) - self._W_initializer = W_initializer - self._V_initializer = V_initializer + self._W_initializer = check.is_initializer(V_initializer, allow_none=False) + self._V_initializer = check.is_initializer(W_initializer, allow_none=False) # variables self.reset_state(self.mode) @@ -456,8 +466,9 @@ def __init__( def reset_state(self, batch_size=None): self.W = variable_(self._W_initializer, self.varshape, batch_size) self.V = variable_(self._V_initializer, self.varshape, batch_size) - self.input = variable_(bm.zeros, self.varshape, batch_size) self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) def dV(self, V, t, W, I_ext): M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) @@ -477,19 +488,28 @@ def dW(self, W, t, V): def derivative(self): return JointEq([self.dV, self.dW]) - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] - if x is not None: self.input += x - V, self.W.value = self.integral(self.V, self.W, t, self.input, dt) + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, W = self.integral(self.V, self.W, t, x, dt) spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V + self.W.value = W self.spike.value = spike + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. -class PinskyRinzelModel(NeuGroup): +class PinskyRinzelModel(NeuGroupNS): r"""The Pinsky and Rinsel (1994) model. The Pinsky and Rinsel (1994) model [7]_ is a 2-compartment (soma and dendrite), @@ -736,8 +756,9 @@ def reset_state(self, batch_size=None): self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis) self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis) self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis) - self.Id.value = variable_(bm.zeros, self.varshape, batch_size) - self.Is.value = variable_(bm.zeros, self.varshape, batch_size) + if self.input_var: + self.Id.value = variable_(bm.zeros, self.varshape, batch_size) + self.Is.value = variable_(bm.zeros, self.varshape, batch_size) # self.spike[:] = False def dCa(self, Ca, t, s, Vd): @@ -876,7 +897,7 @@ def inf_q(self, Ca): return alpha / (alpha + beta) -class WangBuzsakiModel(NeuGroup): +class WangBuzsakiModel(NeuGroupNS): r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. Each model is described by a single compartment and obeys the current balance equation: @@ -979,12 +1000,13 @@ def __init__( n_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.32), noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', + input_var: bool = True, name: str = None, mode: bm.Mode = None, ): # initialization super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode) - check.is_subclass(self.mode, (bm.BatchingMode, bm.NonBatchingMode), self.__class__) + check.is_subclass(self.mode, (bm.BatchingMode, bm.NonBatchingMode)) # parameters self.ENa = parameter(ENa, self.varshape, allow_none=False) @@ -997,6 +1019,7 @@ def __init__( self.phi = parameter(phi, self.varshape, allow_none=False) self.V_th = parameter(V_th, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=3) + self.input_var = input_var # initializers check.is_initializer(h_initializer, 'h_initializer', allow_none=False) @@ -1010,8 +1033,9 @@ def __init__( self.h = variable_(self._h_initializer, self.varshape, self.mode) self.n = variable_(self._n_initializer, self.varshape, self.mode) self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, self.mode) # integral if self.noise is None: @@ -1023,8 +1047,9 @@ def reset_state(self, batch_size=None): self.h.value = variable_(self._h_initializer, self.varshape, batch_size) self.n.value = variable_(self._n_initializer, self.varshape, batch_size) self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + if self.input_var: + self.input.value = variable_(bm.zeros, self.varshape, batch_size) def m_inf(self, V): alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) @@ -1052,16 +1077,24 @@ def dV(self, V, t, h, n, I_ext): @property def derivative(self): - return JointEq([self.dV, self.dh, self.dn]) + return JointEq(self.dV, self.dh, self.dn) - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] - if x is not None: self.input += x - V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt) + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, h, n = self.integral(self.V, self.h, self.n, t, x, dt) self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.h.value = h self.n.value = n + return self.spike.value def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. diff --git a/brainpy/_src/dyn/neurons/fractional_models.py b/brainpy/_src/dyn/neurons/fractional_models.py index b2b7908af..4c9eae00f 100644 --- a/brainpy/_src/dyn/neurons/fractional_models.py +++ b/brainpy/_src/dyn/neurons/fractional_models.py @@ -4,7 +4,8 @@ import jax.numpy as jnp import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup +from brainpy._src.dyn.base import NeuGroupNS +from brainpy._src.dyn.context import share from brainpy._src.initialize import ZeroInit, OneInit, Initializer, parameter from brainpy._src.integrators.fde import CaputoL1Schema from brainpy._src.integrators.fde import GLShortMemory @@ -19,7 +20,7 @@ ] -class FractionalNeuron(NeuGroup): +class FractionalNeuron(NeuGroupNS): """Fractional-order neuron model.""" pass @@ -93,6 +94,7 @@ def __init__( V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.5), w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), y_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + input_var: bool = True, name: str = None, keep_size: bool = False, ): @@ -110,6 +112,7 @@ def __init__( self.mu = parameter(mu, self.varshape, allow_none=False) self.Vth = parameter(Vth, self.varshape, allow_none=False) self.delta = parameter(delta, self.varshape, allow_none=False) + self.input_var = input_var # initializers is_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -123,8 +126,9 @@ def __init__( self.V = bm.Variable(parameter(V_initializer, self.varshape)) self.w = bm.Variable(parameter(w_initializer, self.varshape)) self.y = bm.Variable(parameter(y_initializer, self.varshape)) - self.input = bm.Variable(jnp.zeros(self.varshape)) self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) + if self.input_var: + self.input = bm.Variable(jnp.zeros(self.varshape)) # integral function self.integral = GLShortMemory(self.derivative, @@ -137,13 +141,14 @@ def reset_state(self, batch_size=None): self.V.value = parameter(self._V_initializer, self.varshape) self.w.value = parameter(self._w_initializer, self.varshape) self.y.value = parameter(self._y_initializer, self.varshape) - self.input[:] = 0 self.spike[:] = False + if self.input_var: + self.input[:] = 0 # integral function reset self.integral.reset([self.V, self.w, self.y]) - def dV(self, V, t, w, y): - return V - V ** 3 / 3 - w + y + self.input + def dV(self, V, t, w, y, I): + return V - V ** 3 / 3 - w + y + I def dw(self, w, t, V): return self.delta * (self.a + V - self.b * w) @@ -156,16 +161,24 @@ def derivative(self): return JointEq([self.dV, self.dw, self.dy]) def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] - if x is not None: self.input += x - V, w, y = self.integral(self.V, self.w, self.y, t, dt) + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, w, y = self.integral(self.V, self.w, self.y, t, I=x, dt=dt) self.spike.value = jnp.logical_and(V >= self.Vth, self.V < self.Vth) self.V.value = V self.w.value = w self.y.value = y + return self.spike.value def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class FractionalIzhikevich(FractionalNeuron): @@ -240,6 +253,7 @@ def __init__( V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-65.), u_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.20 * -65.), keep_size: bool = False, + input_var: bool = True, name: str = None ): # initialization @@ -258,6 +272,7 @@ def __init__( self.tau = parameter(tau, self.varshape, allow_none=False) self.R = parameter(R, self.varshape, allow_none=False) self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.input_var = input_var # initializers is_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -268,8 +283,9 @@ def __init__( # variables self.V = bm.Variable(parameter(V_initializer, self.varshape)) self.u = bm.Variable(parameter(u_initializer, self.varshape)) - self.input = bm.Variable(jnp.zeros(self.varshape)) self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) + if self.input_var: + self.input = bm.Variable(jnp.zeros(self.varshape)) # functions is_integer(num_memory, 'num_step', allow_none=False) @@ -281,8 +297,9 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = parameter(self._V_initializer, self.varshape) self.u.value = parameter(self._u_initializer, self.varshape) - self.input[:] = 0 self.spike[:] = False + if self.input_var: + self.input[:] = 0 # integral function reset self.integral.reset([self.V, self.u]) @@ -296,16 +313,24 @@ def du(self, u, t, V): @property def derivative(self): - return JointEq([self.dV, self.du]) - - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] - if x is not None: self.input += x - V, u = self.integral(self.V, self.u, t=t, I_ext=self.input, dt=dt) + return JointEq(self.dV, self.du) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, u = self.integral(self.V, self.u, t=t, I_ext=x, dt=dt) spikes = V >= self.V_th self.V.value = jnp.where(spikes, self.c, V) self.u.value = jnp.where(spikes, u + self.d, u) self.spike.value = spikes + return spikes def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. diff --git a/brainpy/_src/dyn/neurons/input_groups.py b/brainpy/_src/dyn/neurons/input_groups.py index e0532f208..833d2eb9f 100644 --- a/brainpy/_src/dyn/neurons/input_groups.py +++ b/brainpy/_src/dyn/neurons/input_groups.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from brainpy._src.dyn.context import share import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup, not_pass_sha +from brainpy._src.dyn.base import NeuGroupNS, not_pass_shared from brainpy._src.initialize import Initializer, parameter, variable_ from brainpy.types import Shape, ArrayType @@ -17,7 +17,7 @@ ] -class InputGroup(NeuGroup): +class InputGroup(NeuGroupNS): """Input neuron group for place holder. Parameters @@ -41,7 +41,6 @@ def __init__( mode=mode) self.spike = None - @not_pass_sha def update(self, x): return x @@ -49,7 +48,7 @@ def reset_state(self, batch_size=None): pass -class OutputGroup(NeuGroup): +class OutputGroup(NeuGroupNS): """Output neuron group for place holder. Parameters @@ -73,7 +72,6 @@ def __init__( mode=mode) self.spike = None - @not_pass_sha def update(self, x): return x @@ -81,7 +79,7 @@ def reset_state(self, batch_size=None): pass -class SpikeTimeGroup(NeuGroup): +class SpikeTimeGroup(NeuGroupNS): """The input neuron group characterized by spikes emitting at given times. >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms. @@ -162,14 +160,13 @@ def reset_state(self, batch_size=None): self.i = bm.Variable(bm.asarray(0)) self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) - @not_pass_sha def update(self): self.spike.value = bm.zeros_like(self.spike) self._run(share.load('t')) return self.spike.value -class PoissonGroup(NeuGroup): +class PoissonGroup(NeuGroupNS): """Poisson Neuron Group. """ @@ -196,8 +193,7 @@ def __init__( self.rng = bm.random.default_rng(seed) self.reset_state(self.mode) - @not_pass_sha - def update(self, x=None): + def update(self): spikes = self.rng.rand_like(self.spike) <= (self.freqs * share.dt / 1000.) self.spike.value = spikes return spikes @@ -208,3 +204,4 @@ def reset(self, batch_size=None): def reset_state(self, batch_size=None): self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + diff --git a/brainpy/_src/dyn/neurons/noise_groups.py b/brainpy/_src/dyn/neurons/noise_groups.py index c6c9749f8..3c6c14f40 100644 --- a/brainpy/_src/dyn/neurons/noise_groups.py +++ b/brainpy/_src/dyn/neurons/noise_groups.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from brainpy._src.dyn.context import share from brainpy import math as bm, initialize as init -from brainpy._src.dyn.base import NeuGroup, not_pass_sha +from brainpy._src.dyn.base import NeuGroupNS as NeuGroup, not_pass_shared from brainpy._src.initialize import Initializer from brainpy._src.integrators.sde.generic import sdeint from brainpy.types import ArrayType, Shape @@ -77,7 +77,6 @@ def df(self, x, t): def dg(self, x, t): return self.sigma - @not_pass_sha def update(self): t = share.load('t') dt = share.load('dt') diff --git a/brainpy/_src/dyn/neurons/reduced_models.py b/brainpy/_src/dyn/neurons/reduced_models.py index 419eb4599..f17a2c30c 100644 --- a/brainpy/_src/dyn/neurons/reduced_models.py +++ b/brainpy/_src/dyn/neurons/reduced_models.py @@ -6,7 +6,7 @@ from jax.lax import stop_gradient import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup, not_pass_sha +from brainpy._src.dyn.base import NeuGroupNS as NeuGroup, not_pass_shared from brainpy._src.dyn.context import share from brainpy._src.initialize import (ZeroInit, OneInit, @@ -122,7 +122,7 @@ def reset_state(self, batch_size=None): if self.input_var: self.input = variable_(bm.zeros, self.varshape, batch_size) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -266,7 +266,7 @@ def reset_state(self, batch_size=None): if self.ref_var: self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -491,7 +491,7 @@ def derivative(self, V, t, I_ext): dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau return dvdt - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -678,7 +678,7 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -837,7 +837,7 @@ def derivative(self, V, t, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau return dVdt - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1018,7 +1018,7 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1220,7 +1220,7 @@ def dV(self, V, t, I1, I2, I_ext): def derivative(self): return JointEq([self.dI1, self.dI2, self.dVth, self.dV]) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1371,7 +1371,7 @@ def reset_state(self, batch_size=None): self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1565,7 +1565,7 @@ def du(self, u, t, V): dudt = self.a * (self.b * V - u) return dudt - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1805,7 +1805,7 @@ def dz(self, z, t, V): def derivative(self): return JointEq([self.dV, self.dy, self.dz]) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1978,7 +1978,7 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -2096,7 +2096,7 @@ def reset_state(self, batch_size=None): if self.tau_ref is not None: self.t_last_spike = variable_(OneInit(-1e7), self.varshape, batch_size) - @not_pass_sha + @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py index fea21c514..148bce097 100644 --- a/brainpy/_src/dyn/rates/populations.py +++ b/brainpy/_src/dyn/rates/populations.py @@ -3,7 +3,7 @@ from typing import Union, Callable from brainpy import check, math as bm -from brainpy._src.dyn.base import NeuGroup +from brainpy._src.dyn.base import NeuGroupNS as NeuGroup from brainpy._src.dyn.neurons.noise_groups import OUProcess from brainpy._src.initialize import Initializer, Uniform, parameter, variable, ZeroInit from brainpy._src.integrators.joint_eq import JointEq diff --git a/brainpy/_src/dyn/synapses_v2/abstract_models.py b/brainpy/_src/dyn/synapses_v2/abstract_synapses.py similarity index 98% rename from brainpy/_src/dyn/synapses_v2/abstract_models.py rename to brainpy/_src/dyn/synapses_v2/abstract_synapses.py index 25f1de478..a28efa1a4 100644 --- a/brainpy/_src/dyn/synapses_v2/abstract_models.py +++ b/brainpy/_src/dyn/synapses_v2/abstract_synapses.py @@ -107,7 +107,7 @@ def reset_state(self, batch_size=None): if self.stp is not None: self.stp.reset_state(batch_size) - def update(self, pre_spike): + def update(self, pre_spike, post_v=None): if self.stp is not None: syn_value = self.stp(pre_spike) * pre_spike else: @@ -150,6 +150,6 @@ def update(self, pre_spike): # outputs if self.out is not None: - return self.out(self.g.value) + return self.out(self.g.value, post_v) else: return self.g.value diff --git a/brainpy/_src/dyn/synapses_v2/base.py b/brainpy/_src/dyn/synapses_v2/base.py index bcced8c0b..cc1a36e8d 100644 --- a/brainpy/_src/dyn/synapses_v2/base.py +++ b/brainpy/_src/dyn/synapses_v2/base.py @@ -119,7 +119,7 @@ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): class SynOut(DynamicalSystemNS): - def update(self, post_g): + def update(self, post_g, post_v): raise NotImplementedError def reset_state(self, batch_size: Optional[int] = None): diff --git a/brainpy/_src/dyn/synapses_v2/syn_outs.py b/brainpy/_src/dyn/synapses_v2/syn_outs.py index 9a783f8a1..435851ede 100644 --- a/brainpy/_src/dyn/synapses_v2/syn_outs.py +++ b/brainpy/_src/dyn/synapses_v2/syn_outs.py @@ -2,14 +2,14 @@ from typing import Union -from brainpy.math import Variable, exp -from brainpy.types import ArrayType from brainpy._src.dyn.synapses_v2.base import SynOut - +from brainpy.math import exp +from brainpy.types import ArrayType __all__ = [ 'COBA', 'CUBA', + 'MgBlock', ] @@ -34,17 +34,12 @@ class COBA(SynOut): CUBA """ - def __init__(self, - post_potential: Variable, - E: Union[float, ArrayType] = 0., - name: str = None, ): + def __init__(self, E: Union[float, ArrayType] = 0., name: str = None, ): super().__init__(name=name) self.E = E - self.post_potential = post_potential - def update(self, g): - I = g * (self.E - self.post_potential) - return I + def update(self, post_g, post_v): + return post_g * (self.E - post_v) class CUBA(SynOut): @@ -70,7 +65,7 @@ class CUBA(SynOut): def __init__(self, name: str = None, ): super().__init__(name=name) - def update(self, g): + def update(self, g, post_V): return g @@ -107,7 +102,6 @@ class MgBlock(SynOut): def __init__( self, - post_potential: Variable, E: Union[float, ArrayType] = 0., cc_Mg: Union[float, ArrayType] = 1.2, alpha: Union[float, ArrayType] = 0.062, @@ -115,14 +109,12 @@ def __init__( name: str = None, ): super().__init__(name=name) - assert isinstance(post_potential, Variable) - self.post_potential = post_potential self.E = E self.cc_Mg = cc_Mg self.alpha = alpha self.beta = beta - def update(self, g): - I = g * (self.E - self.post_potential) / (1 + self.cc_Mg / self.beta * exp(-self.alpha * self.post_potential)) + def update(self, post_g, post_v): + I = post_g * (self.E - post_v) / (1 + self.cc_Mg / self.beta * exp(-self.alpha * post_v)) return I diff --git a/brainpy/_src/dyn/synouts/conductances.py b/brainpy/_src/dyn/synouts/conductances.py index bf060d291..d498916ab 100644 --- a/brainpy/_src/dyn/synouts/conductances.py +++ b/brainpy/_src/dyn/synouts/conductances.py @@ -106,64 +106,3 @@ def filter(self, g): I = g * (self.E - V) return super(COBA, self).filter(I) - -class eCOBA(SynOut): - r"""Conductance-based synaptic output. - - Given the synaptic conductance, the model output the post-synaptic current with - - .. math:: - - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) - - Parameters - ---------- - E: float, ArrayType, ndarray - The reversal potential. - name: str - The model name. - - See Also - -------- - CUBA - """ - - def __init__(self, - post_potential: Variable, - E: Union[float, ArrayType] = 0., - name: str = None, ): - super().__init__(name=name) - self.E = E - self.post_potential = post_potential - - def update(self, g): - I = g * (self.E - self.post_potential) - return I - - -class eCUBA(SynOut): - r"""Current-based synaptic output. - - Given the conductance, this model outputs the post-synaptic current with a identity function: - - .. math:: - - I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) - - Parameters - ---------- - name: str - The model name. - - - See Also - -------- - COBA - """ - - def __init__(self, name: str = None, ): - super().__init__(name=name) - - def update(self, g): - return g - diff --git a/brainpy/_src/dyn/synouts/ions.py b/brainpy/_src/dyn/synouts/ions.py index c7b1f7579..020c7cd1e 100644 --- a/brainpy/_src/dyn/synouts/ions.py +++ b/brainpy/_src/dyn/synouts/ions.py @@ -94,55 +94,3 @@ def clone(self): target_var=self._target_var, membrane_var=self._membrane_var) - -class eMgBlock(SynOut): - r"""Synaptic output based on Magnesium blocking. - - Given the synaptic conductance, the model output the post-synaptic current with - - .. math:: - - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) - - where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to - - .. math:: - - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. - - Parameters - ---------- - E: float, ArrayType - The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType - Binding constant. Default 0.062 - beta: float, ArrayType - Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType - Concentration of Magnesium ion. Default 1.2 [mM]. - name: str - The model name. - """ - - def __init__( - self, - post_potential: bm.Variable, - E: Union[float, ArrayType] = 0., - cc_Mg: Union[float, ArrayType] = 1.2, - alpha: Union[float, ArrayType] = 0.062, - beta: Union[float, ArrayType] = 3.57, - name: str = None, - ): - super().__init__(name=name) - assert isinstance(post_potential, bm.Variable) - self.post_potential = post_potential - self.E = E - self.cc_Mg = cc_Mg - self.alpha = alpha - self.beta = beta - - def update(self, g): - I = g * (self.E - self.post_potential) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post_potential)) - return I diff --git a/brainpy/_src/dyn/transform.py b/brainpy/_src/dyn/transform.py index 6ca13ff49..0c4f95225 100644 --- a/brainpy/_src/dyn/transform.py +++ b/brainpy/_src/dyn/transform.py @@ -167,8 +167,8 @@ def __init__( assert isinstance(shared_arg, dict) shared_arg['dt'] = dt self.dt = dt - self.t0 = t0 - self.i0 = i0 + self.t0 = None if t0 is None else bm.Variable(bm.as_jax(t0)) + self.i0 = None if i0 is None else bm.Variable(bm.as_jax(i0)) self.jit = jit self.remat = remat @@ -207,9 +207,9 @@ def __call__( if isinstance(duration_or_xs, float): shared = tools.DotDict() if self.t0 is not None: - shared['t'] = jnp.arange(self.t0, duration_or_xs, self.dt) + shared['t'] = jnp.arange(self.t0.value, duration_or_xs, self.dt) if self.i0 is not None: - shared['i'] = jnp.arange(0, shared['t'].shape[0]) + shared['i'] = jnp.arange(self.i0.value, shared['t'].shape[0]) xs = None if self.no_state: raise ValueError('Under the `no_state=True` setting, input cannot be a duration.') @@ -269,26 +269,35 @@ def __call__( if self.no_state: share.save(**self.shared_arg) outputs = self._run(self.shared_arg, dict(), xs) - return tree_map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs) + results = tree_map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs) + if self.i0 is not None: + self.i0 += length[0] + if self.t0 is not None: + self.t0 += length[0] * self.dt + return results else: shared = tools.DotDict() - shared['t'] = jnp.arange(self.t0, self.dt * length[0], self.dt) - shared['i'] = jnp.arange(0, length[0]) + shared['t'] = jnp.arange(self.t0.value, self.dt * length[0], self.dt) + shared['i'] = jnp.arange(self.i0.value, length[0]) assert not self.no_state - return bm.for_loop(functools.partial(self._run, self.shared_arg), - (shared, xs), - child_objs=(self.target, share), - jit=self.jit, - remat=self.remat) + results = bm.for_loop(functools.partial(self._run, self.shared_arg), + (shared, xs), + child_objs=(self.target, share), + jit=self.jit, + remat=self.remat) + if self.i0 is not None: + self.i0 += length[0] + if self.t0 is not None: + self.t0 += length[0] * self.dt + return results def reset_state(self, batch_size=None): self.target.reset_state(batch_size) def _run(self, static_sh, dyn_sh, x): - share.save(**static_sh) - share.save(**dyn_sh) + share.save(**static_sh, **dyn_sh) outs = self.target(x) if self.out_vars is not None: outs = (outs, tree_map(bm.as_jax, self.out_vars)) diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 5b1ad169f..ca3641dce 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -36,6 +36,14 @@ ] +_register_pytree = True + + +def register_object_as_pytree(mode: bool): + global _register_pytree + _register_pytree = mode + + class BrainPyObject(object): """The BrainPyObject class for whole BrainPy ecosystem. @@ -53,9 +61,11 @@ class BrainPyObject(object): def __init__(self, name=None): super().__init__() - cls = self.__class__ - if cls not in _registry: - register_pytree_node_class(cls) + + if _register_pytree: + cls = self.__class__ + if cls not in _registry: + register_pytree_node_class(cls) # check whether the object has a unique name. self._name = None diff --git a/brainpy/syn.py b/brainpy/experimental.py similarity index 83% rename from brainpy/syn.py rename to brainpy/experimental.py index 8efc9f151..8dab17552 100644 --- a/brainpy/syn.py +++ b/brainpy/experimental.py @@ -12,6 +12,6 @@ CUBA as CUBA, COBA as COBA, ) -from brainpy._src.dyn.synapses_v2.abstract_models import ( +from brainpy._src.dyn.synapses_v2.abstract_synapses import ( Exponential as Exponential, ) diff --git a/brainpy/synouts.py b/brainpy/synouts.py index 3365f1038..5f66035b2 100644 --- a/brainpy/synouts.py +++ b/brainpy/synouts.py @@ -4,14 +4,7 @@ COBA as COBA, CUBA as CUBA, ) -from brainpy._src.dyn.synouts.conductances import ( - eCOBA, - eCUBA, -) from brainpy._src.dyn.synouts.ions import ( MgBlock as MgBlock, ) -from brainpy._src.dyn.synouts.ions import ( - eMgBlock, -) diff --git a/changes.md b/changes.md index d072e7470..891057cd2 100644 --- a/changes.md +++ b/changes.md @@ -1,5 +1,4 @@ -# Change from Version 2.3.1 to Version 2.3.2 - +# Change from Version 2.3.4 to Version 2.3.5 This release (under the branch of ``brainpy=2.3.x``) continues to add supports for brain-inspired computation. @@ -8,23 +7,149 @@ This release (under the branch of ``brainpy=2.3.x``) continues to add supports f ## New Features -### 1. New package structure for stable API release +### 1. ``brainpy.share`` for sharing data across submodules + +In this release, we abstract the shared data as a ``brainpy.share`` object. + +This object together with ``brainpy.Delay`` we will introduce below +constitute the support that enable to define SNN models like ANN ones. + + +### 2. ``brainpy.Delay`` for delay processing + +``Delay`` is abstracted as a dynamical system, which can be updated / retrieved by users. + +```python +import brainpy as bp + +class EINet(bp.DynamicalSystemNS): + def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None): + super().__init__() + + self.bg_exc = e_input + self.bg_inh = i_input + + # network size + num_exc = int(3200 * scale) + num_inh = int(800 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), input_var=False) + self.E = bp.neurons.LIF(num_exc, **pars) + self.I = bp.neurons.LIF(num_inh, **pars) + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) + ) + self.E2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) + ) + self.I2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) + ) + self.I2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) + ) + self.delayE = bp.Delay(self.E.spike, entries={'E': delay}) + self.delayI = bp.Delay(self.I.spike, entries={'I': delay}) + + def update(self): + e_spike = self.delayE.at('E') + i_spike = self.delayI.at('I') + e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc + i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh + self.delayE(self.E(e_inp)) + self.delayI(self.I(i_inp)) + +``` + + + +### 3. ``brainpy.checkpoints.save_pytree`` and ``brainpy.checkpoints.load_pytree`` for saving/loading target from the filename + +Now we can directly use ``brainpy.checkpoints.save_pytree`` to save a +network state into the filepath we specified. + +Similarly, we can use ``brainpy.checkpoints.load_pytree`` to load +states from the given file path. + + +### 4. More ANN layers + + +- brainpy.layers.ConvTranspose1d +- brainpy.layers.ConvTranspose2d +- brainpy.layers.ConvTranspose3d +- brainpy.layers.Conv1dLSTMCell +- brainpy.layers.Conv2dLSTMCell +- brainpy.layers.Conv3dLSTMCell + + +### 5. More compatible dense operators + +PyTorch operators: + +- brainpy.math.Tensor +- brainpy.math.flatten +- brainpy.math.cat +- brainpy.math.abs +- brainpy.math.absolute +- brainpy.math.acos +- brainpy.math.arccos +- brainpy.math.acosh +- brainpy.math.arccosh +- brainpy.math.add +- brainpy.math.addcdiv +- brainpy.math.addcmul +- brainpy.math.angle +- brainpy.math.asin +- brainpy.math.arcsin +- brainpy.math.asinh +- brainpy.math.arcsin +- brainpy.math.atan +- brainpy.math.arctan +- brainpy.math.atan2 +- brainpy.math.atanh -Unstable APIs are all hosted in ``brainpy._src`` module. -Other APIs are stable, and will be maintained in a long time. +TensorFlow operators: -### 2. New schedulers +- brainpy.math.concat +- brainpy.math.reduce_sum +- brainpy.math.reduce_max +- brainpy.math.reduce_min +- brainpy.math.reduce_mean +- brainpy.math.reduce_all +- brainpy.math.reduce_any +- brainpy.math.reduce_logsumexp +- brainpy.math.reduce_prod +- brainpy.math.reduce_std +- brainpy.math.reduce_variance +- brainpy.math.reduce_euclidean_norm +- brainpy.math.unsorted_segment_sqrt_n +- brainpy.math.segment_mean +- brainpy.math.unsorted_segment_sum +- brainpy.math.unsorted_segment_prod +- brainpy.math.unsorted_segment_max +- brainpy.math.unsorted_segment_min +- brainpy.math.unsorted_segment_mean +- brainpy.math.segment_sum +- brainpy.math.segment_prod +- brainpy.math.segment_max +- brainpy.math.segment_min +- brainpy.math.clip_by_value +- brainpy.math.cast -- `brainpy.optim.CosineAnnealingWarmRestarts` -- `brainpy.optim.CosineAnnealingLR` -- `brainpy.optim.ExponentialLR` -- `brainpy.optim.MultiStepLR` -- `brainpy.optim.StepLR` +### Others -### 3. Others +- Remove the hard requirements of ``brainpylib`` and ``numba``. -- support `static_argnums` in `brainpy.math.jit` -- fix bugs of `reset_state()` and `clear_input()` in `brainpy.channels` -- fix jit error checking diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py index 2a0e29b84..181aafb3d 100644 --- a/examples/dynamics_simulation/COBA.py +++ b/examples/dynamics_simulation/COBA.py @@ -1,5 +1,4 @@ import brainpy as bp -import brainpy.connect as C class EINet(bp.DynamicalSystemNS): @@ -22,31 +21,30 @@ def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None): # synapses we = 0.6 / scale # excitatory synaptic weight (voltage) wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.E.size, post=self.E.size), - g_max=we, tau=5., out=bp.syn.COBA(self.E.V, E=0.) + self.E2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) ) - self.E2I = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), - g_max=we, tau=5., out=bp.syn.COBA(self.I.V, E=0.) + self.E2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) ) - self.I2E = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.I.size, post=self.E.size), - g_max=wi, tau=10., out=bp.syn.COBA(self.E.V, E=-80.) + self.I2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) ) - self.I2I = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.I.size, post=self.I.size), - g_max=wi, tau=10., out=bp.syn.COBA(self.I.V, E=-80.) + self.I2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) ) self.delayE = bp.Delay(self.E.spike, entries={'E': delay}) self.delayI = bp.Delay(self.I.spike, entries={'I': delay}) - @bp.not_pass_sha def update(self): e_spike = self.delayE.at('E') i_spike = self.delayI.at('I') - e_inp = self.E2E(e_spike) + self.I2E(i_spike) + self.bg_exc - i_inp = self.I2I(i_spike) + self.E2I(e_spike) + self.bg_inh + e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc + i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh self.delayE(self.E(e_inp)) self.delayI(self.I(i_inp)) @@ -71,33 +69,30 @@ def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None): # synapses we = 0.6 / scale # excitatory synaptic weight (voltage) wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.E.size, post=self.E.size), - g_max=we, tau=5., out=bp.syn.COBA(self.E.V, E=0.) + self.E2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) ) - self.E2I = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), - g_max=we, tau=5., out=bp.syn.COBA(self.I.V, E=0.) + self.E2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) ) - self.I2E = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.I.size, post=self.E.size), - g_max=wi, tau=10., out=bp.syn.COBA(self.E.V, E=-80.) + self.I2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) ) - self.I2I = bp.syn.Exponential( - C.FixedProb(0.02, pre=self.I.size, post=self.I.size), - g_max=wi, tau=10., out=bp.syn.COBA(self.I.V, E=-80.) + self.I2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) ) - bp.share.save('t', 2.) - bp.share.save('dt', 0.1) bp.share.save('E-spike', bp.Delay(self.E.spike, entries={'E': delay})) bp.share.save('I-spike', bp.Delay(self.I.spike, entries={'I': delay})) - @bp.not_pass_sha def update(self): e_spike = bp.share.load('E-spike').at('E') i_spike = bp.share.load('I-spike').at('I') - e_inp = self.E2E(e_spike) + self.I2E(i_spike) + self.bg_exc - i_inp = self.I2I(i_spike) + self.E2I(e_spike) + self.bg_inh + e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc + i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh self.E(e_inp) self.I(i_inp) diff --git a/examples/dynamics_training/Song_2016_EI_RNN.py b/examples/dynamics_training/Song_2016_EI_RNN.py index 404d604a7..2eaba0d2b 100644 --- a/examples/dynamics_training/Song_2016_EI_RNN.py +++ b/examples/dynamics_training/Song_2016_EI_RNN.py @@ -75,7 +75,7 @@ def cell(self, x, h): def readout(self, h): return h @ self.w_ro + self.b_ro - @bp.not_pass_sha + @bp.not_pass_shared def update(self, x): self.h.value = self.cell(x, self.h) self.o.value = self.readout(self.h[:, :self.e_size])