Skip to content

Commit

Permalink
Merge pull request #343 from chaoming0625/master
Browse files Browse the repository at this point in the history
Fix bug and more surrogate grad function supports
  • Loading branch information
chaoming0625 authored Mar 4, 2023
2 parents d8d23db + 1480220 commit 0a519c0
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 64 deletions.
6 changes: 2 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.3.5"
__version__ = "2.3.6"


# fundamental supporting modules
Expand Down Expand Up @@ -75,8 +75,7 @@
TwoEndConn as TwoEndConn,
CondNeuGroup as CondNeuGroup,
Channel as Channel)
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share, Delay

Expand Down Expand Up @@ -207,7 +206,6 @@
dyn.__dict__['TwoEndConn'] = TwoEndConn
dyn.__dict__['CondNeuGroup'] = CondNeuGroup
dyn.__dict__['Channel'] = Channel
dyn.__dict__['NoSharedArg'] = NoSharedArg
dyn.__dict__['LoopOverTime'] = LoopOverTime
dyn.__dict__['DSRunner'] = DSRunner

Expand Down
61 changes: 9 additions & 52 deletions brainpy/_src/dyn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

__all__ = [
'LoopOverTime',
'NoSharedArg',
]


Expand Down Expand Up @@ -207,12 +206,13 @@ def __call__(
if isinstance(duration_or_xs, float):
shared = tools.DotDict()
if self.t0 is not None:
shared['t'] = jnp.arange(self.t0.value, duration_or_xs, self.dt)
shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value
if self.i0 is not None:
shared['i'] = jnp.arange(self.i0.value, shared['t'].shape[0])
shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value
xs = None
if self.no_state:
raise ValueError('Under the `no_state=True` setting, input cannot be a duration.')
length = shared['t'].shape

else:
inp_err_msg = ('\n'
Expand Down Expand Up @@ -278,8 +278,8 @@ def __call__(

else:
shared = tools.DotDict()
shared['t'] = jnp.arange(self.t0.value, self.dt * length[0], self.dt)
shared['i'] = jnp.arange(self.i0.value, length[0])
shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value
shared['i'] = jnp.arange(0, length[0]) + self.i0.value

assert not self.no_state
results = bm.for_loop(functools.partial(self._run, self.shared_arg),
Expand All @@ -295,6 +295,10 @@ def __call__(

def reset_state(self, batch_size=None):
self.target.reset_state(batch_size)
if self.i0 is not None:
self.i0.value = jnp.asarray(0)
if self.t0 is not None:
self.t0.value = jnp.asarray(0.)

def _run(self, static_sh, dyn_sh, x):
share.save(**static_sh, **dyn_sh)
Expand All @@ -304,50 +308,3 @@ def _run(self, static_sh, dyn_sh, x):
self.target.clear_input()
return outs


class NoSharedArg(DynSysToBPObj):
"""Transform an instance of :py:class:`~.DynamicalSystem` into a callable
:py:class:`~.BrainPyObject` :math:`y=f(x)`.
.. note::
This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`.
If some children nodes need shared arguments, like :py:class:`~.Dropout` or
:py:class:`~.LIF` models, using ``NoSharedArg`` will cause errors.
Examples
--------
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> l = bp.Sequential(bp.layers.Dense(100, 10),
>>> bm.relu,
>>> bp.layers.Dense(10, 2))
>>> l = bp.NoSharedArg(l)
>>> l(bm.random.random(256, 100))
Parameters
----------
target: DynamicalSystem
The target to transform.
name: str
The transformed object name.
"""

def __init__(self, target: DynamicalSystem, name: str = None):
super().__init__(target=target, name=name)
if isinstance(target, Sequential) and target.no_shared_arg:
raise ValueError(f'It is a {Sequential.__name__} object with `no_shared_arg=True`, '
f'which has already able to be called with `f(x)`. ')

def __call__(self, *args, **kwargs):
return self.target(tools.DotDict(), *args, **kwargs)

def reset(self, batch_size=None):
"""Reset function which reset the whole variables in the model.
"""
self.target.reset(batch_size)

def reset_state(self, batch_size=None):
self.target.reset_state(batch_size)
5 changes: 4 additions & 1 deletion brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def clone(self):
return self.__class__()


def set_environment(
def set(
mode: modes.Mode = None,
dt: float = None,
x64: bool = None,
Expand Down Expand Up @@ -381,6 +381,9 @@ def set_environment(
set_complex(complex_)


set_environment = set


class environment(_DecoratorContextManager):
r"""Context-manager that sets a computing environment for brain dynamics computation.
Expand Down
Loading

0 comments on commit 0a519c0

Please sign in to comment.