diff --git a/src/shmem4py/shmem.py b/src/shmem4py/shmem.py index d56e37a..914f93b 100644 --- a/src/shmem4py/shmem.py +++ b/src/shmem4py/shmem.py @@ -997,8 +997,9 @@ def full( Valid hints are defined as enumerations in `MALLOC` and can be combined using the bitwise OR operator. Keyword argument only. """ + fill_value = np.array(fill_value) if dtype is None: - dtype = np.array(fill_value).dtype + dtype = fill_value.dtype a = new_array(shape, dtype, order, align=align, hints=hints, clear=False) np.copyto(a, fill_value, casting='unsafe') lib.shmem_barrier_all() diff --git a/test/test_rma.py b/test/test_rma.py index 7d1fa01..ee31cd4 100644 --- a/test/test_rma.py +++ b/test/test_rma.py @@ -63,7 +63,7 @@ def testGet(self): nxpe = (mype + 1) % npes for t in types: src = shmem.full(1, mype, dtype=t) - dst = np.full(1, -1, dtype=t) + dst = np.full(1, np.array(-1), dtype=t) shmem.barrier_all() shmem.get(dst, src, nxpe) self.assertEqual(dst[0], nxpe) @@ -148,7 +148,7 @@ def testGetNBI(self): nxpe = (mype + 1) % npes for t in types: src = shmem.full(1, mype, dtype=t) - dst = np.full(1, -1, dtype=t) + dst = np.full(1, np.array(-1), dtype=t) shmem.barrier_all() shmem.get_nbi(dst, src, nxpe) shmem.fence()