Version 2.4.6
This release contains more than 130 commit updates, and has provided several new features.
New Features
1. surrogate gradient functions are more transparent.
New instances can be used to compute the surrogate gradients. For example:
import brainpy.math as bm
fun = bm.surrogate.Sigmoid()
# forward function
spk = fun(membrane_potential)
# backward function
dV = fun.surrogate_grad(1., membrane_potential)
# surrogate forward function
surro_spk = fun.surrogate_fun(membrane_potential)
2. Add brainpy.math.eval_shape
for evaluating the all dynamical variables used in the target function.
This function is similar to jax.eval_shape
which has no FLOPs, while it can extract all variables used in the target function. For example:
net = ... # any dynamical system
inputs = ... # inputs to the dynamical system
variables, outputs= bm.eval_shape(net, inputs)
# "variables" are all variables used in the target "net"
In future, this function will be used everywhere to transform all jax transformations into brainpy's oo transformations.
3. Generalize tools and interfaces for state managements.
For a single object:
- The
.reset_state()
defines the state resetting of all local variables in this node. - The
.load_state()
defines the state loading from external disks (typically, a dict is passed into this.load_state()
function). - The
.save_state()
defines the state saving to external disks (typically, the.save_state()
function generates a dict containing all variable values).
Here is an example to define a full class of brainpy.DynamicalSystem
.
import brainpy as bp
class YouDynSys(bp.DynamicalSystem):
def __init__(self, ): # define parameters
self.par1 = ....
self.num = ...
def reset_state(self, batch_or_mode=None): # define variables
self.a = bp.init.variable_(bm.zeros, (self.num,), batch_or_mode)
def load_state(self, state_dict): # load states from an external dict
self.a.value = bm.as_jax(state_dict['a'])
def save_state(self): # save states as an external dict
return {'a': self.a.value}
For a complex network model, brainpy provide unified state managment interface for initializing, saving, and loading states.
- The
brainpy.reset_state()
defines the state resetting of all variables in this node and its children nodes. - The
brainpy.load_state()
defines the state loading from external disks of all variables in the node and its children. - The
brainpy.save_state()
defines the state saving to external disks of all variables in the node and its children. - The
brainpy.clear_input()
defines the clearing of all input variables in the node and its children.
4. Unified brain simulation and brain-inspired computing interface through automatic membrane scaling.
The same model used in brain simulation can be easily transformed into the one used for brain-inspired computing for training. For example,
class EINet(bp.DynSysGroup):
def __init__(self):
super().__init__()
self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
self.E = bp.dyn.ProjAlignPost1(
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=3200, post=4000), weight=bp.init.Normal(0.6, 0.01)),
syn=bp.dyn.Expon(size=4000, tau=5.),
out=bp.dyn.COBA(E=0.),
post=self.N
)
self.I = bp.dyn.ProjAlignPost1(
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=800, post=4000), weight=bp.init.Normal(6.7, 0.01)),
syn=bp.dyn.Expon(size=4000, tau=10.),
out=bp.dyn.COBA(E=-80.),
post=self.N
)
def update(self, input):
spk = self.delay.at('I')
self.E(spk[:3200])
self.I(spk[3200:])
self.delay(self.N(input))
return self.N.spike.value
# used for brain simulation
with bm.environment(mode=bm.nonbatching_mode):
net = EINet()
# used for brain-inspired computing
# define the `membrane_scaling` parameter
with bm.environment(mode=bm.TrainingMode(128), membrane_scaling=bm.Scaling.transform([-60., -50.])):
net = EINet()
5. New apis for operator customization on CPU and GPU devices through brainpy.math.XLACustomOp
.
Starting from this release, brainpy introduces Taichi for operator customization. Now, users can write CPU and GPU operators through numba and taichi syntax on CPU device, and taichi syntax on GPu device. Particularly, to define an operator, user can use:
import numba as nb
import taichi as ti
import numpy as np
import jax
import brainpy.math as bm
@nb.njit
def numba_cpu_fun(a, b, out_a, out_b):
out_a[:] = a
out_b[:] = b
@ti.kernel
def taichi_gpu_fun(a, b, out_a, out_b):
for i in range(a.size):
out_a[i] = a[i]
for i in range(b.size):
out_b[i] = b[i]
prim = bm.XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun)
a2, b2 = prim(np.random.random(1000), np.random.random(1000),
outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
jax.ShapeDtypeStruct(1000, dtype=np.float32)])
6. Generalized STDP models which are compatible with diverse synapse models.
See https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/dyn/projections/tests/test_STDP.py
What's Changed
- [bug] fix compatible bug by @chaoming0625 in #508
- [docs] add low-level op customization by @ztqakita in #507
- Compatible with
jax==0.4.16
by @chaoming0625 in #511 - updates for parallelization support by @chaoming0625 in #514
- Upgrade surrogate gradient functions by @chaoming0625 in #516
- [doc] update operator customization by @chaoming0625 in #517
- Updates for OO transforma and surrogate functions by @chaoming0625 in #519
- [dyn] add neuron scaling by @ztqakita in #520
- State saving, loading, and resetting by @chaoming0625 in #521
- [delay] rewrite previous delay APIs so that they are compatible with new brainpy version by @chaoming0625 in #522
- [projection] upgrade projections so that APIs are reused across different models by @chaoming0625 in #523
- [math] the interface for operator registration by @chaoming0625 in #524
- FIx bug in Delay by @ztqakita in #525
- Fix bugs in membrane scaling by @ztqakita in #526
- [math] Implement taichi op register by @Routhleck in #527
- Link libtaichi_c_api.so when import brainpylib by @Routhleck in #528
- update taichi op customization by @chaoming0625 in #529
- Fix error message by @HoshinoKoji in #530
- [math] remove the hard requirement of
taichi
by @chaoming0625 in #531 - [math] Resolve encoding of source kernel when ti.func is nested in ti… by @Routhleck in #532
- [math] new abstract function for XLACustomOp, fix its bugs by @chaoming0625 in #534
- [math] fix numpy array priority by @chaoming0625 in #533
- [brainpy.share] add category shared info by @chaoming0625 in #535
- [doc] update documentations by @chaoming0625 in #536
- [doc] update doc by @chaoming0625 in #537
- [dyn] add
brainpy.reset_state()
andbrainpy.clear_input()
for more consistent and flexible state managements by @chaoming0625 in #538 - [math] simplify the taichi AOT operator customization interface by @chaoming0625 in #540
- [dyn] add
save_state
,load_state
,reset_state
, andclear_input
helpers by @chaoming0625 in #542 - [dyn] update STDP APIs on CPUs and fix bugs by @chaoming0625 in #543
New Contributors
- @HoshinoKoji made their first contribution in #530
Full Changelog: V2.4.5...V2.4.6