diff --git a/pydiso/mkl_solver.pyx b/pydiso/mkl_solver.pyx index 120abec..e1d8c7f 100644 --- a/pydiso/mkl_solver.pyx +++ b/pydiso/mkl_solver.pyx @@ -2,6 +2,13 @@ #cython: linetrace=True cimport numpy as np from cython cimport numeric +from cpython.pythread cimport ( + PyThread_type_lock, + PyThread_allocate_lock, + PyThread_acquire_lock, + PyThread_release_lock, + PyThread_free_lock +) import warnings import numpy as np @@ -42,12 +49,12 @@ cdef extern from 'mkl.h': void pardiso(_MKL_DSS_HANDLE_t, const int*, const int*, const int*, const int *, const int *, const void *, const int *, const int *, int *, const int *, int *, - const int *, void *, void *, int *) + const int *, void *, void *, int *) nogil void pardiso_64(_MKL_DSS_HANDLE_t, const long_t *, const long_t *, const long_t *, const long_t *, const long_t *, const void *, const long_t *, const long_t *, long_t *, const long_t *, long_t *, - const long_t *, void *, void *, long_t *) + const long_t *, void *, void *, long_t *) nogil #call pardiso (pt, maxfct, mnum, mtype, phase, n, a, ia, ja, perm, nrhs, iparm, msglvl, b, x, error) @@ -184,7 +191,7 @@ cdef class MKLPardisoSolver: cdef int_t _factored cdef size_t shape[2] cdef int_t _initialized - + cdef PyThread_type_lock lock cdef void * a cdef object _data_type @@ -253,6 +260,9 @@ cdef class MKLPardisoSolver: raise ValueError("Matrix is not square") self.shape = n_row, n_col + # allocate the lock + self.lock = PyThread_allocate_lock() + self._data_type = A.dtype if matrix_type is None: if np.issubdtype(self._data_type, np.complexfloating): @@ -496,6 +506,7 @@ cdef class MKLPardisoSolver: cdef long_t phase64=-1, nrhs64=0, error64=0 if self._initialized: + PyThread_acquire_lock(self.lock, 1) if self._is_32: pardiso( self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype, @@ -508,9 +519,12 @@ cdef class MKLPardisoSolver: &phase64, &self._par64.n, self.a, NULL, NULL, NULL, &nrhs64, self._par64.iparm, &self._par64.msglvl, NULL, NULL, &error64 ) + PyThread_release_lock(self.lock) err = error or error64 if err!=0: raise PardisoError("Memmory release error "+_err_messages[err]) + #dealloc lock + PyThread_free_lock(self.lock) cdef _analyze(self): #phase = 11 @@ -536,17 +550,20 @@ cdef class MKLPardisoSolver: if err!=0: raise PardisoError("Solve step error, "+_err_messages[err]) - cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0): + cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0) nogil: cdef int_t error=0 cdef long_t error64=0, phase64=phase, nrhs64=nrhs + PyThread_acquire_lock(self.lock, 1) if self._is_32: pardiso(self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype, &phase, &self._par.n, self.a, &self._par.ia[0], &self._par.ja[0], &self._par.perm[0], &nrhs, self._par.iparm, &self._par.msglvl, b, x, &error) + PyThread_release_lock(self.lock) return error else: pardiso_64(self.handle, &self._par64.maxfct, &self._par64.mnum, &self._par64.mtype, &phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0], &self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64) + PyThread_release_lock(self.lock) return error64 diff --git a/tests/test_pydiso.py b/tests/test_pydiso.py index 4aa6d0c..d5c715a 100644 --- a/tests/test_pydiso.py +++ b/tests/test_pydiso.py @@ -8,6 +8,7 @@ set_mkl_threads, set_mkl_pardiso_threads, ) +from concurrent.futures import ThreadPoolExecutor import pytest import sys @@ -147,3 +148,25 @@ def test_rhs_size_error(): solver.solve(b_bad) with pytest.raises(ValueError): solver.solve(b, x_bad) + +def test_threading(): + """ + Here we test that calling the solver is safe from multiple threads. + There isn't actually any speedup because it acquires a lock on each call + to pardiso internally (because those calls are not thread safe). + """ + n = 200 + n_rhs = 75 + A = sp.diags([-1, 2, -1], (-1, 0, 1), shape=(n, n), format='csr') + Ainv = Solver(A) + + x_true = np.random.rand(n, n_rhs) + rhs = A @ x_true + + with ThreadPoolExecutor() as pool: + x_sol = np.stack( + list(pool.map(lambda i: Ainv.solve(rhs[:, i]), range(n_rhs))), + axis=1 + ) + + np.testing.assert_allclose(x_true, x_sol)