Skip to content

Commit

Permalink
update installation setup
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Oct 5, 2022
1 parent 4417069 commit 9e5fe03
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 42 deletions.
24 changes: 20 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,27 @@
raise ModuleNotFoundError(
'''
Please install jaxlib. See
https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
BrainPy needs jaxlib, please install jaxlib.
for installation instructions.
1. If you are using Windows system, install jaxlib through
>>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html
2. If you are using macOS platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
3. If you are using Linux platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
4. If you are using Linux + CUDA platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14", "jaxlib=0.3.14".
More detail installation instruction, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
''') from None

Expand Down
50 changes: 32 additions & 18 deletions docs/quickstart/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ Linux & MacOS
^^^^^^^^^^^^^

Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or
later) platforms. The provided binary releases of JAX for Linux and macOS
later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS
systems are available at

- for CPU: https://storage.googleapis.com/jax-releases/jax_releases.html
- for GPU: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


To install a CPU-only version of JAX, you can run
If you want to install a CPU-only version of `jax` and `jaxlib`, you can run

.. code-block:: bash
pip install --upgrade "jax[cpu]"
pip install --upgrade "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
If you want to install JAX with both CPU and NVidia GPU support, you must first install
`CUDA`_ and `CuDNN`_, if they have not already been installed. Next, run
Expand All @@ -109,42 +109,57 @@ If you want to install JAX with both CPU and NVidia GPU support, you must first
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Alternatively, you can download the preferred release ".whl" file for jaxlib, and install it via ``pip``:
Alternatively, you can download the preferred release ".whl" file for jaxlib
from the above release links, and install it via ``pip``:

.. code-block:: bash
pip install xxx-0.3.14-xxx.whl
pip install jax==0.3.14
Note that the versions of `jaxlib` and `jax` should be consistent.
.. note::

Note that the versions of `jaxlib` and `jax` should be consistent.

For example, if you are using `jax==0.3.14`, you would better install `jax==0.3.14`.



Windows
^^^^^^^

For **Windows** users, JAX can be installed by the following methods:
For **Windows** users, `jax` and `jaxlib` can be installed from the community supports.
Specifically, you can install `jax` and `jaxlib` through:

.. code-block:: bash
pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html
If you are using GPU, you can install GPU-versioned wheels through:

.. code-block:: bash
- **Method 1**: There are several communities support JAX for Windows, please refer
to the github link for more details: https://github.com/cloudhan/jax-windows-builder .
Simply speaking, the provided binary releases of JAX for Windows
are available at https://whls.blob.core.windows.net/unstable/index.html .
pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html
You can download the preferred release ".whl" file, and install it via ``pip``:
Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by
downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html .
Then install it via ``pip``:

.. code-block:: bash
pip install xxx-0.3.14-xxx.whl
pip install jax==0.3.14
- **Method 2**: For Windows 10+ system, you can use `Windows Subsystem for Linux (WSL)`_.
The installation guide can be found in `WSL Installation Guide for Windows 10`_.
Then, you can install JAX in WSL just like the installation step in Linux/MacOs.


- **Method 3**: You can also `build JAX from source`_.
WSL
^^^

Moreover, for Windows 10+ system, we recommend using `Windows Subsystem for Linux (WSL)`_.
The installation guide can be found in
`WSL Installation Guide for Windows 10/11 <https://docs.microsoft.com/en-us/windows/wsl/install-win10>`_.
Then, you can install JAX in WSL just like the installation step in Linux/MacOs.


Dependency 3: brainpylib
Expand Down Expand Up @@ -194,7 +209,6 @@ packages:
.. _Matplotlib: https://matplotlib.org/
.. _JAX: https://github.com/google/jax
.. _Windows Subsystem for Linux (WSL): https://docs.microsoft.com/en-us/windows/wsl/about
.. _WSL Installation Guide for Windows 10: https://docs.microsoft.com/en-us/windows/wsl/install-win10
.. _build JAX from source: https://jax.readthedocs.io/en/latest/developer.html
.. _SymPy: https://github.com/sympy/sympy
.. _Numba: https://numba.pydata.org/
Expand Down
21 changes: 1 addition & 20 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,6 @@
with io.open(os.path.join(here, 'README.md'), 'r', encoding='utf-8') as f:
README = f.read()

# require users to install jaxlib before installing brainpy on Windows platform
requirements = ['numpy>=1.15', 'jax>=0.3.0', 'tqdm']
if sys.platform.startswith('win32') or sys.platform.startswith('cygwin'):
try:
import jaxlib
except ModuleNotFoundError:
raise ModuleNotFoundError('''
----------------------------------------------------------------------
We detect that your are using Windows platform.
Please manually install "jaxlib" before installing "brainpy".
See https://whls.blob.core.windows.net/unstable/index.html
for jaxlib's Windows wheels.
----------------------------------------------------------------------
''') from None
else:
requirements.append('jaxlib>=0.3.0')

# installation packages
packages = find_packages()
if 'docs' in packages:
Expand All @@ -74,7 +55,7 @@
author_email='chao.brain@qq.com',
packages=packages,
python_requires='>=3.7',
install_requires=requirements,
install_requires=['numpy>=1.15', 'jax>=0.3.0', 'tqdm'],
url='https://github.com/PKU-NIP-Lab/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/PKU-NIP-Lab/BrainPy/issues",
Expand Down

0 comments on commit 9e5fe03

Please sign in to comment.