Skip to content

Commit

Permalink
fix #149 (dozens of random samplings in NumPy) and fix JaxArray op er…
Browse files Browse the repository at this point in the history
…rors (#216)

fix #149 (dozens of random samplings in NumPy) and fix JaxArray op errors
  • Loading branch information
chaoming0625 authored May 17, 2022
2 parents e2f5170 + 184d720 commit 23ae81c
Show file tree
Hide file tree
Showing 7 changed files with 1,218 additions and 338 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Windows_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy==1.21.0
python -m pip install "jax[cpu]==0.3.5" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install "jax[cpu]==0.3.2" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install -r requirements-win.txt
python -m pip install tqdm brainpylib
python setup.py install
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
publishment.md
#experimental/
.vscode

io_test_tmp*

brainpy/base/tests/io_test_tmp*

Expand Down
34 changes: 17 additions & 17 deletions brainpy/math/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __sub__(self, oc):
return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc))

def __rsub__(self, oc):
return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) - self._value)

def __isub__(self, oc):
# a -= b
Expand All @@ -249,7 +249,7 @@ def __mul__(self, oc):
return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc))

def __rmul__(self, oc):
return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) * self._value)

def __imul__(self, oc):
# a *= b
Expand All @@ -258,17 +258,17 @@ def __imul__(self, oc):
self._value = self._value * (oc._value if isinstance(oc, JaxArray) else oc)
return self

def __div__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
# def __div__(self, oc):
# return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))

def __rdiv__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)

def __truediv__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))

def __rtruediv__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)

def __itruediv__(self, oc):
# a /= b
Expand All @@ -281,7 +281,7 @@ def __floordiv__(self, oc):
return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc))

def __rfloordiv__(self, oc):
return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) // self._value)

def __ifloordiv__(self, oc):
# a //= b
Expand All @@ -291,16 +291,16 @@ def __ifloordiv__(self, oc):
return self

def __divmod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc))

def __rdivmod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc))

def __mod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))

def __rmod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) % self._value)

def __imod__(self, oc):
# a %= b
Expand All @@ -313,7 +313,7 @@ def __pow__(self, oc):
return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc))

def __rpow__(self, oc):
return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ** self._value)

def __ipow__(self, oc):
# a **= b
Expand All @@ -326,7 +326,7 @@ def __matmul__(self, oc):
return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc))

def __rmatmul__(self, oc):
return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) @ self._value)

def __imatmul__(self, oc):
# a @= b
Expand All @@ -339,7 +339,7 @@ def __and__(self, oc):
return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc))

def __rand__(self, oc):
return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) & self._value)

def __iand__(self, oc):
# a &= b
Expand All @@ -352,7 +352,7 @@ def __or__(self, oc):
return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc))

def __ror__(self, oc):
return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) | self._value)

def __ior__(self, oc):
# a |= b
Expand All @@ -365,7 +365,7 @@ def __xor__(self, oc):
return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc))

def __rxor__(self, oc):
return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ^ self._value)

def __ixor__(self, oc):
# a ^= b
Expand All @@ -378,7 +378,7 @@ def __lshift__(self, oc):
return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc))

def __rlshift__(self, oc):
return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) << self._value)

def __ilshift__(self, oc):
# a <<= b
Expand All @@ -391,7 +391,7 @@ def __rshift__(self, oc):
return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc))

def __rrshift__(self, oc):
return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) >> self._value)

def __irshift__(self, oc):
# a >>= b
Expand Down
Loading

0 comments on commit 23ae81c

Please sign in to comment.