Skip to content

Commit

Permalink
Fixes to support NumPy 2.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dalcinl committed Aug 20, 2024
1 parent acdbb3c commit a88ad9a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/shmem4py/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,9 @@ def full(
Valid hints are defined as enumerations in `MALLOC` and can be
combined using the bitwise OR operator. Keyword argument only.
"""
fill_value = np.array(fill_value)
if dtype is None:
dtype = np.array(fill_value).dtype
dtype = fill_value.dtype
a = new_array(shape, dtype, order, align=align, hints=hints, clear=False)
np.copyto(a, fill_value, casting='unsafe')
lib.shmem_barrier_all()
Expand Down
4 changes: 2 additions & 2 deletions test/test_rma.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def testGet(self):
nxpe = (mype + 1) % npes
for t in types:
src = shmem.full(1, mype, dtype=t)
dst = np.full(1, -1, dtype=t)
dst = np.full(1, np.array(-1), dtype=t)
shmem.barrier_all()
shmem.get(dst, src, nxpe)
self.assertEqual(dst[0], nxpe)
Expand Down Expand Up @@ -148,7 +148,7 @@ def testGetNBI(self):
nxpe = (mype + 1) % npes
for t in types:
src = shmem.full(1, mype, dtype=t)
dst = np.full(1, -1, dtype=t)
dst = np.full(1, np.array(-1), dtype=t)
shmem.barrier_all()
shmem.get_nbi(dst, src, nxpe)
shmem.fence()
Expand Down

0 comments on commit a88ad9a

Please sign in to comment.