Skip to content

Commit

Permalink
Merge branch 'develop' into refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ErjieWu authored Jan 4, 2025
2 parents af61224 + 9ab9150 commit 59ebcb1
Show file tree
Hide file tree
Showing 17 changed files with 349 additions and 135 deletions.
207 changes: 183 additions & 24 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,101 @@
#include "module_base/global_variable.h"
#endif

#ifdef __CUDA
#include <base/macros/macros.h>
#include <cuda_runtime.h>
#include <thrust/complex.h>
#include <thrust/execution_policy.h>
#include <thrust/inner_product.h>
#include "module_base/tool_quit.h"

#include "cublas_v2.h"

namespace BlasUtils{

static cublasHandle_t cublas_handle = nullptr;

void createGpuBlasHandle(){
if (cublas_handle == nullptr) {
cublasErrcheck(cublasCreate(&cublas_handle));
}
}

void destoryBLAShandle(){
if (cublas_handle != nullptr) {
cublasErrcheck(cublasDestroy(cublas_handle));
cublas_handle = nullptr;
}
}


cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return CUBLAS_OP_N;
}
else if(trans == 'T')
{
return CUBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return CUBLAS_OP_C;
}
return CUBLAS_OP_N;
}

} // namespace BlasUtils

#endif

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
saxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
daxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
caxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
#endif
}
}


Expand All @@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
#endif
}
}


Expand All @@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return sdot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
float result = 0.0;
cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return sdot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return ddot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return ddot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

// Col-Major part
Expand All @@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

// Symm and Hemm part. Only col-major is supported.
Expand Down
Loading

0 comments on commit 59ebcb1

Please sign in to comment.