diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py index 2a06928d4..6c0ef36fa 100644 --- a/brainpy/connect/base.py +++ b/brainpy/connect/base.py @@ -1,14 +1,13 @@ # -*- coding: utf-8 -*- import abc -from typing import Union, List, Tuple, Any +from typing import Union, List, Tuple import jax.numpy as jnp import numpy as onp -from jax import config + from brainpy import tools, math as bm from brainpy.errors import ConnectorError -from brainpy.tools.others import numba_jit, numba_range __all__ = [ # the connection types @@ -25,7 +24,9 @@ 'Connector', 'TwoEndConnector', 'OneEndConnector', # methods - 'csr2csc', 'csr2mat', 'mat2csr', 'ij2csr' + 'mat2coo', 'mat2csc', 'mat2csr', + 'csr2csc', 'csr2mat', 'csr2coo', + 'coo2csr', 'coo2csc', 'coo2mat', ] CONN_MAT = 'conn_mat' @@ -37,12 +38,16 @@ POST2SYN = 'post2syn' PRE_SLICE = 'pre_slice' POST_SLICE = 'post_slice' +COO = 'coo' +CSR = 'csr' +CSC = 'csc' SUPPORTED_SYN_STRUCTURE = [CONN_MAT, PRE_IDS, POST_IDS, PRE2POST, POST2PRE, PRE2SYN, POST2SYN, - PRE_SLICE, POST_SLICE] + PRE_SLICE, POST_SLICE, + COO, CSR, CSC] MAT_DTYPE = jnp.bool_ IDX_DTYPE = jnp.uint32 @@ -100,7 +105,7 @@ class TwoEndConnector(Connector): 1. Implementing ``build_conn(self)`` function, which returns one of the connection data ``csr`` (CSR sparse data, a tuple of ), - ``ij`` (COO sparse data, a tuple of ), or ``mat`` + ``coo`` (COO sparse data, a tuple of ), or ``mat`` (a binary connection matrix). For instance, .. code-block:: python @@ -108,7 +113,7 @@ class TwoEndConnector(Connector): import brainpy as bp class MyConnector(bp.conn.TwoEndConnector): def build_conn(self): - return dict(csr=, mat=, ij=) + return dict(csr=, mat=, coo=) 2. Implementing functions ``build_mat()``, ``build_csr()``, and ``build_coo()``. Users can provide all three functions, or one of them. @@ -202,13 +207,10 @@ def _return_by_mat(self, structures, mat, all_data: dict): if (CONN_MAT in structures) and (CONN_MAT not in all_data): all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE) - require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0 - if require_other_structs: - np = onp if isinstance(mat, onp.ndarray) else bm - pre_ids, post_ids = np.where(mat > 0) - pre_ids = np.asarray(pre_ids, dtype=IDX_DTYPE) - post_ids = np.asarray(post_ids, dtype=IDX_DTYPE) - self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data) + if len([s for s in structures + if s not in [CONN_MAT]]) > 0: + ij = mat2coo(mat) + self._return_by_coo(structures, coo=ij, all_data=all_data) def _return_by_csr(self, structures, csr: tuple, all_data: dict): indices, indptr = csr @@ -226,15 +228,29 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict): if (POST_IDS in structures) and (POST_IDS not in all_data): all_data[POST_IDS] = bm.asarray(indices, dtype=IDX_DTYPE) + if (COO in structures) and (COO not in all_data): + pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr)) + all_data[COO] = (bm.asarray(pre_ids, dtype=IDX_DTYPE), + bm.asarray(indices, dtype=IDX_DTYPE)) + if (PRE2POST in structures) and (PRE2POST not in all_data): all_data[PRE2POST] = (bm.asarray(indices, dtype=IDX_DTYPE), bm.asarray(indptr, dtype=IDX_DTYPE)) + if (CSR in structures) and (CSR not in all_data): + all_data[CSR] = (bm.asarray(indices, dtype=IDX_DTYPE), + bm.asarray(indptr, dtype=IDX_DTYPE)) + if (POST2PRE in structures) and (POST2PRE not in all_data): indc, indptrc = csr2csc((indices, indptr), self.post_num) all_data[POST2PRE] = (bm.asarray(indc, dtype=IDX_DTYPE), bm.asarray(indptrc, dtype=IDX_DTYPE)) + if (CSC in structures) and (CSC not in all_data): + indc, indptrc = csr2csc((indices, indptr), self.post_num) + all_data[CSC] = (bm.asarray(indc, dtype=IDX_DTYPE), + bm.asarray(indptrc, dtype=IDX_DTYPE)) + if (PRE2SYN in structures) and (PRE2SYN not in all_data): syn_seq = np.arange(indices.size, dtype=IDX_DTYPE) all_data[PRE2SYN] = (bm.asarray(syn_seq, dtype=IDX_DTYPE), @@ -246,11 +262,11 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict): all_data[POST2SYN] = (bm.asarray(syn_seqc, dtype=IDX_DTYPE), bm.asarray(indptrc, dtype=IDX_DTYPE)) - def _return_by_ij(self, structures, ij: tuple, all_data: dict): - pre_ids, post_ids = ij + def _return_by_coo(self, structures, coo: tuple, all_data: dict): + pre_ids, post_ids = coo if (CONN_MAT in structures) and (CONN_MAT not in all_data): - all_data[CONN_MAT] = bm.asarray(ij2mat(ij, self.pre_num, self.post_num), dtype=MAT_DTYPE) + all_data[CONN_MAT] = bm.asarray(coo2mat(coo, self.pre_num, self.post_num), dtype=MAT_DTYPE) if (PRE_IDS in structures) and (PRE_IDS not in all_data): all_data[PRE_IDS] = bm.asarray(pre_ids, dtype=IDX_DTYPE) @@ -258,10 +274,24 @@ def _return_by_ij(self, structures, ij: tuple, all_data: dict): if (POST_IDS in structures) and (POST_IDS not in all_data): all_data[POST_IDS] = bm.asarray(post_ids, dtype=IDX_DTYPE) - require_other_structs = len([s for s in structures - if s not in [CONN_MAT, PRE_IDS, POST_IDS]]) > 0 - if require_other_structs: - csr = ij2csr(pre_ids, post_ids, self.pre_num) + if (COO in structures) and (COO not in all_data): + all_data[COO] = (bm.asarray(pre_ids, dtype=IDX_DTYPE), + bm.asarray(post_ids, dtype=IDX_DTYPE)) + + if CSC in structures and CSC not in all_data: + csc = coo2csc(coo, self.post_num) + all_data[CSC] = (bm.asarray(csc[0], dtype=IDX_DTYPE), + bm.asarray(csc[1], dtype=IDX_DTYPE)) + + if POST2PRE in structures and POST2PRE not in all_data: + csc = coo2csc(coo, self.post_num) + all_data[POST2PRE] = (bm.asarray(csc[0], dtype=IDX_DTYPE), + bm.asarray(csc[1], dtype=IDX_DTYPE)) + + if (len([s for s in structures + if s not in [CONN_MAT, PRE_IDS, POST_IDS, + COO, CSC, POST2PRE]]) > 0): + csr = coo2csr(coo, self.pre_num) self._return_by_csr(structures, csr=csr, all_data=all_data) def _make_returns(self, structures, conn_data): @@ -269,30 +299,30 @@ def _make_returns(self, structures, conn_data): """ csr = None mat = None - ij = None + coo = None if isinstance(conn_data, dict): csr = conn_data.get('csr', None) mat = conn_data.get('mat', None) - ij = conn_data.get('ij', None) + coo = conn_data.get('coo', None) elif isinstance(conn_data, tuple): if conn_data[0] == 'csr': csr = conn_data[1] elif conn_data[0] == 'mat': mat = conn_data[1] - elif conn_data[0] == 'ij': - ij = conn_data[1] + elif conn_data[0] == 'coo': + coo = conn_data[1] else: - raise ConnectorError(f'Must provide one of "csr", "mat" or "ij". Got "{conn_data[0]}" instead.') + raise ConnectorError(f'Must provide one of "csr", "mat" or "coo". Got "{conn_data[0]}" instead.') else: - raise ConnectorError + raise ConnectorError('Unknown type') # checking - all_data = dict() - if (csr is None) and (mat is None) and (ij is None): - raise ConnectorError('Must provide one of "csr", "mat" or "ij".') + if (csr is None) and (mat is None) and (coo is None): + raise ConnectorError('Must provide one of "csr", "mat" or "coo".') structures = (structures,) if isinstance(structures, str) else structures assert isinstance(structures, (tuple, list)) + all_data = dict() # "csr" structure if csr is not None: if (PRE2POST in structures) and (PRE2POST not in all_data): @@ -307,13 +337,13 @@ def _make_returns(self, structures, conn_data): all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE) self._return_by_mat(structures, mat=mat, all_data=all_data) - # "ij" structure - if ij is not None: + # "coo" structure + if coo is not None: if (PRE_IDS in structures) and (PRE_IDS not in structures): - all_data[PRE_IDS] = bm.asarray(ij[0], dtype=IDX_DTYPE) + all_data[PRE_IDS] = bm.asarray(coo[0], dtype=IDX_DTYPE) if (POST_IDS in structures) and (POST_IDS not in structures): - all_data[POST_IDS] = bm.asarray(ij[1], dtype=IDX_DTYPE) - self._return_by_ij(structures, ij=ij, all_data=all_data) + all_data[POST_IDS] = bm.asarray(coo[1], dtype=IDX_DTYPE) + self._return_by_coo(structures, coo=coo, all_data=all_data) # return if len(structures) == 1: @@ -349,36 +379,91 @@ def require(self, *structures): else: return tuple() - try: - assert self.pre_num is not None and self.post_num is not None - except AssertionError: + if self.pre_num is None or self.post_num is None: raise ConnectorError(f'self.pre_num or self.post_num is not defined. ' f'Please use "self.require(pre_size, post_size, DATA1, DATA2, ...)" ') + _has_coo_imp = not hasattr(self.build_coo, 'not_customized') + _has_csr_imp = not hasattr(self.build_csr, 'not_customized') + _has_mat_imp = not hasattr(self.build_mat, 'not_customized') + self._check(structures) - if self.is_version2_style: + if (_has_coo_imp or _has_csr_imp or _has_mat_imp): if len(structures) == 1: - if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'): + if PRE2POST in structures and _has_csr_imp: r = self.build_csr() return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) - elif CONN_MAT in structures and not hasattr(self.build_mat, 'not_customized'): + elif CSR in structures and _has_csr_imp: + r = self.build_csr() + return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) + elif CONN_MAT in structures and _has_mat_imp: return bm.asarray(self.build_mat(), dtype=MAT_DTYPE) - elif PRE_IDS in structures and not hasattr(self.build_coo, 'not_customized'): + elif PRE_IDS in structures and _has_coo_imp: return bm.asarray(self.build_coo()[0], dtype=IDX_DTYPE) - elif POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): + elif POST_IDS in structures and _has_coo_imp: return bm.asarray(self.build_coo()[1], dtype=IDX_DTYPE) + elif COO in structures and not _has_coo_imp: + return bm.asarray(self.build_coo(), dtype=IDX_DTYPE) + elif len(structures) == 2: - if PRE_IDS in structures and POST_IDS in structures and not hasattr(self.build_coo, 'not_customized'): + if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp): r = self.build_coo() - return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) + if structures[0] == PRE_IDS: + return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE) + else: + return bm.asarray(r[1], dtype=IDX_DTYPE), bm.asarray(r[0], dtype=IDX_DTYPE) + + if ((CSR in structures or PRE2POST in structures) + and _has_csr_imp and COO in structures and _has_coo_imp): + csr = self.build_csr() + csr = (bm.asarray(csr[0], dtype=IDX_DTYPE), bm.asarray(csr[1], dtype=IDX_DTYPE)) + coo = self.build_coo() + coo = (bm.asarray(coo[0], dtype=IDX_DTYPE), bm.asarray(coo[1], dtype=IDX_DTYPE)) + if structures[0] == COO: + return coo, csr + else: + return csr, coo + + if ((CSR in structures or PRE2POST in structures) + and _has_csr_imp and CONN_MAT in structures and _has_mat_imp): + csr = self.build_csr() + csr = (bm.asarray(csr[0], dtype=IDX_DTYPE), bm.asarray(csr[1], dtype=IDX_DTYPE)) + mat = bm.asarray(self.build_mat(), dtype=MAT_DTYPE) + if structures[0] == CONN_MAT: + return mat, csr + else: + return csr, mat + + if (COO in structures and _has_coo_imp and CONN_MAT in structures and _has_mat_imp): + coo = self.build_coo() + coo = (bm.asarray(coo[0], dtype=IDX_DTYPE), bm.asarray(coo[1], dtype=IDX_DTYPE)) + mat = bm.asarray(self.build_mat(), dtype=MAT_DTYPE) + if structures[0] == COO: + return coo, mat + else: + return mat, coo conn_data = dict(csr=None, ij=None, mat=None) - if not hasattr(self.build_coo, 'not_customized'): - conn_data['ij'] = self.build_coo() - elif not hasattr(self.build_csr, 'not_customized'): + if _has_coo_imp: + conn_data['coo'] = self.build_coo() + # if (CSR in structures or PRE2POST in structures) and _has_csr_imp: + # conn_data['csr'] = self.build_csr() + # if CONN_MAT in structures and _has_mat_imp: + # conn_data['mat'] = self.build_mat() + elif _has_csr_imp: conn_data['csr'] = self.build_csr() - elif not hasattr(self.build_mat, 'not_customized'): + # if COO in structures and _has_coo_imp: + # conn_data['coo'] = self.build_coo() + # if CONN_MAT in structures and _has_mat_imp: + # conn_data['mat'] = self.build_mat() + elif _has_mat_imp: conn_data['mat'] = self.build_mat() + # if COO in structures and _has_coo_imp: + # conn_data['coo'] = self.build_coo() + # if (CSR in structures or PRE2POST in structures) and _has_csr_imp: + # conn_data['csr'] = self.build_csr() + else: + raise ValueError else: conn_data = self.build_conn() @@ -405,8 +490,8 @@ def build_conn(self): conn: tuple, dict A tuple with two elements: connection type (str) and connection data. For example: ``return 'csr', (ind, indptr)`` - Or a dict with three elements: csr, mat and ij. For example: - ``return dict(csr=(ind, indptr), mat=None, ij=None)`` + Or a dict with three elements: csr, mat and coo. For example: + ``return dict(csr=(ind, indptr), mat=None, coo=None)`` """ pass @@ -488,48 +573,23 @@ def _reset_conn(self, pre_size, post_size=None): self.__call__(pre_size, post_size) -def csr2csc(csr, post_num, data=None): - """Convert csr to csc.""" - indices, indptr = csr - np = onp if isinstance(indices, onp.ndarray) else bm - # kind = 'quicksort' if isinstance(indices, jnp.ndarray) else 'stable' - - pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr)) - - sort_ids = np.argsort(indices) # to maintain the original order of the elements with the same value - if isinstance(sort_ids, bm.JaxArray): - sort_ids = sort_ids.value - pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE) - - unique_post_ids, count = np.unique(indices, return_counts=True) - post_count = np.zeros(post_num, dtype=IDX_DTYPE) - post_count[unique_post_ids] = count - - indptr_new = post_count.cumsum() - indptr_new = np.insert(indptr_new, 0, 0) - indptr_new = np.asarray(indptr_new, dtype=IDX_DTYPE) - - if data is None: - return pre_ids_new, indptr_new - else: - data_new = data[sort_ids] - return pre_ids_new, indptr_new, data_new - - def mat2csr(dense): """convert a dense matrix to (indices, indptr).""" np = onp if isinstance(dense, onp.ndarray) else bm + pre_ids, post_ids = np.where(dense > 0) + return coo2csr((pre_ids, post_ids), dense.shape[0]) + +def mat2coo(dense): + np = onp if isinstance(dense, onp.ndarray) else bm pre_ids, post_ids = np.where(dense > 0) - pre_num = dense.shape[0] + return np.asarray(pre_ids, dtype=IDX_DTYPE), np.asarray(post_ids, dtype=IDX_DTYPE) - uni_idx, count = np.unique(pre_ids, return_counts=True) - pre_count = np.zeros(pre_num, dtype=IDX_DTYPE) - pre_count[uni_idx] = count - indptr = count.cumsum() - indptr = np.insert(indptr, 0, 0) - return np.asarray(post_ids, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE) +def mat2csc(dense): + np = onp if isinstance(dense, onp.ndarray) else bm + pre_ids, post_ids = np.where(dense > 0) + return coo2csr((post_ids, pre_ids), dense.shape[1]) def csr2mat(csr, num_pre, num_post): @@ -542,7 +602,19 @@ def csr2mat(csr, num_pre, num_post): return d -def ij2mat(ij, num_pre, num_post): +def csr2csc(csr, post_num, data=None): + """Convert csr to csc.""" + return coo2csc(csr2coo(csr), post_num, data) + + +def csr2coo(csr): + np = onp if isinstance(csr[0], onp.ndarray) else bm + indices, indptr = csr + pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr)) + return pre_ids, indices + + +def coo2mat(ij, num_pre, num_post): """convert (indices, indptr) to a dense matrix.""" pre_ids, post_ids = ij np = onp if isinstance(pre_ids, onp.ndarray) else bm @@ -551,49 +623,42 @@ def ij2mat(ij, num_pre, num_post): return d -def ij2csr(pre_ids, post_ids, num_pre): +def coo2csr(coo, num_pre): """convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'""" - if isinstance(pre_ids, onp.ndarray): - return _cpu_ij2csr(pre_ids, post_ids, num_pre) - elif isinstance(pre_ids, (jnp.ndarray, bm.ndarray)): - if pre_ids.device().platform == 'cpu': - return _cpu_ij2csr(pre_ids, post_ids, num_pre) - else: - return _gpu_ij2csr(pre_ids, post_ids, num_pre) - else: - raise TypeError - + pre_ids, post_ids = coo + np = onp if isinstance(pre_ids, onp.ndarray) else bm -def _gpu_ij2csr(pre_ids, post_ids, num_pre): - """convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'""" - sort_ids = bm.argsort(pre_ids) + sort_ids = np.argsort(pre_ids) + post_ids = np.asarray(post_ids) post_ids = post_ids[sort_ids] indices = post_ids - unique_pre_ids, pre_count = bm.unique(pre_ids, return_counts=True) - final_pre_count = bm.zeros(num_pre, dtype=jnp.uint32) - final_pre_count[unique_pre_ids] = pre_count - indptr = final_pre_count.cumsum() - indptr = bm.insert(indptr, 0, 0) - return bm.asarray(indices, dtype=IDX_DTYPE), bm.asarray(indptr, dtype=IDX_DTYPE) - - -def _cpu_ij2csr(pre_ids, post_ids, num_pre): - """convert pre_ids, post_ids to (indices, indptr). and use numba for sort function when'jax_platform_name' = 'cpu'""" - np = onp if isinstance(pre_ids, onp.ndarray) else bm unique_pre_ids, pre_count = np.unique(pre_ids, return_counts=True) - final_pre_count = np.zeros(num_pre, dtype=np.uint32) + final_pre_count = np.zeros(num_pre, dtype=jnp.uint32) final_pre_count[unique_pre_ids] = pre_count indptr = final_pre_count.cumsum() indptr = np.insert(indptr, 0, 0) + return np.asarray(indices, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE) - @numba_jit(parallel=True, nogil=True) - def single_sort(pre_ids, post_ids, indptr): - pre_tmp = indptr.copy() - indices = onp.zeros((indptr[-1],)) - for i in numba_range(indptr[-1]): - indices[pre_tmp[pre_ids[i]]] = post_ids[i] - pre_tmp[pre_ids[i]] += 1 - return indices - indices = single_sort(bm.as_numpy(pre_ids), bm.as_numpy(post_ids), bm.as_numpy(indptr)) - return np.asarray(indices, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE) +def coo2csc(coo, post_num, data=None): + """Convert csr to csc.""" + pre_ids, indices = coo + np = onp if isinstance(indices, onp.ndarray) else bm + + # to maintain the original order of the elements with the same value + sort_ids = np.argsort(indices) + pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE) + + unique_post_ids, count = np.unique(indices, return_counts=True) + post_count = np.zeros(post_num, dtype=IDX_DTYPE) + post_count[unique_post_ids] = count + + indptr_new = post_count.cumsum() + indptr_new = np.insert(indptr_new, 0, 0) + indptr_new = np.asarray(indptr_new, dtype=IDX_DTYPE) + + if data is None: + return pre_ids_new, indptr_new + else: + data_new = data[sort_ids] + return pre_ids_new, indptr_new, data_new diff --git a/brainpy/connect/custom_conn.py b/brainpy/connect/custom_conn.py index e452061e5..69c3ab879 100644 --- a/brainpy/connect/custom_conn.py +++ b/brainpy/connect/custom_conn.py @@ -81,8 +81,8 @@ class CSRConn(TwoEndConnector): def __init__(self, indices, inptr): super(CSRConn, self).__init__() - self.indices = bm.asarray(indices).astype(IDX_DTYPE) - self.inptr = bm.asarray(inptr).astype(IDX_DTYPE) + self.indices = bm.asarray(indices, dtype=IDX_DTYPE) + self.inptr = bm.asarray(inptr, dtype=IDX_DTYPE) self.pre_num = self.inptr.size - 1 self.max_post = bm.max(self.indices) @@ -110,3 +110,5 @@ def __init__(self, csr_mat): self.csr_mat = csr_mat super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE), inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE)) + self.pre_num = csr_mat.shape[0] + self.post_num = csr_mat.shape[1] diff --git a/brainpy/connect/regular_conn.py b/brainpy/connect/regular_conn.py index cb8103162..a05a29634 100644 --- a/brainpy/connect/regular_conn.py +++ b/brainpy/connect/regular_conn.py @@ -46,7 +46,7 @@ def build_csr(self): f'same size, but {self.pre_num} != {self.post_num}.') ind = np.arange(self.pre_num) indptr = np.arange(self.pre_num + 1) - return np.asarray(ind, dtype=IDX_DTYPE), np.arange(indptr, dtype=IDX_DTYPE), + return (np.asarray(ind, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE)) def build_mat(self, pre_size=None, post_size=None): if self.pre_num != self.post_num: diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index afc352173..f77259646 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -988,7 +988,7 @@ def __init__( ltp.register_master(master=self) self.ltp: SynLTP = ltp - def init_weights( + def _init_weights( self, weight: Union[float, Array, Initializer, Callable], comp_method: str, @@ -996,7 +996,7 @@ def init_weights( ) -> Union[float, Array]: if comp_method not in ['sparse', 'dense']: raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') - if sparse_data not in ['csr', 'ij']: + if sparse_data not in ['csr', 'ij', 'coo']: raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}') if self.conn is None: raise ValueError(f'Must provide "conn" when initialize the model {self.name}') @@ -1014,11 +1014,11 @@ def init_weights( if comp_method == 'sparse': if sparse_data == 'csr': conn_mask = self.conn.require('pre2post') - elif sparse_data == 'ij': + elif sparse_data in ['ij', 'coo']: conn_mask = self.conn.require('post_ids', 'pre_ids') else: ValueError(f'Unknown sparse data type: {sparse_data}') - weight = parameter(weight, conn_mask[1].shape, allow_none=False) + weight = parameter(weight, conn_mask[0].shape, allow_none=False) elif comp_method == 'dense': weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False) conn_mask = self.conn.require('conn_mat') @@ -1030,7 +1030,7 @@ def init_weights( weight = bm.TrainVar(weight) return weight, conn_mask - def syn2post_with_all2all(self, syn_value, syn_weight): + def _syn2post_with_all2all(self, syn_value, syn_weight): if bm.ndim(syn_weight) == 0: if isinstance(self.mode, BatchingMode): post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) @@ -1043,10 +1043,10 @@ def syn2post_with_all2all(self, syn_value, syn_weight): post_vs = syn_value @ syn_weight return post_vs - def syn2post_with_one2one(self, syn_value, syn_weight): + def _syn2post_with_one2one(self, syn_value, syn_weight): return syn_value * syn_weight - def syn2post_with_dense(self, syn_value, syn_weight, conn_mat): + def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): if bm.ndim(syn_weight) == 0: post_vs = (syn_weight * syn_value) @ conn_mat else: diff --git a/brainpy/dyn/layers/activate.py b/brainpy/dyn/layers/activate.py index d884541e2..936bb3695 100644 --- a/brainpy/dyn/layers/activate.py +++ b/brainpy/dyn/layers/activate.py @@ -1,15 +1,16 @@ -from brainpy.dyn.base import DynamicalSystem -from typing import Optional -from brainpy.modes import Mode from typing import Callable +from typing import Optional + +from brainpy.dyn.base import DynamicalSystem +from brainpy.modes import Mode, training class Activation(DynamicalSystem): - r"""Applies a activation to the inputs + r"""Applies an activation function to the inputs Parameters: ---------- - activate_fun: Callable + activate_fun: Callable, function The function of Activation name: str, Optional The name of the object @@ -17,12 +18,13 @@ class Activation(DynamicalSystem): Enable training this node or not. (default True). """ - def __init__(self, - activate_fun: Callable, - name: Optional[str] = None, - mode: Optional[Mode] = None, - **kwargs, - ): + def __init__( + self, + activate_fun: Callable, + name: Optional[str] = None, + mode: Mode = training, + **kwargs, + ): super().__init__(name, mode) self.activate_fun = activate_fun self.kwargs = kwargs diff --git a/brainpy/dyn/layers/linear.py b/brainpy/dyn/layers/linear.py index 1fccd718d..4e6af4ee9 100644 --- a/brainpy/dyn/layers/linear.py +++ b/brainpy/dyn/layers/linear.py @@ -9,7 +9,7 @@ from brainpy.dyn.base import DynamicalSystem from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter -from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching +from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching from brainpy.tools.checking import check_initializer from brainpy.types import Array @@ -201,17 +201,19 @@ class Flatten(DynamicalSystem): mode: Mode Enable training this node or not. (default True) """ - def __init__(self, - name: Optional[str] = None, - mode: Optional[Mode] = batching, - ): + + def __init__( + self, + name: Optional[str] = None, + mode: Optional[Mode] = batching, + ): super().__init__(name, mode) - + def update(self, shr, x): if isinstance(self.mode, BatchingMode): return x.reshape((x.shape[0], -1)) else: return x.flatten() - + def reset_state(self, batch_size=None): - pass \ No newline at end of file + pass diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 54785b02c..a74880cb8 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -119,7 +119,7 @@ def __init__( self.comp_method = comp_method # connections and weights - self.g_max, self.conn_mask = self.init_weights(g_max, comp_method=comp_method, sparse_data='csr') + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr') # register delay self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) @@ -143,10 +143,10 @@ def update(self, tdi, pre_spike=None): # synaptic values onto the post if isinstance(self.conn, All2All): syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) - post_vs = self.syn2post_with_all2all(syn_value, self.g_max) + post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) - post_vs = self.syn2post_with_one2one(syn_value, self.g_max) + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max) @@ -160,7 +160,7 @@ def update(self, tdi, pre_spike=None): # post_vs *= f2(stp_value) else: syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype())) - post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) if self.post_ref_key: post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key)) @@ -296,7 +296,7 @@ def __init__( raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}') # connections and weights - self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr') + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr') # variables self.g = variable_(bm.zeros, self.post.num, mode) @@ -328,11 +328,11 @@ def update(self, tdi, pre_spike=None): if isinstance(self.conn, All2All): syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) if self.stp is not None: syn_value = self.stp(syn_value) - post_vs = self.syn2post_with_all2all(syn_value, self.g_max) + post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) if self.stp is not None: syn_value = self.stp(syn_value) - post_vs = self.syn2post_with_one2one(syn_value, self.g_max) + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max) @@ -343,7 +343,7 @@ def update(self, tdi, pre_spike=None): else: syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) if self.stp is not None: syn_value = self.stp(syn_value) - post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates self.g.value = self.integral(self.g.value, t, dt) + post_vs @@ -487,7 +487,7 @@ def __init__( f'But we got {self.tau_decay}') # connections - self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij') # variables self.h = variable_(bm.zeros, self.pre.num, mode) @@ -531,16 +531,16 @@ def update(self, tdi, pre_spike=None): syn_value = self.g.value if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): - post_vs = self.syn2post_with_all2all(syn_value, self.g_max) + post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_vs = self.syn2post_with_one2one(syn_value, self.g_max) + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: - post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # output return self.output(post_vs) @@ -829,7 +829,7 @@ def __init__( self.stop_spike_gradient = stop_spike_gradient # connections and weights - self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij') # variables self.g = variable_(bm.zeros, self.pre.num, mode) @@ -872,16 +872,16 @@ def update(self, tdi, pre_spike=None): syn_value = self.g.value if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): - post_vs = self.syn2post_with_all2all(syn_value, self.g_max) + post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_vs = self.syn2post_with_one2one(syn_value, self.g_max) + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: - post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # output return self.output(post_vs) diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index a6db1fb7a..fa9c7c1e7 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -181,7 +181,7 @@ def __init__( raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}') # connection - self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij') # variables self.g = variable(bm.zeros, mode, self.pre.num) @@ -226,16 +226,16 @@ def update(self, tdi, pre_spike=None): syn_value = self.g.value if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): - post_vs = self.syn2post_with_all2all(syn_value, self.g_max) + post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_vs = self.syn2post_with_one2one(syn_value, self.g_max) + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: - post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # output return self.output(post_vs) @@ -526,7 +526,7 @@ def __init__( self.stop_spike_gradient = stop_spike_gradient # connections and weights - self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij') # variables self.g = variable(bm.zeros, mode, self.pre.num) @@ -575,16 +575,16 @@ def update(self, tdi, pre_spike=None): syn_value = self.g.value if self.stp is not None: syn_value = self.stp(syn_value) if isinstance(self.conn, All2All): - post_vs = self.syn2post_with_all2all(syn_value, self.g_max) + post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - post_vs = self.syn2post_with_one2one(syn_value, self.g_max) + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: - post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # output return self.output(post_vs) diff --git a/brainpy/math/setting.py b/brainpy/math/setting.py index 069c09822..93b84d0c6 100644 --- a/brainpy/math/setting.py +++ b/brainpy/math/setting.py @@ -3,13 +3,14 @@ import os import re -from jax import dtypes, config, numpy as jnp +from jax import dtypes, config, numpy as jnp, devices from jax.lib import xla_bridge __all__ = [ 'enable_x64', 'disable_x64', 'set_platform', + 'get_platform', 'set_host_device_count', # device memory @@ -92,7 +93,7 @@ def disable_x64(): config.update("jax_enable_x64", False) -def set_platform(platform): +def set_platform(platform: str): """ Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program. @@ -101,6 +102,17 @@ def set_platform(platform): config.update("jax_platform_name", platform) +def get_platform() -> str: + """Get the computing platform. + + Returns + ------- + platform: str + Either 'cpu', 'gpu' or 'tpu'. + """ + return devices()[0].platform + + def set_host_device_count(n): """ By default, XLA considers all CPU cores as one device. This utility tells XLA diff --git a/docs/conf.py b/docs/conf.py index 44ceb52b7..c15ea2e0c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -69,7 +69,7 @@ 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', - 'sphinx-mathjax-offline', + # 'sphinx-mathjax-offline', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', diff --git a/docs/index.rst b/docs/index.rst index cea550d25..73ed383fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -86,8 +86,14 @@ The code of BrainPy is open-sourced at GitHub: tutorial_advanced/adavanced_lowdim_analysis.ipynb tutorial_advanced/operator_customization.ipynb tutorial_advanced/interoperation.ipynb - tutorial_advanced/compile_brainpylib - tutorial_advanced/citing_and_publication + + + +.. toctree:: + :maxdepth: 1 + :caption: Frequently Asked Questions + + tutorial_FAQs/citing_and_publication .. toctree:: diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index fb220b731..27de4b254 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -13,7 +13,7 @@ GNU/Linux, and OSX. It only relies on Python libraries. Installation with pip --------------------- -You can install ``BrainPy`` from the `pypi `_. +You can install ``BrainPy`` from the `pypi `_. To do so, use: .. code-block:: bash @@ -35,18 +35,11 @@ of BrainPy, you can use: pip install --pre brainpy -To install ``brainpylib`` (needed in dedicated operators), you can use: - -.. code-block:: bash - - pip install brainpylib - - Installation from source ------------------------ -If you decide not to use ``conda`` or ``pip``, you can install ``BrainPy`` from +If you decide not to use ``pip``, you can install ``BrainPy`` from `GitHub `_, or `OpenI `_. @@ -174,7 +167,7 @@ Many customized operators in BrainPy are implemented in ``brainpylib``. For GPU operators, you should compile ``brainpylib`` from source. The details please see -`Compile GPU operators in brainpylib <../tutorial_advanced/compile_brainpylib.html>`_. +`Compile GPU operators in brainpylib `_. Other Dependency @@ -214,3 +207,4 @@ packages: .. _Numba: https://numba.pydata.org/ .. _CUDA: https://developer.nvidia.com/cuda-downloads .. _CuDNN: https://developer.nvidia.com/CUDNN + diff --git a/docs/tutorial_advanced/citing_and_publication.rst b/docs/tutorial_FAQs/citing_and_publication.rst similarity index 100% rename from docs/tutorial_advanced/citing_and_publication.rst rename to docs/tutorial_FAQs/citing_and_publication.rst diff --git a/docs/tutorial_advanced/compile_brainpylib.rst b/docs/tutorial_advanced/compile_brainpylib.rst deleted file mode 100644 index 0795b8901..000000000 --- a/docs/tutorial_advanced/compile_brainpylib.rst +++ /dev/null @@ -1,47 +0,0 @@ -Compile GPU operators in brainpylib -=================================== - -``brainpylib`` is designed to provide dedicated operators for sparse -and event-based synaptic computation. -We have already published CPU version of ``brainpylib`` on Pypi and users can install the CPU version by following instructions: - -.. code-block:: bash - - pip install brainpylib - -This section aims to introduce how to build up and install the GPU version. We currently did not provide GPU wheel on Pypi -and users need to build ``brainpylib`` from source. There are some prerequisites first: - -- Linux platform. -- Nvidia GPU series required. -- CUDA and cuDNN have installed. - -We have tested whole building process on Nvidia RTX A6000 GPU with CUDA 11.6 version. - -Building ``brainpylib`` GPU version ------------------------- - -First, obtain the BrainPy source code: - -.. code-block:: bash - - git clone https://github.com/PKU-NIP-Lab/BrainPy.git - cd BrainPy/extensions - -In ``extensions`` directory, users can compile GPU wheel: - -.. code-block:: bash - - python setup_cuda.py bdist_wheel - -After compilation, it's convenient for users to install the package through following instructions: - -.. code-block:: bash - - pip install dist/brainpylib-*.whl - -``brainpylib-*.whl`` is the generated file from compilation, which is located in ``dist`` folder. - -Now users have successfully install GPU version of ``brainpylib``, and we recommend users to check if ``brainpylib`` can -be imported in the Python script. - diff --git a/requirements-dev.txt b/requirements-dev.txt index 43aaf37cf..8f4eaecfe 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,11 +1,11 @@ numpy>=1.15 -jax>=0.3.0 tqdm numba matplotlib>=3.4 +jax>=0.3.0 jaxlib>=0.3.0 scipy>=1.1.0 -brainpylib>=0.0.5 +brainpylib h5py requests pillow diff --git a/requirements-doc.txt b/requirements-doc.txt index 68f8f318e..1afa63395 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -3,7 +3,6 @@ matplotlib>=3.4 jaxlib>=0.3.0 scipy>=1.1.0 -brainpylib>=0.0.5 numba requests pillow @@ -16,4 +15,4 @@ myst-nb sphinx_book_theme sphinx-autodoc-typehints sphinx_thebe -sphinx-mathjax-offline +# sphinx-mathjax-offline \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 74075aaf2..a2555ecf7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.15 jax>=0.3.0 -tqdm \ No newline at end of file +tqdm +brainpylib \ No newline at end of file diff --git a/setup.py b/setup.py index 42d7be58f..2433098bd 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.7', - install_requires=['numpy>=1.15', 'jax>=0.3.0', 'tqdm'], + install_requires=['numpy>=1.15', 'jax>=0.3.0', 'tqdm', 'brainpylib'], url='https://github.com/PKU-NIP-Lab/BrainPy', project_urls={ "Bug Tracker": "https://github.com/PKU-NIP-Lab/BrainPy/issues",