From 49c32d72e1ef4e0234d7227da2e830cd9d2981e7 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 10 Nov 2023 13:30:11 +0800 Subject: [PATCH 1/3] [doc] update doc of state loading and saving --- .../state_saving_and_loading.ipynb | 95 ++++++++++--------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/docs/tutorial_toolbox/state_saving_and_loading.ipynb b/docs/tutorial_toolbox/state_saving_and_loading.ipynb index ef0922c4a..6e851bbe0 100644 --- a/docs/tutorial_toolbox/state_saving_and_loading.ipynb +++ b/docs/tutorial_toolbox/state_saving_and_loading.ipynb @@ -33,8 +33,8 @@ "name": "#%%\n" }, "ExecuteTime": { - "end_time": "2023-10-18T11:31:33.724617200Z", - "start_time": "2023-10-18T11:31:32.625523200Z" + "end_time": "2023-11-10T05:28:22.558070Z", + "start_time": "2023-11-10T05:28:20.063466800Z" } }, "outputs": [], @@ -64,9 +64,9 @@ "source": [ "State saving and loading in BrainPy are managed by a **local** function and a **global** function. \n", "\n", - "The **local function** is to save or load states in the current node. Particularly, ``__save_state__()`` and ``__load_state__()`` are local functions for saving and loading states. \n", + "The **local function** is to save or load states in the current node. Particularly, ``save_state()`` and ``load_state()`` are local functions for saving and loading states. \n", "\n", - "The **global function** is to save or load all states in the current and children nodes. Particularly, ``state_dict()`` and ``load_state_dict()`` are global functions for saving and loading states. " + "The **global function** is to save or load all states in the current and children nodes. Particularly, ``brainpy.save_state()`` and ``brainpy.load_state()`` are global functions for saving and loading states. " ], "metadata": { "collapsed": false @@ -94,8 +94,8 @@ "name": "#%%\n" }, "ExecuteTime": { - "end_time": "2023-10-18T11:31:33.730412Z", - "start_time": "2023-10-18T11:31:33.727125300Z" + "end_time": "2023-11-10T05:28:22.558070Z", + "start_time": "2023-11-10T05:28:22.555605500Z" } }, "outputs": [], @@ -121,8 +121,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.080422100Z", - "start_time": "2023-10-18T11:31:33.730412Z" + "end_time": "2023-11-10T05:28:23.436487700Z", + "start_time": "2023-11-10T05:28:22.558070Z" } }, "id": "59a6abf6a8eabaa9" @@ -153,13 +153,13 @@ } ], "source": [ - "net.__save_state__()" + "net.save_state()" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.093722800Z", - "start_time": "2023-10-18T11:31:34.080422100Z" + "end_time": "2023-11-10T05:28:23.460151400Z", + "start_time": "2023-11-10T05:28:23.438987500Z" } }, "id": "5eb9d839e47cf417" @@ -180,7 +180,7 @@ "outputs": [ { "data": { - "text/plain": "{'SNN0': {'SNN0.var': Array([0.], dtype=float32)},\n 'Dense0': {},\n 'Lif0': {'Lif0.V': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n 'Lif0.spike': Array([False, False, False, False, False, False, False, False, False,\n False], dtype=bool)},\n 'ExponentialEuler0': {}}" + "text/plain": "{'SNN0': {'SNN0.var': Array([0.], dtype=float32)},\n 'Dense0': {},\n 'Lif0': {'Lif0.V': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n 'Lif0.spike': Array([False, False, False, False, False, False, False, False, False,\n False], dtype=bool)}}" }, "execution_count": 5, "metadata": {}, @@ -188,13 +188,13 @@ } ], "source": [ - "net.state_dict()" + "bp.save_state(net)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.096851300Z", - "start_time": "2023-10-18T11:31:34.093722800Z" + "end_time": "2023-11-10T05:28:23.460151400Z", + "start_time": "2023-11-10T05:28:23.448336300Z" } }, "id": "a5e0fc0f7f424718" @@ -227,8 +227,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.106804200Z", - "start_time": "2023-10-18T11:31:34.096851300Z" + "end_time": "2023-11-10T05:28:23.460151400Z", + "start_time": "2023-11-10T05:28:23.457628200Z" } }, "id": "1b3cf2ec8272839f" @@ -263,8 +263,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.171028800Z", - "start_time": "2023-10-18T11:31:34.106804200Z" + "end_time": "2023-11-10T05:28:23.548940300Z", + "start_time": "2023-11-10T05:28:23.460151400Z" } }, "id": "2cdc6d82d53317e7" @@ -288,8 +288,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.171028800Z", - "start_time": "2023-10-18T11:31:34.124099300Z" + "end_time": "2023-11-10T05:28:23.564629600Z", + "start_time": "2023-11-10T05:28:23.485183100Z" } }, "id": "4d18c9fba2983e69" @@ -318,13 +318,13 @@ } ], "source": [ - "net.load_state_dict(states)" + "bp.load_state(net, states)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.171028800Z", - "start_time": "2023-10-18T11:31:34.129292800Z" + "end_time": "2023-11-10T05:28:23.564629600Z", + "start_time": "2023-11-10T05:28:23.492251800Z" } }, "id": "a585a32ef51654b" @@ -338,15 +338,15 @@ } }, "source": [ - "- ``bp.checkpoints.save_pytree(filename: str, target: PyTree, overwrite: bool = True, async_manager: Optional[AsyncManager] = None, verbose: bool = True)`` function requires you to provide a `filename` which is the path where checkpoint files will be stored. You also need to supply a `target`, which is a state dict object. An optional `overwrite` argument allows you to decide whether to overwrite existing checkpoint files \n", + "- ``brainpy.checkpoints.save_pytree(filename: str, target: PyTree, overwrite: bool = True, async_manager: Optional[AsyncManager] = None, verbose: bool = True)`` function requires you to provide a `filename` which is the path where checkpoint files will be stored. You also need to supply a `target`, which is a state dict object. An optional `overwrite` argument allows you to decide whether to overwrite existing checkpoint files \n", "if a checkpoint for the current step or a later one already exists. If you provide an `async_manager`, the save operation will be non-blocking on the main thread, but note that this is only suitable for a single host. However, any ongoing save will still prevent \n", "new saves to ensure overwrite logic remains correct. Finally, you can set the `verbose` argument to specify if you want to receive printed information about the operation.\n", "\n", - "- ``bp.checkpoints.load_pytree(filename: str, parallel: bool = True)`` function allows you to restore data from a given checkpoint file or a directory containing multiple checkpoints, which you specify with the `filename` argument. If you set the `parallel` argument to true, the function will attempt to load seekable checkpoints simultaneously for quicker results. When executed, the function returns the restored target from the checkpoint file. If no step is specified and there are no checkpoint files available, the function simply returns the input `target` without changes. If you specify a file path that doesn't exist, the function will also return the original `target`. This behavior mirrors the scenario where a directory path is given, but the directory hasn't been created yet.\n", + "- ``brainpy.checkpoints.load_pytree(filename: str, parallel: bool = True)`` function allows you to restore data from a given checkpoint file or a directory containing multiple checkpoints, which you specify with the `filename` argument. If you set the `parallel` argument to true, the function will attempt to load seekable checkpoints simultaneously for quicker results. When executed, the function returns the restored target from the checkpoint file. If no step is specified and there are no checkpoint files available, the function simply returns the input `target` without changes. If you specify a file path that doesn't exist, the function will also return the original `target`. This behavior mirrors the scenario where a directory path is given, but the directory hasn't been created yet.\n", "\n", - "- ``.state_dict()`` function retrieves the entire state of the module and returns it as a dictionary. \n", + "- ``brainpy.save_state(target)`` function retrieves the entire state of the ``target`` module and returns it as a dictionary. \n", "\n", - "- ``.load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True, compatible: str = 'v2')`` function is used to import parameters and buffers from a provided `state_dict` into the current module and all its child modules. You need to provide the function with a `state_dict`, which is a dictionary containing the desired parameters and persistent buffers to be loaded. Optionally, you can also provide a `warn` parameter (defaulting to True) that will generate warnings if there are keys in the provided `state_dict` that either don't match the current module's structure (unexpected keys) or are missing from the `state_dict` but exist in the module (missing keys). When executed, the function returns a `StateLoadResult`, a named tuple with two fields:\n", + "- ``brainpy.load_state(target, state_dict)`` function is used to import parameters and buffers from a provided `state_dict` into the current module and all its child modules. You need to provide the function with a `state_dict`, which is a dictionary containing the desired parameters and persistent buffers to be loaded. hen executed, the function returns a `StateLoadResult`, a named tuple with two fields:\n", " - **missing_keys**: A list of keys that are present in the module but missing in the provided `state_dict`.\n", " - **unexpected_keys**: A list of keys found in the `state_dict` that don't correspond to any part of the current module." ] @@ -409,8 +409,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:34.171028800Z", - "start_time": "2023-10-18T11:31:34.170239100Z" + "end_time": "2023-11-10T05:28:23.564629600Z", + "start_time": "2023-11-10T05:28:23.507605600Z" } }, "id": "8c70c70c785f620c" @@ -423,12 +423,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0, loss 1.1491968631744385\n", - "Epoch 1, loss 1.035304069519043\n", - "Epoch 2, loss 0.8735314607620239\n", - "Epoch 3, loss 0.745592474937439\n", - "Epoch 4, loss 0.6913021802902222\n", - "Epoch 5, loss 0.676512598991394\n" + "Epoch 0, loss 1.0733333826065063\n", + "Epoch 1, loss 0.9526105523109436\n", + "Epoch 2, loss 0.8582525253295898\n", + "Epoch 3, loss 0.7843770384788513\n", + "Epoch 4, loss 0.7399720549583435\n", + "Epoch 5, loss 0.7254235744476318\n", + "Epoch 9, loss 0.7122021913528442\n" ] } ], @@ -479,7 +480,7 @@ " l = trainer.f_train()\n", " if l < loss:\n", " loss = l\n", - " states = {'net': net.state_dict(), # save the state dict of the network in the checkpoint\n", + " states = {'net': bp.save_state(net), # save the state dict of the network in the checkpoint\n", " 'epoch_i': i,\n", " 'train_loss': loss}\n", " bp.checkpoints.save_pytree('snn.bp', states, verbose=False) # save the checkpoint\n", @@ -488,8 +489,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:36.268190500Z", - "start_time": "2023-10-18T11:31:34.171028800Z" + "end_time": "2023-11-10T05:28:26.375228100Z", + "start_time": "2023-11-10T05:28:23.507605600Z" } }, "id": "edbfcc58" @@ -517,13 +518,13 @@ "source": [ "# model loading\n", "state_dict = bp.checkpoints.load_pytree('snn.bp') # load the state dict\n", - "net.load_state_dict(state_dict['net']) # unpack the state dict and load it into the network" + "bp.load_state(net, state_dict['net']) # unpack the state dict and load it into the network" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-10-18T11:31:36.354738600Z", - "start_time": "2023-10-18T11:31:36.268190500Z" + "end_time": "2023-11-10T05:28:26.390898700Z", + "start_time": "2023-11-10T05:28:26.356488500Z" } }, "id": "621ac319" @@ -565,7 +566,7 @@ "source": [ "You can make your own saving and loading functions easily.\n", "\n", - "For customizing the saving and loading, users can overwrite ``__save_state__`` and ``__load_state__`` functions.\n", + "For customizing the saving and loading, users can overwrite ``save_state`` and ``load_state`` functions.\n", "\n", "Here is an example to customize:\n", "```python\n", @@ -577,7 +578,7 @@ " self.d = bm.var_list([bm.Variable(bm.random.rand(3)),\n", " bm.Variable(bm.random.rand(3))])\n", "\n", - " def __save_state__(self) -> dict:\n", + " def save_state(self) -> dict:\n", " state_dict = {'a': self.a,\n", " 'b': self.b,\n", " 'c': self.c}\n", @@ -586,7 +587,7 @@ "\n", " return state_dict\n", "\n", - " def __load_state__(self, state_dict):\n", + " def load_state(self, state_dict):\n", " self.a = state_dict['a']\n", " self.b = bm.asarray(state_dict['b'])\n", " self.c = bm.asarray(state_dict['c'])\n", @@ -596,9 +597,9 @@ "```\n", "\n", "\n", - "- ``__save_state__(self)`` function saves the state of the object's variables and returns a dictionary where the keys are the names of the variables and the values are the variables' contents.\n", + "- ``save_state(self)`` function saves the state of the object's variables and returns a dictionary where the keys are the names of the variables and the values are the variables' contents.\n", "\n", - "- ``__load_state__(self, state_dict: Dict)`` function loads the state of the object's variables from a provided dictionary (``state_dict``). \n", + "- ``load_state(self, state_dict: Dict)`` function loads the state of the object's variables from a provided dictionary (``state_dict``). \n", "At firstly it gets the current variables of the object.\n", "Then, it determines the intersection of keys from the provided state_dict and the object's variables.\n", "For each intersecting key, it updates the value of the object's variable with the value from state_dict.\n", From e1fa7c677dc3abc039651fdd5d92d99973043002 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 10 Nov 2023 23:02:26 +0800 Subject: [PATCH 2/3] [dyn] update STDP APIs and fix bugs --- brainpy/_src/delay.py | 5 +- brainpy/_src/dnn/linear.py | 387 +++++++++++------- brainpy/_src/dyn/projections/plasticity.py | 6 +- .../_src/dyn/projections/tests/test_STDP.py | 85 +++- .../math/jitconn/tests/test_event_matvec.py | 19 +- brainpy/_src/mixin.py | 12 +- 6 files changed, 321 insertions(+), 193 deletions(-) diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index d0450162b..ee0be5763 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -249,7 +249,10 @@ def register_entry( Return the self. """ if entry in self._registered_entries: - raise KeyError(f'Entry {entry} has been registered. You can use another key, or reuse the existing key. ') + raise KeyError(f'Entry {entry} has been registered. ' + f'The existing delay for the key {entry} is {self._registered_entries[entry]}. ' + f'The new delay for the key {entry} is {delay_time}. ' + f'You can use another key. ') if isinstance(delay_time, (np.ndarray, jax.Array)): assert delay_time.size == 1 and delay_time.ndim == 0 diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 314ffb19c..09bf2958d 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1,22 +1,24 @@ # -*- coding: utf-8 -*- +import numbers from typing import Dict, Optional, Union, Callable -import numba import jax -import numpy as np import jax.numpy as jnp +import numba +import numpy as np from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share +from brainpy._src.dnn.base import Layer +from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP from brainpy.check import is_initializer +from brainpy.connect import csr2csc from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding -from brainpy._src.dnn.base import Layer -from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP __all__ = [ 'Dense', 'Linear', @@ -30,7 +32,7 @@ ] -class Dense(Layer, SupportOnline, SupportOffline, SupportSTDP): +class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline): r"""A linear transformation applied over the last dimension of the input. Mathematically, this node can be defined as: @@ -199,19 +201,25 @@ def offline_fit(self, self.W.value = Wff self.b.value = bias[0] - def update_STDP(self, dW, constraints=None): + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): if isinstance(self.W, float): raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - if self.W.shape != dW.shape: - raise ValueError(f'The shape of delta_weight {dW.shape} ' - f'should be the same as the shape of weight {self.W.shape}.') if not isinstance(self.W, bm.Variable): self.tracing_variable('W', self.W, self.W.shape) - self.W += dW - if constraints is not None: - self.W.value = constraints(self.W) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max) Linear = Dense @@ -228,43 +236,44 @@ def update(self, x): return x -def event_mm(pre_spike, post_inc, weight, w_min, w_max): - return weight - +@numba.njit(nogil=True, fastmath=True, parallel=False) +def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): + out_w[:] = weight + for i in numba.prange(spike.shape[0]): + if spike[i]: + out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) -@numba.njit -def event_mm_imp(outs, ins): - pre_spike, post_inc, weight, w_min, w_max = ins - w_min = w_min[()] - w_max = w_max[()] - outs = outs - outs.fill(weight) - for i in range(pre_spike.shape[0]): - if pre_spike[i]: - outs[i] = np.clip(outs[i] + post_inc, w_min, w_max) +dense_on_pre_prim = bm.XLACustomOp(_cpu_dense_on_pre) -event_left_mm = bm.CustomOpByNumba(event_mm, event_mm_imp, multiple_results=False) +def dense_on_pre(weight, spike, trace, w_min, w_max): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + return dense_on_pre_prim(weight, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] -def event_mm2(post_spike, pre_inc, weight, w_min, w_max): - return weight +@numba.njit(nogil=True, fastmath=True, parallel=False) +def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): + out_w[:] = weight + for i in numba.prange(spike.shape[0]): + if spike[i]: + out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) -@numba.njit -def event_mm_imp2(outs, ins): - post_spike, pre_inc, weight, w_min, w_max = ins - w_min = w_min[()] - w_max = w_max[()] - outs = outs - outs.fill(weight) - for j in range(post_spike.shape[0]): - if post_spike[j]: - outs[:, j] = np.clip(outs[:, j] + pre_inc, w_min, w_max) +dense_on_post_prim = bm.XLACustomOp(_cpu_dense_on_post) -event_right_mm = bm.CustomOpByNumba(event_mm2, event_mm_imp2, multiple_results=False) +def dense_on_post(weight, spike, trace, w_min, w_max): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + return dense_on_post_prim(weight, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] class AllToAll(Layer, SupportSTDP): @@ -329,15 +338,25 @@ def update(self, pre_val): post_val = pre_val @ self.weight return post_val - def stdp_update_on_pre(self, pre_spike, trace, w_min=None, w_max=None): - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight.value = event_left_mm(pre_spike, trace, self.weight, w_min, w_max) - - def stdp_update_on_post(self, post_spike, trace, w_min=None, w_max=None): + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight.value = event_right_mm(post_spike, trace, self.weight, w_min, w_max) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) class OneToOne(Layer, SupportSTDP): @@ -373,6 +392,26 @@ def __init__( def update(self, pre_val): return pre_val * self.weight + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value += spike * trace + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value += spike * trace + class MaskedLinear(Layer, SupportSTDP): r"""Synaptic matrix multiplication with masked dense computation. @@ -427,23 +466,84 @@ def __init__( def update(self, x): return x @ self.mask_fun(self.weight * self.mask) - def update_STDP(self, dW, constraints=None): + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): if isinstance(self.weight, float): raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - if self.weight.shape != dW.shape: - raise ValueError(f'The shape of delta_weight {dW.shape} ' - f'should be the same as the shape of weight {self.weight.shape}.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) - self.weight += dW - if constraints is not None: - self.weight.value = constraints(self.weight) +class _CSRLayer(Layer, SupportSTDP): + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode) -class CSRLinear(Layer, SupportSTDP): + assert isinstance(conn, connect.TwoEndConnector) + assert sharding is None, 'Currently this model does not support sharding.' + self.conn = conn + self.sharding = sharding + self.transpose = transpose + + # connection + self.indices, self.indptr = self.conn.require('csr') + + # weight + weight = init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if bm.isscalar(self.weight): + raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') + if self.weight.shape != self.indices.shape: + raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: # update on presynaptic spike + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) + if on_post is not None: # update on postsynaptic spike + if not hasattr(self, '_pre_ids'): + with jax.ensure_compile_time_eval(): + self._pre_ids, self._post_indptr, self.w_indices = csr2csc( + [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) + ) + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr, + self.w_indices, spike, trace, w_min, w_max) + + +class CSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with CSR sparse computation. It performs the computation of: @@ -473,23 +573,8 @@ def __init__( method: str = 'cusparse', transpose: bool = True, ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - assert sharding is None, 'Currently this model does not support sharding.' - self.conn = conn - self.sharding = sharding + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) self.method = method - self.transpose = transpose - - # connection - self.indices, self.indptr = self.conn.require('csr') - - # weight - weight = init.parameter(weight, (self.indices.size,)) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight def update(self, x): if x.ndim == 1: @@ -511,25 +596,9 @@ def _batch_csrmv(self, x): transpose=self.transpose, method=self.method) - def update_STDP(self, dW, constraints=None): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr) - sparse_dW = dW[pre_ids, post_ids] - if self.weight.shape != sparse_dW.shape: - raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} ' - f'should be the same as the shape of sparse weight {self.weight.shape}.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight += sparse_dW - if constraints is not None: - self.weight.value = constraints(self.weight) - -class CSCLinear(Layer): - r"""Synaptic matrix multiplication with CSC sparse computation. +class EventCSRLinear(_CSRLayer): + r"""Synaptic matrix multiplication with event CSR sparse computation. It performs the computation of: @@ -537,13 +606,13 @@ class CSCLinear(Layer): y = x @ M - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a CSC sparse matrix. + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weight using a CSR sparse matrix. Args: conn: TwoEndConnector. The connection. weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. + sharding: The sharding strategy. mode: The synaptic computing mode. name: The synapse model name. """ @@ -555,16 +624,81 @@ def __init__( sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, + transpose: bool = True, ): - super().__init__(name=name, mode=mode) + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding + def update(self, x): + if x.ndim == 1: + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + def _batch_csrmv(self, x): + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) + + +@numba.njit(nogil=True, fastmath=True, parallel=False) +def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): + out_w[:] = w + w_min = w_min[()] + w_max = w_max[()] + for i in numba.prange(spike.shape[0]): # pre id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): # synapse id + j = indices[k] # post id + # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) + out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) + + +csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update) + + +def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + + +@numba.njit(nogil=True, fastmath=True, parallel=False) +def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + out_w[:] = w + w_min = w_min[()] + w_max = w_max[()] + for i in numba.prange(spike.shape[0]): # post id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = post_ids[k] # pre id + l = w_ids[k] # syn id + out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + + +csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) + + +def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] -class EventCSRLinear(Layer, SupportSTDP): - r"""Synaptic matrix multiplication with event CSR sparse computation. + +class CSCLinear(Layer): + r"""Synaptic matrix multiplication with CSC sparse computation. It performs the computation of: @@ -572,8 +706,8 @@ class EventCSRLinear(Layer, SupportSTDP): y = x @ M - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weight using a CSR sparse matrix. + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSC sparse matrix. Args: conn: TwoEndConnector. The connection. @@ -590,59 +724,12 @@ def __init__( sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, - transpose: bool = True, ): super().__init__(name=name, mode=mode) assert isinstance(conn, connect.TwoEndConnector) - assert sharding is None, 'Currently this model does not support sharding.' self.conn = conn self.sharding = sharding - self.transpose = transpose - - # connection - self.indices, self.indptr = self.conn.require('csr') - - # weight - weight = init.parameter(weight, (self.indices.size,)) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, x): - if x.ndim == 1: - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - elif x.ndim > 1: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_csrmv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_csrmv(self, x): - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - - def update_STDP(self, dW, constraints=None): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - with jax.ensure_compile_time_eval(): - pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr) - sparse_dW = dW[pre_ids, post_ids] - if self.weight.shape != sparse_dW.shape: - raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} ' - f'should be the same as the shape of sparse weight {self.weight.shape}.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight += sparse_dW - if constraints is not None: - self.weight.value = constraints(self.weight) class BcsrMM(Layer): diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index c51332e44..3ee6f4fef 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -163,7 +163,7 @@ def __init__( syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) else: - syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) + syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') out_cls = out() add_inp_fun(out_label, self.name, out_cls, post) @@ -206,9 +206,9 @@ def update(self): # weight updates Apost = self.refs['post_trace'].g - self.comm.stdp_update_on_pre(pre_spike, -Apost * self.A2, self.W_min, self.W_max) + self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max) Apre = self.refs['pre_trace'].g - self.comm.stdp_update_on_post(post_spike, Apre * self.A1, self.W_min, self.W_max) + self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max) # synaptic currents current = self.comm(x) diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index 001afc02e..a4173c7ba 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import matplotlib.pyplot as plt + import numpy as np from absl.testing import parameterized @@ -9,22 +9,67 @@ class Test_STDP(parameterized.TestCase): - def test_STDP(self): + + @parameterized.product( + comm_method=['dense', 'csr', 'masked_linear', 'all2all', 'one2one'], + delay=[None, 0., 2.], + syn_model=['exp', 'dual_exp', 'ampa'], + out_model=['cuba', 'coba', 'mg'] + ) + def test_STDP(self, comm_method, delay, syn_model, out_model): bm.random.seed() class STDPNet(bp.DynamicalSystem): def __init__(self, num_pre, num_post): super().__init__() - self.pre = bp.dyn.LifRef(num_pre, name='neu1') - self.post = bp.dyn.LifRef(num_post, name='neu2') + self.pre = bp.dyn.LifRef(num_pre) + self.post = bp.dyn.LifRef(num_post) + + if comm_method == 'all2all': + comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'csr': + if syn_model == 'exp': + comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + else: + comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'masked_linear': + comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'dense': + comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'one2one': + comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) + else: + raise ValueError + + if syn_model == 'exp': + syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) + elif syn_model == 'dual_exp': + syn = bp.dyn.DualExpon.desc(self.post.varshape) + elif syn_model == 'dual_exp_v2': + syn = bp.dyn.DualExponV2.desc(self.post.varshape) + elif syn_model == 'ampa': + syn = bp.dyn.AMPA.desc(self.post.varshape) + else: + raise ValueError + + if out_model == 'cuba': + out = bp.dyn.CUBA.desc() + elif out_model == 'coba': + out = bp.dyn.COBA.desc(E=0.) + elif out_model == 'mg': + out = bp.dyn.MgBlock.desc(E=0.) + else: + raise ValueError + self.syn = bp.dyn.STDP_Song2000( pre=self.pre, - delay=1., - # comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - # weight=bp.init.Uniform(0., 0.1)), - comm=bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)), - syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), + delay=delay, + comm=comm, + syn=syn, + out=out, post=self.post, tau_s=16.8, tau_t=33.7, @@ -42,7 +87,11 @@ def update(self, I_pre, I_post): Apre = self.syn.refs['pre_trace'].g Apost = self.syn.refs['post_trace'].g current = self.post.sum_inputs(self.post.V) - return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight.flatten() + if comm_method == 'dense': + w = self.syn.comm.W.flatten() + else: + w = self.syn.comm.weight.flatten() + return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w duration = 300. I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], @@ -59,11 +108,13 @@ def run(i, I_pre, I_post): indices = np.arange(int(duration / bm.dt)) pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) - fig, gs = bp.visualize.get_figure(4, 1, 3, 10) - bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) - bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) - bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) - bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) - plt.show() + # import matplotlib.pyplot as plt + # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) + # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) + # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) + # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) + # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) + # plt.show() bm.clear_buffer_memory() + diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index f442cbada..016f9b0dd 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -13,7 +13,6 @@ if platform.system() == 'Windows' and not is_manual_test: pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True) - shapes = [(100, 200), (10, 1000), (2, 1000), @@ -33,9 +32,9 @@ def __init__(self, *args, platform='cpu', **kwargs): outdim_parallel=[True, False], shape=shapes, prob=[0.01, 0.1, 0.5], - homo_data= [-1., ], + homo_data=[-1., ], bool_event=[True, False], - seed = [1234], + seed=[1234], ) def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False): print(f'_test_homo: ' @@ -96,14 +95,12 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve @parameterized.product( transpose=[True, False], - - x64= [True, False], - outdim_parallel= [True, False], - shape= shapes, - prob= [0.01, 0.1, 0.5], - bool_event= [True, False], - - seed = [1234], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.01, 0.1, 0.5], + bool_event=[True, False], + seed=[1234], ) def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False): print(f'_test_homo_vmap: ' diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index f356f44b3..8ea8a5216 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -483,17 +483,7 @@ class SupportSTDP(MixIn): """Support synaptic plasticity by modifying the weights. """ - def update_STDP( - self, - dW: Union[bm.Array, jax.Array], - constraints: Optional[Callable] = None, - ): - raise NotImplementedError - - def stdp_update_on_pre(self, pre_spike, trace, *args, **kwargs): - raise NotImplementedError - - def stdp_update_on_post(self, post_spike, trace, *args, **kwargs): + def stdp_update(self, *args, on_pre=None, onn_post=None, **kwargs): raise NotImplementedError From ebea6b67745ba27137db53ac678b0c4f4aba42da Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 10 Nov 2023 23:32:26 +0800 Subject: [PATCH 3/3] [doc] update state resetting APIs --- docs/tutorial_toolbox/state_resetting.ipynb | 8 ++++---- setup.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/tutorial_toolbox/state_resetting.ipynb b/docs/tutorial_toolbox/state_resetting.ipynb index 04cd4e9f4..19b81308f 100644 --- a/docs/tutorial_toolbox/state_resetting.ipynb +++ b/docs/tutorial_toolbox/state_resetting.ipynb @@ -18,7 +18,7 @@ "Similar to [state saving and loading](./saving_and_loading.ipynb) , state resetting is implemented with two functions:\n", "\n", "- a local function ``.reset_state()`` which resets all local variables in the current node.\n", - "- a global function ``.reset()`` which resets all variables in parent and children nodes." + "- a global function ``brainpy.reset_state()`` which resets all variables in parent and children nodes." ], "metadata": { "collapsed": false @@ -93,7 +93,7 @@ { "cell_type": "markdown", "source": [ - "By calling ``net.reset()``, we can reset all states in this network, including variables in the neurons, synapses, and networks. By using ``net.reset_state()``, we can reset the local variables which are defined in the current network. " + "By calling ``brainpy.reset_state(net)``, we can reset all states in this network, including variables in the neurons, synapses, and networks. By using ``net.reset_state()``, we can reset the local variables which are defined in the current network. " ], "metadata": { "collapsed": false @@ -115,7 +115,7 @@ ], "source": [ "print('Before reset:', net.N.V.value)\n", - "net.reset()\n", + "bp.reset_state(net)\n", "print('After reset:', net.N.V.value)" ], "metadata": { @@ -157,7 +157,7 @@ { "cell_type": "markdown", "source": [ - "There is no change for the ``V`` variable, meaning that the network's ``reset_state()`` can not reset states in the children node. Instead, to reset the whole states of the network, users should use ``reset()`` function. " + "There is no change for the ``V`` variable, meaning that the network's ``reset_state()`` can not reset states in the children node. Instead, to reset the whole states of the network, users should use ``brainpy.reset_state()`` function. " ], "metadata": { "collapsed": false diff --git a/setup.py b/setup.py index ef051aa0c..69c33cdfe 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,6 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11',