Skip to content

Commit

Permalink
be a bit more accurate about the argument types
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Feb 29, 2024
1 parent 6d11442 commit 4ed000d
Showing 1 changed file with 52 additions and 57 deletions.
109 changes: 52 additions & 57 deletions pydiso/mkl_solver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@ import numpy as np
import scipy.sparse as sp
import os

ctypedef long long MKL_INT64
ctypedef unsigned long long MKL_UINT64
ctypedef int MKL_INT

ctypedef MKL_INT int_t
ctypedef MKL_INT64 long_t

cdef extern from 'mkl.h':

ctypedef long long MKL_INT64
ctypedef unsigned long long MKL_UINT64
ctypedef int MKL_INT
int MKL_DOMAIN_PARDISO

ctypedef struct MKLVersion:
Expand All @@ -42,28 +39,27 @@ cdef extern from 'mkl.h':
int mkl_get_max_threads()
int mkl_domain_get_max_threads(int domain)

ctypedef int (*ProgressEntry)(int_t* thread, int_t* step, char* stage, int_t stage_len) except? -1;
ctypedef int (*ProgressEntry)(int* thread, int* step, char* stage, int stage_len) except? -1;
ProgressEntry mkl_set_progress(ProgressEntry progress);

ctypedef void * _MKL_DSS_HANDLE_t

void pardiso(_MKL_DSS_HANDLE_t, const int*, const int*, const int*,
const int *, const int *, const void *, const int *,
const int *, int *, const int *, int *,
const int *, void *, void *, int *) nogil

void pardiso_64(_MKL_DSS_HANDLE_t, const long_t *, const long_t *, const long_t *,
const long_t *, const long_t *, const void *, const long_t *,
const long_t *, long_t *, const long_t *, long_t *,
const long_t *, void *, void *, long_t *) nogil
void pardiso(_MKL_DSS_HANDLE_t, const MKL_INT*, const MKL_INT*, const MKL_INT*,
const MKL_INT *, const MKL_INT *, const void *, const MKL_INT *,
const MKL_INT *, MKL_INT *, const MKL_INT *, MKL_INT *,
const MKL_INT *, void *, void *, MKL_INT *) nogil

void pardiso_64(_MKL_DSS_HANDLE_t, const long long int *, const long long int *, const long long int *,
const long long int *, const long long int *, const void *, const long long int *,
const long long int *, long long int *, const long long int *, long long int *,
const long long int *, void *, void *, long long int *) nogil

#call pardiso (pt, maxfct, mnum, mtype, phase, n, a, ia, ja, perm, nrhs, iparm, msglvl, b, x, error)
cdef int mkl_progress(int_t *thread, int_t* step, char* stage, int_t stage_len):
cdef int mkl_progress(int *thread, int* step, char* stage, int stage_len):
print(thread[0], step[0], stage, stage_len)
return 0

cdef int mkl_no_progress(int_t *thread, int_t* step, char* stage, int_t stage_len) nogil:
cdef int mkl_no_progress(int *thread, int* step, char* stage, int stage_len) nogil:
return 0

MATRIX_TYPES ={
Expand Down Expand Up @@ -170,14 +166,14 @@ def get_mkl_version():
return vers

cdef class _PardisoParams:
cdef int_t iparm[64]
cdef int_t n, mtype, maxfct, mnum, msglvl
cdef int_t[:] ia, ja, perm
cdef MKL_INT iparm[64]
cdef MKL_INT n, mtype, maxfct, mnum, msglvl
cdef MKL_INT[:] ia, ja, perm

cdef class _PardisoParams64:
cdef long_t iparm[64]
cdef long_t n, mtype, maxfct, mnum, msglvl
cdef long_t[:] ia, ja, perm
cdef MKL_INT64 iparm[64]
cdef MKL_INT64 n, mtype, maxfct, mnum, msglvl
cdef MKL_INT64[:] ia, ja, perm

ctypedef fused _par_params:
_PardisoParams
Expand All @@ -187,11 +183,11 @@ cdef class MKLPardisoSolver:
cdef _MKL_DSS_HANDLE_t handle[64]
cdef _PardisoParams _par
cdef _PardisoParams64 _par64
cdef int_t _is_32
cdef int_t mat_type
cdef int_t _factored
cdef int _call32
cdef int mat_type
cdef int _factored
cdef size_t shape[2]
cdef int_t _initialized
cdef int _initialized
cdef PyThread_type_lock lock
cdef void * a

Expand Down Expand Up @@ -293,8 +289,10 @@ cdef class MKLPardisoSolver:

#set integer length
integer_len = A.indices.itemsize
self._is_32 = integer_len == sizeof(int_t)
if self._is_32:
# we only need to call the 64 bit version if
# sizeof(MKL_INT) == 4 and A.indices.itemsize == 8
self._call32 = not (sizeof(MKL_INT) == 4 and integer_len == 8)
if self._call32:
self._par = _PardisoParams()
self._initialize(self._par, A, matrix_type, verbose)
elif integer_len == 8:
Expand All @@ -307,7 +305,7 @@ cdef class MKLPardisoSolver:
# allocate the lock
self.lock = PyThread_allocate_lock()

if(verbose):
if verbose:
#for reporting factorization progress via python's `print`
mkl_set_progress(mkl_progress)
else:
Expand Down Expand Up @@ -399,7 +397,7 @@ cdef class MKLPardisoSolver:
if bp == xp:
raise PardisoError("b and x must be different arrays")

cdef int_t nrhs = b.shape[1] if b.ndim == 2 else 1
cdef MKL_INT nrhs = b.shape[1] if b.ndim == 2 else 1

if transpose:
self.set_iparm(11, 2)
Expand All @@ -412,7 +410,7 @@ cdef class MKLPardisoSolver:
def perm(self):
""" Fill-reducing permutation vector used inside pardiso.
"""
if self._is_32:
if self._call32:
return np.array(self._par.perm)
else:
return np.array(self._par64.perm)
Expand All @@ -421,20 +419,20 @@ cdef class MKLPardisoSolver:
def iparm(self):
""" Parameter options for the pardiso solver.
"""
if self._is_32:
if self._call32:
return np.array(self._par.iparm)
else:
return np.array(self._par64.iparm)

def set_iparm(self, int_t i, int_t val):
def set_iparm(self, MKL_INT i, MKL_INT val):
if i > 63 or i < 0:
raise IndexError(f"index {i} is out of bounds for size 64 array")
if i not in [
1, 3, 4, 5, 7, 9, 10, 11, 12, 17, 18, 20, 23,
24, 26, 30, 33, 34, 35, 36, 38, 42, 55, 59
]:
raise PardisoError(f"cannot set parameter {i} of the iparm array")
if self._is_32:
if self._call32:
self._par.iparm[i] = val
else:
self._par64.iparm[i] = val
Expand All @@ -444,8 +442,12 @@ cdef class MKLPardisoSolver:
return self.iparm[17]

cdef _initialize(self, _par_params par, A, matrix_type, verbose):
if sizeof(MKL_INT) == 4:
np_int_dtype = np.int32
else:
np_int_dtype = np.int64
par.n = A.shape[0]
par.perm = np.empty(par.n, dtype=np.int32)
par.perm = np.empty(par.n, dtype=np_int_dtype)

par.maxfct = 1
par.mnum = 1
Expand Down Expand Up @@ -487,28 +489,21 @@ cdef class MKLPardisoSolver:
par.iparm[55] = 0 # Internal function used to work with pivot and calculation of diagonal arrays turned off.
par.iparm[59] = 0 # operate in-core mode

if _par_params is _PardisoParams:
indices = np.require(A.indices, dtype=np.int32)
indptr = np.require(A.indptr, dtype=np.int32)
else:
indices = np.require(A.indices, dtype=np.int64)
indptr = np.require(A.indptr, dtype=np.int64)

par.ia = indptr
par.ja = indices
par.ia = np.require(A.indices, dtype=np_int_dtype)
par.ja = np.require(A.indptr, dtype=np_int_dtype)

cdef _set_A(self, data):
self._Adata = data
self._Adata = np.ascontiguousarray(data)
self.a = np.PyArray_DATA(data)

def __dealloc__(self):
# Need to call pardiso with phase=-1 to release memory
cdef int_t phase=-1, nrhs=0, error=0
cdef long_t phase64=-1, nrhs64=0, error64=0
cdef MKL_INT phase=-1, nrhs=0, error=0
cdef MKL_INT64 phase64=-1, nrhs64=0, error64=0

if self._initialized:
PyThread_acquire_lock(self.lock, 1)
if self._is_32:
if self._call32:
pardiso(
self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
&phase, &self._par.n, self.a, NULL, NULL, NULL, &nrhs, self._par.iparm,
Expand Down Expand Up @@ -544,7 +539,7 @@ cdef class MKLPardisoSolver:

self._factored = True

cdef _solve(self, void* b, void* x, int_t nrhs_in):
cdef _solve(self, void* b, void* x, MKL_INT nrhs_in):
#phase = 33
if(not self._factored):
raise PardisoError("Cannot solve without a previous factorization.")
Expand All @@ -554,12 +549,12 @@ cdef class MKLPardisoSolver:
raise PardisoError("Solve step error, "+_err_messages[err])

@cython.boundscheck(False)
cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0) nogil:
cdef int_t error=0
cdef long_t error64=0, phase64=phase, nrhs64=nrhs
cdef MKL_INT _run_pardiso(self, MKL_INT phase, void* b=NULL, void* x=NULL, MKL_INT nrhs=0) nogil:
cdef MKL_INT error=0
cdef MKL_INT64 error64=0, phase64=phase, nrhs64=nrhs

PyThread_acquire_lock(self.lock, 1)
if self._is_32:
if self._call32:
pardiso(self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
&phase, &self._par.n, self.a, &self._par.ia[0], &self._par.ja[0],
&self._par.perm[0], &nrhs, self._par.iparm, &self._par.msglvl, b, x, &error)
Expand All @@ -568,5 +563,5 @@ cdef class MKLPardisoSolver:
&phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0],
&self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64)
PyThread_release_lock(self.lock)
error = error or error64
error = error or <MKL_INT> error64
return error

0 comments on commit 4ed000d

Please sign in to comment.