Skip to content

Commit

Permalink
Fix several bugs, update APIs and docs (#304)
Browse files Browse the repository at this point in the history
Fix several bugs, update APIs and docs
  • Loading branch information
chaoming0625 authored Nov 29, 2022
2 parents 4c3433d + 5281ef7 commit af185c5
Show file tree
Hide file tree
Showing 18 changed files with 280 additions and 244 deletions.
311 changes: 188 additions & 123 deletions brainpy/connect/base.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions brainpy/connect/custom_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class CSRConn(TwoEndConnector):
def __init__(self, indices, inptr):
super(CSRConn, self).__init__()

self.indices = bm.asarray(indices).astype(IDX_DTYPE)
self.inptr = bm.asarray(inptr).astype(IDX_DTYPE)
self.indices = bm.asarray(indices, dtype=IDX_DTYPE)
self.inptr = bm.asarray(inptr, dtype=IDX_DTYPE)
self.pre_num = self.inptr.size - 1
self.max_post = bm.max(self.indices)

Expand Down Expand Up @@ -110,3 +110,5 @@ def __init__(self, csr_mat):
self.csr_mat = csr_mat
super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE),
inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE))
self.pre_num = csr_mat.shape[0]
self.post_num = csr_mat.shape[1]
2 changes: 1 addition & 1 deletion brainpy/connect/regular_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_csr(self):
f'same size, but {self.pre_num} != {self.post_num}.')
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return np.asarray(ind, dtype=IDX_DTYPE), np.arange(indptr, dtype=IDX_DTYPE),
return (np.asarray(ind, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE))

def build_mat(self, pre_size=None, post_size=None):
if self.pre_num != self.post_num:
Expand Down
14 changes: 7 additions & 7 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,15 +988,15 @@ def __init__(
ltp.register_master(master=self)
self.ltp: SynLTP = ltp

def init_weights(
def _init_weights(
self,
weight: Union[float, Array, Initializer, Callable],
comp_method: str,
sparse_data: str = 'csr'
) -> Union[float, Array]:
if comp_method not in ['sparse', 'dense']:
raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
if sparse_data not in ['csr', 'ij']:
if sparse_data not in ['csr', 'ij', 'coo']:
raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
Expand All @@ -1014,11 +1014,11 @@ def init_weights(
if comp_method == 'sparse':
if sparse_data == 'csr':
conn_mask = self.conn.require('pre2post')
elif sparse_data == 'ij':
elif sparse_data in ['ij', 'coo']:
conn_mask = self.conn.require('post_ids', 'pre_ids')
else:
ValueError(f'Unknown sparse data type: {sparse_data}')
weight = parameter(weight, conn_mask[1].shape, allow_none=False)
weight = parameter(weight, conn_mask[0].shape, allow_none=False)
elif comp_method == 'dense':
weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
conn_mask = self.conn.require('conn_mat')
Expand All @@ -1030,7 +1030,7 @@ def init_weights(
weight = bm.TrainVar(weight)
return weight, conn_mask

def syn2post_with_all2all(self, syn_value, syn_weight):
def _syn2post_with_all2all(self, syn_value, syn_weight):
if bm.ndim(syn_weight) == 0:
if isinstance(self.mode, BatchingMode):
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
Expand All @@ -1043,10 +1043,10 @@ def syn2post_with_all2all(self, syn_value, syn_weight):
post_vs = syn_value @ syn_weight
return post_vs

def syn2post_with_one2one(self, syn_value, syn_weight):
def _syn2post_with_one2one(self, syn_value, syn_weight):
return syn_value * syn_weight

def syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
if bm.ndim(syn_weight) == 0:
post_vs = (syn_weight * syn_value) @ conn_mat
else:
Expand Down
24 changes: 13 additions & 11 deletions brainpy/dyn/layers/activate.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
from brainpy.dyn.base import DynamicalSystem
from typing import Optional
from brainpy.modes import Mode
from typing import Callable
from typing import Optional

from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, training


class Activation(DynamicalSystem):
r"""Applies a activation to the inputs
r"""Applies an activation function to the inputs
Parameters:
----------
activate_fun: Callable
activate_fun: Callable, function
The function of Activation
name: str, Optional
The name of the object
mode: Mode
Enable training this node or not. (default True).
"""

def __init__(self,
activate_fun: Callable,
name: Optional[str] = None,
mode: Optional[Mode] = None,
**kwargs,
):
def __init__(
self,
activate_fun: Callable,
name: Optional[str] = None,
mode: Mode = training,
**kwargs,
):
super().__init__(name, mode)
self.activate_fun = activate_fun
self.kwargs = kwargs
Expand Down
18 changes: 10 additions & 8 deletions brainpy/dyn/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from brainpy.dyn.base import DynamicalSystem
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
from brainpy.tools.checking import check_initializer
from brainpy.types import Array

Expand Down Expand Up @@ -201,17 +201,19 @@ class Flatten(DynamicalSystem):
mode: Mode
Enable training this node or not. (default True)
"""
def __init__(self,
name: Optional[str] = None,
mode: Optional[Mode] = batching,
):

def __init__(
self,
name: Optional[str] = None,
mode: Optional[Mode] = batching,
):
super().__init__(name, mode)

def update(self, shr, x):
if isinstance(self.mode, BatchingMode):
return x.reshape((x.shape[0], -1))
else:
return x.flatten()

def reset_state(self, batch_size=None):
pass
pass
32 changes: 16 additions & 16 deletions brainpy/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.comp_method = comp_method

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method=comp_method, sparse_data='csr')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
Expand All @@ -143,10 +143,10 @@ def update(self, tdi, pre_spike=None):
# synaptic values onto the post
if isinstance(self.conn, All2All):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
Expand All @@ -160,7 +160,7 @@ def update(self, tdi, pre_spike=None):
# post_vs *= f2(stp_value)
else:
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
if self.post_ref_key:
post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key))

Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')

# variables
self.g = variable_(bm.zeros, self.post.num, mode)
Expand Down Expand Up @@ -328,11 +328,11 @@ def update(self, tdi, pre_spike=None):
if isinstance(self.conn, All2All):
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
Expand All @@ -343,7 +343,7 @@ def update(self, tdi, pre_spike=None):
else:
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
# updates
self.g.value = self.integral(self.g.value, t, dt) + post_vs

Expand Down Expand Up @@ -487,7 +487,7 @@ def __init__(
f'But we got {self.tau_decay}')

# connections
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.h = variable_(bm.zeros, self.pre.num, mode)
Expand Down Expand Up @@ -531,16 +531,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
Expand Down Expand Up @@ -829,7 +829,7 @@ def __init__(
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable_(bm.zeros, self.pre.num, mode)
Expand Down Expand Up @@ -872,16 +872,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
Expand Down
16 changes: 8 additions & 8 deletions brainpy/dyn/synapses/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(
raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}')

# connection
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
Expand Down Expand Up @@ -226,16 +226,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
Expand Down Expand Up @@ -526,7 +526,7 @@ def __init__(
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
Expand Down Expand Up @@ -575,16 +575,16 @@ def update(self, tdi, pre_spike=None):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
16 changes: 14 additions & 2 deletions brainpy/math/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import os
import re

from jax import dtypes, config, numpy as jnp
from jax import dtypes, config, numpy as jnp, devices
from jax.lib import xla_bridge

__all__ = [
'enable_x64',
'disable_x64',
'set_platform',
'get_platform',
'set_host_device_count',

# device memory
Expand Down Expand Up @@ -92,7 +93,7 @@ def disable_x64():
config.update("jax_enable_x64", False)


def set_platform(platform):
def set_platform(platform: str):
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
Expand All @@ -101,6 +102,17 @@ def set_platform(platform):
config.update("jax_platform_name", platform)


def get_platform() -> str:
"""Get the computing platform.
Returns
-------
platform: str
Either 'cpu', 'gpu' or 'tpu'.
"""
return devices()[0].platform


def set_host_device_count(n):
"""
By default, XLA considers all CPU cores as one device. This utility tells XLA
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx-mathjax-offline',
# 'sphinx-mathjax-offline',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_autodoc_typehints',
Expand Down
Loading

0 comments on commit af185c5

Please sign in to comment.