Skip to content

Commit

Permalink
update installation setup (#269)
Browse files Browse the repository at this point in the history
update installation setup
  • Loading branch information
chaoming0625 authored Oct 5, 2022
2 parents 174c81a + 9e5fe03 commit 62a1220
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 198 deletions.
31 changes: 26 additions & 5 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
# -*- coding: utf-8 -*-

__version__ = "2.2.3"
__version__ = "2.2.3.1"


try:
import jaxlib
del jaxlib
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Please install jaxlib. See '
'https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax '
'for installation instructions.'
) from None
'''
BrainPy needs jaxlib, please install jaxlib.
1. If you are using Windows system, install jaxlib through
>>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html
2. If you are using macOS platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
3. If you are using Linux platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
4. If you are using Linux + CUDA platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14", "jaxlib=0.3.14".
More detail installation instruction, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
''') from None


# fundamental modules
Expand Down
2 changes: 1 addition & 1 deletion brainpy/dyn/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.initialize import XavierNormal, ZeroInit, parameter
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check
from brainpy.modes import Mode, TrainingMode, training

__all__ = [
'GeneralConv',
Expand Down
2 changes: 1 addition & 1 deletion brainpy/dyn/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import brainpy.math as bm
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check_mode

__all__ = [
'BatchNorm',
Expand Down
4 changes: 2 additions & 2 deletions brainpy/dyn/layers/nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check
from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check_mode
from brainpy.tools.checking import (check_integer, check_sequence)

__all__ = [
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
name: str = None,
):
super(NVAR, self).__init__(mode=mode, name=name)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
order = tuple() if order is None else order
Expand Down
2 changes: 1 addition & 1 deletion brainpy/dyn/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.lax
import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check
from brainpy.modes import Mode, training

__all__ = [
'Pool',
Expand Down
10 changes: 5 additions & 5 deletions brainpy/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.integrators.sde import sdeint
from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check
from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check_mode
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Array

Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
self.ENa = parameter(ENa, self.varshape, allow_none=False)
Expand Down Expand Up @@ -427,7 +427,7 @@ def __init__(
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__)

# params
self.V_Ca = parameter(V_Ca, self.varshape, allow_none=False)
Expand Down Expand Up @@ -685,7 +685,7 @@ def __init__(
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (NormalMode, BatchingMode), self.__class__)
check_mode(self.mode, (NormalMode, BatchingMode), self.__class__)

# conductance parameters
self.gAHP = parameter(gAHP, self.varshape, allow_none=False)
Expand Down Expand Up @@ -994,7 +994,7 @@ def __init__(
):
# initialization
super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__)

# parameters
self.ENa = parameter(ENa, self.varshape, allow_none=False)
Expand Down
2 changes: 2 additions & 0 deletions brainpy/dyn/neurons/input_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from brainpy.modes import Mode, BatchingMode, normal
from brainpy.types import Shape, Array


__all__ = [
'InputGroup',
'OutputGroup',
Expand Down Expand Up @@ -205,3 +206,4 @@ def reset(self, batch_size=None):

def reset_state(self, batch_size=None):
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)

Loading

0 comments on commit 62a1220

Please sign in to comment.