Skip to content

Commit

Permalink
Use cusparseXcsrsort instead of the deprecated cusparseXcsru2csr
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-sim-dev committed Sep 29, 2024
1 parent 8ed11a5 commit d4d149a
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions linalg/sparsemat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,42 +469,34 @@ void SparseMatrix::SortColumnIndices()
#if defined(MFEM_USE_CUDA)
size_t pBufferSizeInBytes = 0;
void *pBuffer = NULL;
int *P = NULL;

const int n = Height();
const int m = Width();
const int nnzA = J.Capacity();
real_t * d_a_sorted = ReadWriteData();
const int * d_ia = ReadI();
int * d_ja_sorted = ReadWriteJ();
csru2csrInfo_t sortInfoA;

cusparseMatDescr_t matA_descr;
cusparseCreateMatDescr( &matA_descr );
cusparseSetMatIndexBase( matA_descr, CUSPARSE_INDEX_BASE_ZERO );
cusparseSetMatType( matA_descr, CUSPARSE_MATRIX_TYPE_GENERAL );

cusparseCreateCsru2csrInfo( &sortInfoA );

#ifdef MFEM_USE_SINGLE
cusparseScsru2csr_bufferSizeExt( handle, n, m, nnzA, d_a_sorted, d_ia,
d_ja_sorted, sortInfoA,
&pBufferSizeInBytes);
#elif defined MFEM_USE_DOUBLE
cusparseDcsru2csr_bufferSizeExt( handle, n, m, nnzA, d_a_sorted, d_ia,
d_ja_sorted, sortInfoA,
&pBufferSizeInBytes);
#else
MFEM_ABORT("Floating point type undefined");
#endif
cusparseXcsrsort_bufferSizeExt(handle, n, m, nnzA, d_ia, d_ja_sorted,
&pBufferSizeInBytes);

CuMemAlloc( &pBuffer, pBufferSizeInBytes );

cusparseCreateIdentityPermutation(handle, nnzA, P);
cusparseXcsrsort(handle, n, m, nnzA, descrA, d_ia, d_ja_sorted, P, pBuffer);

#ifdef MFEM_USE_SINGLE
cusparseScsru2csr( handle, n, m, nnzA, matA_descr, d_a_sorted, d_ia,
d_ja_sorted, sortInfoA, pBuffer);
cusparseSgthr(handle, nnzA, d_a_sorted, d_a_sorted, P,
CUSPARSE_INDEX_BASE_ZERO);
#elif defined MFEM_USE_DOUBLE
cusparseDcsru2csr( handle, n, m, nnzA, matA_descr, d_a_sorted, d_ia,
d_ja_sorted, sortInfoA, pBuffer);
cusparseDgthr(handle, nnzA, d_a_sorted, d_a_sorted, P,
CUSPARSE_INDEX_BASE_ZERO);
#else
MFEM_ABORT("Floating point type undefined");
#endif
Expand All @@ -513,10 +505,10 @@ void SparseMatrix::SortColumnIndices()
// wait for it to finish before we can free device temporaries.
MFEM_STREAM_SYNC;

cusparseDestroyCsru2csrInfo( sortInfoA );
cusparseDestroyMatDescr( matA_descr );

CuMemFree( pBuffer );
CuMemFree( P );
#endif
}
else if ( Device::Allows( Backend::HIP_MASK ))
Expand Down

0 comments on commit d4d149a

Please sign in to comment.