Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Mar 6, 2024
1 parent ded369a commit ced44b2
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 31 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"qseek.ext.array_tools",
sources=["src/qseek/ext/array_tools.c"],
include_dirs=[numpy.get_include()],
extra_compile_args=["-fopenmp"],
extra_link_args=["-lgomp"],
)
]
)
71 changes: 42 additions & 29 deletions src/qseek/ext/array_tools.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#define PY_SSIZE_T_CLEAN /* Make "s#" use Py_ssize_t rather than int. */
#include "numpy/arrayobject.h"
#include <Python.h>
Expand All @@ -23,25 +24,28 @@ static PyObject *fill_zero_bytes(PyObject *module, PyObject *args,
}

static PyObject *apply_cache(PyObject *module, PyObject *args, PyObject *kwds) {
PyObject *array, *cache, *mask;
PyObject *obj, *cache, *mask;
PyArrayObject *array, *mask_array, *cached_row;
npy_intp *array_shape;
npy_intp n_nodes, n_samples;
int n_threads = 1;
uint sum_mask = 0;

npy_int *cumsum_mask, mask_value;
npy_int idx_sum = 0;
npy_bool *mask_data;
PyArrayObject *cached_array;

static char *kwlist[] = {"array", "cache", "mask", NULL};
static char *kwlist[] = {"array", "cache", "mask", "nthreads", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO", kwlist, &array, &cache,
&mask))
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|i", kwlist, &obj, &cache,
&mask, &n_threads))
return NULL;

if (!PyArray_Check(array)) {
if (!PyArray_Check(obj)) {
PyErr_SetString(PyExc_ValueError, "array is not a NumPy array");
return NULL;
}
array = (PyArrayObject *)obj;
if (PyArray_NDIM(array) != 2) {
PyErr_SetString(PyExc_ValueError, "array is not a 2D NumPy array");
return NULL;
Expand All @@ -63,67 +67,76 @@ static PyObject *apply_cache(PyObject *module, PyObject *args, PyObject *kwds) {
PyErr_SetString(PyExc_ValueError, "mask is not a NumPy array");
return NULL;
}
if (PyArray_NDIM(mask) != 1) {
mask_array = (PyArrayObject *)mask;
if (PyArray_NDIM(mask_array) != 1) {
PyErr_SetString(PyExc_ValueError, "mask is not a 2D NumPy array");
return NULL;
}
if (PyArray_SIZE(mask) != n_nodes) {
if (PyArray_SIZE(mask_array) != n_nodes) {
PyErr_SetString(PyExc_ValueError, "mask size does not match array");
return NULL;
}

cumsum_mask = (npy_int *)malloc(n_nodes * sizeof(npy_int));
mask_data = PyArray_DATA(mask_array);
for (int i_node = 0; i_node < n_nodes; i_node++) {
mask_value = mask_data[i_node];
if (!mask_value) {
cumsum_mask[i_node] = -1;
} else {
cumsum_mask[i_node] = idx_sum;
idx_sum += 1;
sum_mask += 1;
}
}
if (!PyList_Check(cache)) {
PyErr_SetString(PyExc_ValueError, "cache is not a list");
return NULL;
}
if (PyList_Size(cache) != sum_mask) {
PyErr_SetString(PyExc_ValueError, "cache elements does not match mask");
return NULL;
}

for (int i_node = 0; i_node < PyList_Size(cache); i_node++) {
PyObject *item = PyList_GetItem(cache, i_node);
if (!PyArray_Check(item)) {
PyErr_SetString(PyExc_ValueError, "cache item is not a NumPy array");
return NULL;
}
if (PyArray_TYPE(item) != NPY_FLOAT) {
cached_row = (PyArrayObject *)item;
if (PyArray_TYPE(cached_row) != NPY_FLOAT) {
PyErr_SetString(PyExc_ValueError, "cache item is not of type np.float32");
return NULL;
}
if (PyArray_NDIM(item) != 1) {
if (PyArray_NDIM(cached_row) != 1) {
PyErr_SetString(PyExc_ValueError, "cache item is not a 1D NumPy array");
return NULL;
}
if (!PyArray_IS_C_CONTIGUOUS(item)) {
if (!PyArray_IS_C_CONTIGUOUS(cached_row)) {
PyErr_SetString(PyExc_ValueError, "cache item is not C contiguous");
return NULL;
}
if (PyArray_SIZE(item) != n_samples) {
PyErr_SetString(PyExc_ValueError, "cache item size does not match array");
if (PyArray_SIZE(cached_row) != n_samples) {
PyErr_SetString(PyExc_ValueError,
"cache item size does not match array nsamples");
return NULL;
}
}

// cumsum mask

cumsum_mask = (npy_int *)malloc(n_nodes * sizeof(npy_int));
mask_data = PyArray_DATA((PyArrayObject *)mask);
for (int i_node = 0; i_node < n_nodes; i_node++) {
mask_value = mask_data[i_node];
if (!mask_value) {
cumsum_mask[i_node] = -1;
} else {
cumsum_mask[i_node] = idx_sum;
idx_sum += 1;
}
}
Py_BEGIN_ALLOW_THREADS;
#pragma omp parallel for num_threads(n_threads) \
schedule(dynamic) private(cached_row)
for (int i_node = 0; i_node < n_nodes; i_node++) {
if (cumsum_mask[i_node] == -1) {
continue;
}
cached_array = PyList_GET_ITEM(cache, (Py_ssize_t)cumsum_mask[i_node]);
cached_row = (PyArrayObject *)PyList_GET_ITEM(
cache, (Py_ssize_t)cumsum_mask[i_node]);
memcpy(
PyArray_GETPTR2((PyArrayObject *)array, (npy_intp)i_node, (npy_intp)0),
PyArray_DATA((PyArrayObject *)cached_array),
n_samples * sizeof(npy_float32));
PyArray_DATA((PyArrayObject *)cached_row),
(size_t)n_samples * sizeof(npy_float32));
}
Py_END_ALLOW_THREADS;

Expand Down
8 changes: 7 additions & 1 deletion src/qseek/ext/array_tools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ import numpy as np
def fill_zero_bytes(array: np.ndarray) -> None:
"""Fill the zero bytes of the array with zeros."""

def apply_cache(data: np.ndarray, cache: list[np.ndarray], mask: np.ndarray) -> None:
def apply_cache(
data: np.ndarray,
cache: list[np.ndarray],
mask: np.ndarray,
nthreads: int = 1,
) -> None:
"""Apply the cache to the data array.
Args:
data: The data array, ndim=2 with NxM shape.
cache: List of arrays with ndim=1 and M shape.
mask: The mask array, ndim=1 with N shape of np.bool type.
nthreads: The number of threads to use.
"""
8 changes: 7 additions & 1 deletion src/qseek/models/semblance.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,13 @@ async def apply_cache(self, cache: dict[bytes, np.ndarray]) -> None:
# for idx, copy in enumerate(mask):
# if copy:
# memoryview(self.semblance_unpadded[idx])[:] = memoryview(data.pop(0))
await asyncio.to_thread(apply_cache, self.semblance_unpadded, data, mask)
await asyncio.to_thread(
apply_cache,
self.semblance_unpadded,
data,
mask,
nthreads=4,
)

def maximum_node_semblance(self) -> np.ndarray:
semblance = self.semblance.max(axis=1)
Expand Down

0 comments on commit ced44b2

Please sign in to comment.