Skip to content

Commit

Permalink
Merge pull request #326 from chaoming0625/master
Browse files Browse the repository at this point in the history
[compatibility] more operators in pytorch and tensorflow
  • Loading branch information
chaoming0625 authored Jan 29, 2023
2 parents f46e78b + 6927195 commit 7f96a8b
Show file tree
Hide file tree
Showing 21 changed files with 583 additions and 262 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
from . import activations

# high-level numpy operations
from .arraycreation import *
from .arrayinterporate import *
from .arraycompatible import *
from .compat_numpy import *
from .compat_tensorflow import *
from .others import *
from . import random, linalg, fft

Expand Down
21 changes: 20 additions & 1 deletion brainpy/_src/math/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,29 @@ def _compatible_with_brainpy_array(fun: Callable):
@functools.wraps(fun)
def new_fun(*args, **kwargs):
args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf)
out = None
if len(kwargs):
# compatible with PyTorch syntax
if 'dim' in kwargs:
kwargs['axis'] = kwargs.pop('dim')
# compatible with PyTorch syntax
if 'keepdim' in kwargs:
kwargs['keep_dims'] = kwargs.pop('keepdim')
# compatible with TensorFlow syntax
if 'keepdims' in kwargs:
kwargs['keep_dims'] = kwargs.pop('keepdims')
# compatible with NumPy/PyTorch syntax
if 'out' in kwargs:
out = kwargs.get('out')
if not isinstance(out, Array):
raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}')
# format
kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf)
r = fun(*args, **kwargs)
return tree_map(_return, r)
if out is None:
return tree_map(_return, r)
else:
out.value = r

new_fun.__doc__ = getattr(fun, "__doc__", None)

Expand Down
101 changes: 0 additions & 101 deletions brainpy/_src/math/arraycreation.py

This file was deleted.

108 changes: 0 additions & 108 deletions brainpy/_src/math/arrayoperation.py

This file was deleted.

Loading

0 comments on commit 7f96a8b

Please sign in to comment.