Skip to content

Commit

Permalink
Merge pull request #7 from mrocklin/nogil
Browse files Browse the repository at this point in the history
Add nogil declaration to _run_paradiso function
  • Loading branch information
jcapriot authored Nov 14, 2023
2 parents 8a92a13 + 35d4d5d commit 94ca681
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
25 changes: 21 additions & 4 deletions pydiso/mkl_solver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
#cython: linetrace=True
cimport numpy as np
from cython cimport numeric
from cpython.pythread cimport (
PyThread_type_lock,
PyThread_allocate_lock,
PyThread_acquire_lock,
PyThread_release_lock,
PyThread_free_lock
)

import warnings
import numpy as np
Expand Down Expand Up @@ -42,12 +49,12 @@ cdef extern from 'mkl.h':
void pardiso(_MKL_DSS_HANDLE_t, const int*, const int*, const int*,
const int *, const int *, const void *, const int *,
const int *, int *, const int *, int *,
const int *, void *, void *, int *)
const int *, void *, void *, int *) nogil

void pardiso_64(_MKL_DSS_HANDLE_t, const long_t *, const long_t *, const long_t *,
const long_t *, const long_t *, const void *, const long_t *,
const long_t *, long_t *, const long_t *, long_t *,
const long_t *, void *, void *, long_t *)
const long_t *, void *, void *, long_t *) nogil


#call pardiso (pt, maxfct, mnum, mtype, phase, n, a, ia, ja, perm, nrhs, iparm, msglvl, b, x, error)
Expand Down Expand Up @@ -184,7 +191,7 @@ cdef class MKLPardisoSolver:
cdef int_t _factored
cdef size_t shape[2]
cdef int_t _initialized

cdef PyThread_type_lock lock
cdef void * a

cdef object _data_type
Expand Down Expand Up @@ -253,6 +260,9 @@ cdef class MKLPardisoSolver:
raise ValueError("Matrix is not square")
self.shape = n_row, n_col

# allocate the lock
self.lock = PyThread_allocate_lock()

self._data_type = A.dtype
if matrix_type is None:
if np.issubdtype(self._data_type, np.complexfloating):
Expand Down Expand Up @@ -496,6 +506,7 @@ cdef class MKLPardisoSolver:
cdef long_t phase64=-1, nrhs64=0, error64=0

if self._initialized:
PyThread_acquire_lock(self.lock, 1)
if self._is_32:
pardiso(
self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
Expand All @@ -508,9 +519,12 @@ cdef class MKLPardisoSolver:
&phase64, &self._par64.n, self.a, NULL, NULL, NULL, &nrhs64,
self._par64.iparm, &self._par64.msglvl, NULL, NULL, &error64
)
PyThread_release_lock(self.lock)
err = error or error64
if err!=0:
raise PardisoError("Memmory release error "+_err_messages[err])
#dealloc lock
PyThread_free_lock(self.lock)

cdef _analyze(self):
#phase = 11
Expand All @@ -536,17 +550,20 @@ cdef class MKLPardisoSolver:
if err!=0:
raise PardisoError("Solve step error, "+_err_messages[err])

cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0):
cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0) nogil:
cdef int_t error=0
cdef long_t error64=0, phase64=phase, nrhs64=nrhs

PyThread_acquire_lock(self.lock, 1)
if self._is_32:
pardiso(self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
&phase, &self._par.n, self.a, &self._par.ia[0], &self._par.ja[0],
&self._par.perm[0], &nrhs, self._par.iparm, &self._par.msglvl, b, x, &error)
PyThread_release_lock(self.lock)
return error
else:
pardiso_64(self.handle, &self._par64.maxfct, &self._par64.mnum, &self._par64.mtype,
&phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0],
&self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64)
PyThread_release_lock(self.lock)
return error64
23 changes: 23 additions & 0 deletions tests/test_pydiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
set_mkl_threads,
set_mkl_pardiso_threads,
)
from concurrent.futures import ThreadPoolExecutor
import pytest
import sys

Expand Down Expand Up @@ -147,3 +148,25 @@ def test_rhs_size_error():
solver.solve(b_bad)
with pytest.raises(ValueError):
solver.solve(b, x_bad)

def test_threading():
"""
Here we test that calling the solver is safe from multiple threads.
There isn't actually any speedup because it acquires a lock on each call
to pardiso internally (because those calls are not thread safe).
"""
n = 200
n_rhs = 75
A = sp.diags([-1, 2, -1], (-1, 0, 1), shape=(n, n), format='csr')
Ainv = Solver(A)

x_true = np.random.rand(n, n_rhs)
rhs = A @ x_true

with ThreadPoolExecutor() as pool:
x_sol = np.stack(
list(pool.map(lambda i: Ainv.solve(rhs[:, i]), range(n_rhs))),
axis=1
)

np.testing.assert_allclose(x_true, x_sol)

6 comments on commit 94ca681

@devjit-kobold
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcapriot @mrocklin

Have you tested that you can still build pydiso?

I can't compile the Cython file anymore with this commit:


#48 463.1       pydiso/mkl_solver.pyx:313:29: Cannot assign type 'int (int_t *, int_t *, char *, int_t) nogil' to 'ProgressEntry'
#48 463.1       Traceback (most recent call last):
#48 463.1         File "<string>", line 2, in <module>
#48 463.1         File "<pip-setuptools-caller>", line 34, in <module>
#48 463.1         File "/tmp/pip-install-_blub8vr/pydiso_bb78c6ab35cf4fa593ea46a34bfd632e/setup.py", line 70, in <module>
#48 463.1           setup(**metadata)
#48 463.1         File "/python/lib/python3.10/site-packages/numpy/distutils/core.py", line 135, in setup
#48 463.1           config = configuration()
#48 463.1         File "/tmp/pip-install-_blub8vr/pydiso_bb78c6ab35cf4fa593ea46a34bfd632e/setup.py", line 17, in configuration
#48 463.1           config.add_subpackage("pydiso")
#48 463.1         File "/python/lib/python3.10/site-packages/numpy/distutils/misc_util.py", line 1050, in add_subpackage
#48 463.1           config_list = self.get_subpackage(subpackage_name, subpackage_path,
#48 463.1         File "/python/lib/python3.10/site-packages/numpy/distutils/misc_util.py", line 1016, in get_subpackage
#48 463.1           config = self._get_configuration_from_setup_py(
#48 463.1         File "/python/lib/python3.10/site-packages/numpy/distutils/misc_util.py", line 958, in _get_configuration_from_setup_py
#48 463.1           config = setup_module.configuration(*args)
#48 463.1         File "/tmp/pip-install-_blub8vr/pydiso_bb78c6ab35cf4fa593ea46a34bfd632e/pydiso/setup.py", line 20, in configuration
#48 463.1           cythonize(join(base_path, "mkl_solver.pyx"))
#48 463.1         File "/python/lib/python3.10/site-packages/Cython/Build/Dependencies.py", line 1115, in cythonize
#48 463.1           cythonize_one(*args)
#48 463.1         File "/python/lib/python3.10/site-packages/Cython/Build/Dependencies.py", line 1238, in cythonize_one
#48 463.1           raise CompileError(None, pyx_file)
#48 463.1       Cython.Compiler.Errors.CompileError: /tmp/pip-install-_blub8vr/pydiso_bb78c6ab35cf4fa593ea46a34bfd632e/pydiso/mkl_solver.pyx
#48 463.1       [end of output]

@jcapriot
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still builds in the CI environment, what version of Cython do you have?

@devjit-kobold
Copy link

@devjit-kobold devjit-kobold commented on 94ca681 Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.29.34 - which exceeds the 0.29.31 requirement in setup.py?

trying with ~3.0 now...

@devjit-kobold
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep - that works - perhaps we should update setup.py requirements and are you planning to do a release again at some point?

@jcapriot
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the minimum requirements for this. Wasn't quite aware of anything else different here that would've caused issues with the cython compiler though. Do you know if it works with 0.29.36? I'll push a small release here as a patch release for this version soon.

@devjit-kobold
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if 0.29.36 works - jumped straight to ~3 - doesn't seem like an onerous requirement generally though - perhaps worth just bumping up?

Please sign in to comment.