From 414446ac18aa79385a143f9861f10e49575937ec Mon Sep 17 00:00:00 2001 From: liiutao <74701833+A-006@users.noreply.github.com> Date: Sat, 4 Jan 2025 17:33:07 +0800 Subject: [PATCH 1/2] Refactor:remove cal_tau from ElecStateLCAO (#5802) * modify the cal_tau in lcao * add template for cal_tau * updatea func for cal_tau --- source/module_elecstate/elecstate_lcao.cpp | 5 ++- source/module_elecstate/elecstate_lcao.h | 1 - .../elecstate_lcao_cal_tau.cpp | 39 +++++++++++++------ .../module_elecstate/elecstate_lcao_cal_tau.h | 21 ++++++++++ source/module_esolver/esolver_ks_lcao.cpp | 6 ++- source/module_rdmft/update_state_rdmft.cpp | 6 +-- 6 files changed, 58 insertions(+), 20 deletions(-) create mode 100644 source/module_elecstate/elecstate_lcao_cal_tau.h diff --git a/source/module_elecstate/elecstate_lcao.cpp b/source/module_elecstate/elecstate_lcao.cpp index 9b2c945fd9..748ef7a9b8 100644 --- a/source/module_elecstate/elecstate_lcao.cpp +++ b/source/module_elecstate/elecstate_lcao.cpp @@ -8,6 +8,7 @@ #include "module_hamilt_lcao/module_gint/grid_technique.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_parameter/parameter.h" +#include "elecstate_lcao_cal_tau.h" #include @@ -64,7 +65,7 @@ void ElecStateLCAO>::psiToRho(const psi::Psical_tau(psi); + elecstate::lcao_cal_tau_k(gint_k, this->charge); } this->charge->renormalize_rho(); @@ -99,7 +100,7 @@ void ElecStateLCAO::psiToRho(const psi::Psi& psi) if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) { - this->cal_tau(psi); + elecstate::lcao_cal_tau_gamma(gint_gamma, this->charge); } this->charge->renormalize_rho(); diff --git a/source/module_elecstate/elecstate_lcao.h b/source/module_elecstate/elecstate_lcao.h index 4beeb017f0..c85f6e27e5 100644 --- a/source/module_elecstate/elecstate_lcao.h +++ b/source/module_elecstate/elecstate_lcao.h @@ -46,7 +46,6 @@ class ElecStateLCAO : public ElecState // virtual void psiToRho(const psi::Psi& psi) override; // return current electronic density rho, as a input for constructing Hamiltonian // const double* getRho(int spin) const override; - virtual void cal_tau(const psi::Psi& psi) override; // update charge density for next scf step // void getNewRho() override; diff --git a/source/module_elecstate/elecstate_lcao_cal_tau.cpp b/source/module_elecstate/elecstate_lcao_cal_tau.cpp index e6bd6561a0..c7d83bd1e9 100644 --- a/source/module_elecstate/elecstate_lcao_cal_tau.cpp +++ b/source/module_elecstate/elecstate_lcao_cal_tau.cpp @@ -1,41 +1,56 @@ #include "elecstate_lcao.h" - +#include "elecstate_lcao_cal_tau.h" #include "module_base/timer.h" namespace elecstate { // calculate the kinetic energy density tau, multi-k case -template <> -void ElecStateLCAO>::cal_tau(const psi::Psi>& psi) +void lcao_cal_tau_k(Gint_k* gint_k, + Charge* charge) { ModuleBase::timer::tick("ElecStateLCAO", "cal_tau"); for (int is = 0; is < PARAM.inp.nspin; is++) { - ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx); + ModuleBase::GlobalFunc::ZEROS(charge->kin_r[is], charge->nrxx); } - Gint_inout inout1(this->charge->kin_r, Gint_Tools::job_type::tau, PARAM.inp.nspin); - this->gint_k->cal_gint(&inout1); + Gint_inout inout1(charge->kin_r, Gint_Tools::job_type::tau, PARAM.inp.nspin); + gint_k->cal_gint(&inout1); ModuleBase::timer::tick("ElecStateLCAO", "cal_tau"); return; } // calculate the kinetic energy density tau, gamma-only case -template <> -void ElecStateLCAO::cal_tau(const psi::Psi& psi) +void lcao_cal_tau_gamma(Gint_Gamma* gint_gamma, + Charge* charge) { ModuleBase::timer::tick("ElecStateLCAO", "cal_tau"); for (int is = 0; is < PARAM.inp.nspin; is++) { - ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx); + ModuleBase::GlobalFunc::ZEROS(charge->kin_r[is], charge->nrxx); } - Gint_inout inout1(this->charge->kin_r, Gint_Tools::job_type::tau, PARAM.inp.nspin); - this->gint_gamma->cal_gint(&inout1); + Gint_inout inout1(charge->kin_r, Gint_Tools::job_type::tau, PARAM.inp.nspin); + gint_gamma->cal_gint(&inout1); ModuleBase::timer::tick("ElecStateLCAO", "cal_tau"); return; } -} \ No newline at end of file +template <> +void lcao_cal_tau(Gint_Gamma* gint_gamma, + Gint_k* gint_k, + Charge* charge) +{ + lcao_cal_tau_gamma(gint_gamma, charge); +} +template <> +void lcao_cal_tau>(Gint_Gamma* gint_gamma, + Gint_k* gint_k, + Charge* charge) +{ + lcao_cal_tau_k(gint_k, charge); +} + +} // namespace elecstate \ No newline at end of file diff --git a/source/module_elecstate/elecstate_lcao_cal_tau.h b/source/module_elecstate/elecstate_lcao_cal_tau.h new file mode 100644 index 0000000000..c0cfbc078a --- /dev/null +++ b/source/module_elecstate/elecstate_lcao_cal_tau.h @@ -0,0 +1,21 @@ +#ifndef ELECSTATE_LCAO_CAL_TAU_H +#define ELECSTATE_LCAO_CAL_TAU_H +#include "module_elecstate/module_charge/charge.h" +#include "module_hamilt_lcao/module_gint/gint_gamma.h" +#include "module_hamilt_lcao/module_gint/gint_k.h" +namespace elecstate +{ + + void lcao_cal_tau_k(Gint_k* gint_k, + Charge* charge); + + void lcao_cal_tau_gamma(Gint_Gamma* gint_gamma, + Charge* charge); + + template + void lcao_cal_tau(Gint_Gamma* gint_gamma, + Gint_k* gint_k, + Charge* charge); + +} +#endif \ No newline at end of file diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 3ee810abc0..b9fb62e853 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -34,6 +34,7 @@ #include "module_base/global_function.h" #include "module_cell/module_neighbor/sltk_grid_driver.h" #include "module_elecstate/cal_ux.h" +#include "module_elecstate/elecstate_lcao_cal_tau.h" #include "module_elecstate/module_charge/symmetry_rho.h" #include "module_elecstate/occupy.h" #include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag @@ -927,8 +928,9 @@ void ESolver_KS_LCAO::after_scf(UnitCell& ucell, const int istep) // 1) calculate the kinetic energy density tau, sunliang 2024-09-18 if (PARAM.inp.out_elf[0] > 0) { - assert(this->psi != nullptr); - this->pelec->cal_tau(*(this->psi)); + elecstate::lcao_cal_tau(&(this->GG), + &(this->GK), + this->pelec->charge); } //! 2) call after_scf() of ESolver_KS diff --git a/source/module_rdmft/update_state_rdmft.cpp b/source/module_rdmft/update_state_rdmft.cpp index abe56d71c3..dc0398e8c9 100644 --- a/source/module_rdmft/update_state_rdmft.cpp +++ b/source/module_rdmft/update_state_rdmft.cpp @@ -8,7 +8,7 @@ #include "module_elecstate/module_dm/cal_dm_psi.h" #include "module_elecstate/module_dm/density_matrix.h" #include "module_elecstate/module_charge/symmetry_rho.h" - +#include "module_elecstate/elecstate_lcao_cal_tau.h" namespace rdmft { @@ -118,7 +118,7 @@ void RDMFT::update_charge(UnitCell& ucell) // } // Gint_inout inout1(charge->kin_r, Gint_Tools::job_type::tau); // GG->cal_gint(&inout1); - this->pelec->cal_tau(wfc); + elecstate::lcao_cal_tau_gamma(GG, charge); } charge->renormalize_rho(); @@ -148,7 +148,7 @@ void RDMFT::update_charge(UnitCell& ucell) // } // Gint_inout inout1(charge->kin_r, Gint_Tools::job_type::tau); // GK->cal_gint(&inout1); - this->pelec->cal_tau(wfc); + elecstate::lcao_cal_tau_k(GK, charge); } charge->renormalize_rho(); From 9ab9150e278f9c97e57778c5a4854c0287841a3b Mon Sep 17 00:00:00 2001 From: Critsium Date: Sat, 4 Jan 2025 04:35:14 -0500 Subject: [PATCH 2/2] [Feature] Add some GPU kernels to blas_connector (#5799) * Initial commit * Modify CMakeLists * Complete CMakeLists in module_base * Add blas_connector.cpp definition * Fix module_base tests * Fix tests failure * fix opt_test * OPTFIX2 * Return all changes * Fix global_func_text * Fix MPI Bug * return base_math_chebyshev * Fix MPI bug * Finish * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> --- source/module_base/blas_connector.cpp | 207 ++++++++++++++++-- source/module_base/test/CMakeLists.txt | 76 +++---- .../test/clebsch_gordan_coeff_test.cpp | 12 - .../module_base/test/complexmatrix_test.cpp | 6 - .../module_base/test/global_function_test.cpp | 10 +- .../module_base/test/inverse_matrix_test.cpp | 6 - .../module_base/test/math_chebyshev_test.cpp | 4 + source/module_base/test/math_ylmreal_test.cpp | 10 - source/module_base/test/opt_CG_test.cpp | 46 +++- source/module_base/test/opt_TN_test.cpp | 26 ++- source/module_base/test/opt_test_tools.cpp | 3 + 11 files changed, 291 insertions(+), 115 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 85ea4584e9..3bb91e2f01 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -5,32 +5,101 @@ #include "module_base/global_variable.h" #endif +#ifdef __CUDA +#include +#include +#include +#include +#include +#include "module_base/tool_quit.h" + +#include "cublas_v2.h" + +namespace BlasUtils{ + + static cublasHandle_t cublas_handle = nullptr; + + void createGpuBlasHandle(){ + if (cublas_handle == nullptr) { + cublasErrcheck(cublasCreate(&cublas_handle)); + } + } + + void destoryBLAShandle(){ + if (cublas_handle != nullptr) { + cublasErrcheck(cublasDestroy(cublas_handle)); + cublas_handle = nullptr; + } + } + + + cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name) + { + if (trans == 'N') + { + return CUBLAS_OP_N; + } + else if(trans == 'T') + { + return CUBLAS_OP_T; + } + else if(is_complex && trans == 'C') + { + return CUBLAS_OP_C; + } + return CUBLAS_OP_N; + } + +} // namespace BlasUtils + +#endif + void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { saxpy_(&n, &alpha, X, &incX, Y, &incY); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); +#endif + } } void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { daxpy_(&n, &alpha, X, &incX, Y, &incY); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); +#endif + } } void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { caxpy_(&n, &alpha, X, &incX, Y, &incY); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY)); +#endif + } } void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { zaxpy_(&n, &alpha, X, &incX, Y, &incY); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY)); +#endif + } } @@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i { if (device_type == base_device::AbacusDevice_t::CpuDevice) { sscal_(&n, &alpha, X, &incX); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { +#ifdef __CUDA + cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); +#endif + } } void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { dscal_(&n, &alpha, X, &incX); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { +#ifdef __CUDA + cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); +#endif + } } void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { cscal_(&n, &alpha, X, &incX); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { +#ifdef __CUDA + cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX)); +#endif + } } void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { zscal_(&n, &alpha, X, &incX); -} + } + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { +#ifdef __CUDA + cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX)); +#endif + } } @@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo if (device_type == base_device::AbacusDevice_t::CpuDevice) { return sdot_(&n, X, &incX, Y, &incY); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + float result = 0.0; + cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); + return result; +#endif + } return sdot_(&n, X, &incX, Y, &incY); } @@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d if (device_type == base_device::AbacusDevice_t::CpuDevice) { return ddot_(&n, X, &incX, Y, &incY); } + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + double result = 0.0; + cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); + return result; +#endif + } return ddot_(&n, X, &incX, Y, &incY); } @@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons &alpha, b, &ldb, a, &lda, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ sgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); +#endif + } } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons &alpha, b, &ldb, a, &lda, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); +#endif + } } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons &alpha, b, &ldb, a, &lda, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc)); +#endif + } } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons &alpha, b, &ldb, a, &lda, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc)); +#endif + } } // Col-Major part @@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ sgemm_mth_(&transb, &transa, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); +#endif + } } void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, @@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); +#endif + } } void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, @@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); +#endif + } } void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, @@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } - #ifdef __DSP +#ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); } - #endif +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); +#endif + } } // Symm and Hemm part. Only col-major is supported. diff --git a/source/module_base/test/CMakeLists.txt b/source/module_base/test/CMakeLists.txt index 09b77c7404..0c8fd53461 100644 --- a/source/module_base/test/CMakeLists.txt +++ b/source/module_base/test/CMakeLists.txt @@ -2,8 +2,8 @@ remove_definitions(-D__MPI) install(DIRECTORY data DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) AddTest( TARGET base_blas_connector - LIBS parameter ${math_libs} - SOURCES blas_connector_test.cpp ../blas_connector.cpp + LIBS parameter ${math_libs} base device + SOURCES blas_connector_test.cpp ) AddTest( TARGET base_atom_in @@ -31,8 +31,8 @@ AddTest( ) ADDTest( TARGET base_global_function - LIBS parameter ${math_libs} - SOURCES global_function_test.cpp ../blas_connector.cpp ../global_function.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp + LIBS parameter ${math_libs} + SOURCES global_function_test.cpp ../global_function.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../memory.cpp ../timer.cpp ) AddTest( TARGET base_vector3 @@ -41,8 +41,8 @@ AddTest( ) AddTest( TARGET base_matrix3 - LIBS parameter ${math_libs} - SOURCES matrix3_test.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp ../blas_connector.cpp + LIBS parameter ${math_libs} base device + SOURCES matrix3_test.cpp ) AddTest( TARGET base_intarray @@ -56,8 +56,8 @@ AddTest( ) AddTest( TARGET base_matrix - LIBS parameter ${math_libs} - SOURCES matrix_test.cpp ../blas_connector.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp + LIBS parameter ${math_libs} base device + SOURCES matrix_test.cpp ) AddTest( TARGET base_complexarray @@ -66,8 +66,8 @@ AddTest( ) AddTest( TARGET base_complexmatrix - LIBS parameter ${math_libs} - SOURCES complexmatrix_test.cpp ../blas_connector.cpp ../complexmatrix.cpp ../matrix.cpp + LIBS parameter ${math_libs} base device + SOURCES complexmatrix_test.cpp ) AddTest( TARGET base_integral @@ -81,10 +81,8 @@ AddTest( ) AddTest( TARGET base_ylmreal - LIBS parameter ${math_libs} device - SOURCES math_ylmreal_test.cpp ../blas_connector.cpp ../math_ylmreal.cpp ../complexmatrix.cpp ../global_variable.cpp ../ylm.cpp ../realarray.cpp ../timer.cpp ../matrix.cpp ../vector3.h - ../parallel_reduce.cpp ../parallel_global.cpp ../parallel_comm.cpp ../parallel_common.cpp - ../memory.cpp ../libm/branred.cpp ../libm/sincos.cpp + LIBS parameter ${math_libs} base device + SOURCES math_ylmreal_test.cpp ../libm/branred.cpp ../libm/sincos.cpp ) AddTest( TARGET base_math_sphbes @@ -93,13 +91,13 @@ AddTest( ) AddTest( TARGET base_mathzone - LIBS parameter ${math_libs} - SOURCES mathzone_test.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp ../blas_connector.cpp + LIBS parameter ${math_libs} base device + SOURCES mathzone_test.cpp ) AddTest( TARGET base_mathzone_add1 - LIBS parameter ${math_libs} - SOURCES mathzone_add1_test.cpp ../blas_connector.cpp ../mathzone_add1.cpp ../math_sphbes.cpp ../matrix3.cpp ../matrix.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp + LIBS parameter ${math_libs} base device + SOURCES mathzone_add1_test.cpp ) AddTest( TARGET base_math_polyint @@ -108,8 +106,8 @@ AddTest( ) AddTest( TARGET base_gram_schmidt_orth - LIBS parameter ${math_libs} - SOURCES gram_schmidt_orth_test.cpp ../blas_connector.cpp ../gram_schmidt_orth.h ../gram_schmidt_orth-inl.h ../global_function.h ../math_integral.cpp + LIBS parameter ${math_libs} base device + SOURCES gram_schmidt_orth_test.cpp ) AddTest( TARGET base_math_bspline @@ -118,8 +116,8 @@ AddTest( ) AddTest( TARGET base_inverse_matrix - LIBS parameter ${math_libs} - SOURCES inverse_matrix_test.cpp ../blas_connector.cpp ../inverse_matrix.cpp ../complexmatrix.cpp ../matrix.cpp ../timer.cpp + LIBS parameter ${math_libs} base device + SOURCES inverse_matrix_test.cpp ) AddTest( TARGET base_mymath @@ -134,26 +132,26 @@ AddTest( AddTest( TARGET base_math_chebyshev - LIBS parameter ${math_libs} device container - SOURCES math_chebyshev_test.cpp ../blas_connector.cpp ../math_chebyshev.cpp ../tool_quit.cpp ../global_variable.cpp ../timer.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../parallel_reduce.cpp + LIBS parameter ${math_libs} base device container + SOURCES math_chebyshev_test.cpp ) AddTest( TARGET base_lapack_connector - LIBS parameter ${math_libs} - SOURCES lapack_connector_test.cpp ../blas_connector.cpp ../lapack_connector.h + LIBS parameter ${math_libs} base device + SOURCES lapack_connector_test.cpp ) AddTest( TARGET base_opt_CG - LIBS parameter ${math_libs} - SOURCES opt_CG_test.cpp opt_test_tools.cpp ../blas_connector.cpp ../opt_CG.cpp ../opt_DCsrch.cpp ../global_variable.cpp ../parallel_reduce.cpp + LIBS parameter ${math_libs} base device + SOURCES opt_CG_test.cpp opt_test_tools.cpp ) AddTest( TARGET base_opt_TN - LIBS parameter ${math_libs} - SOURCES opt_TN_test.cpp opt_test_tools.cpp ../blas_connector.cpp ../opt_CG.cpp ../opt_DCsrch.cpp ../global_variable.cpp ../parallel_reduce.cpp + LIBS parameter ${math_libs} base device + SOURCES opt_TN_test.cpp opt_test_tools.cpp ) AddTest( @@ -194,28 +192,26 @@ AddTest( AddTest( TARGET spherical_bessel_transformer - SOURCES spherical_bessel_transformer_test.cpp ../blas_connector.cpp ../spherical_bessel_transformer.cpp ../math_sphbes.cpp ../math_integral.cpp ../timer.cpp - LIBS parameter ${math_libs} + SOURCES spherical_bessel_transformer_test.cpp + LIBS parameter ${math_libs} base device ) AddTest( TARGET cubic_spline - SOURCES cubic_spline_test.cpp ../blas_connector.cpp ../cubic_spline.cpp - LIBS parameter ${math_libs} + SOURCES cubic_spline_test.cpp + LIBS parameter ${math_libs} base device ) AddTest( TARGET clebsch_gordan_coeff_test - SOURCES clebsch_gordan_coeff_test.cpp ../blas_connector.cpp ../clebsch_gordan_coeff.cpp ../intarray.cpp ../realarray.cpp ../complexmatrix.cpp ../matrix.cpp ../timer.cpp - ../math_ylmreal.cpp ../global_variable.cpp ../ylm.cpp ../timer.cpp ../vector3.h ../parallel_reduce.cpp ../parallel_global.cpp ../parallel_comm.cpp ../parallel_common.cpp - ../memory.cpp ../libm/branred.cpp ../libm/sincos.cpp ../inverse_matrix.cpp ../lapack_connector.h - LIBS parameter ${math_libs} device + SOURCES clebsch_gordan_coeff_test.cpp + LIBS parameter ${math_libs} base device ) AddTest( TARGET assoc_laguerre_test - SOURCES assoc_laguerre_test.cpp ../blas_connector.cpp ../assoc_laguerre.cpp ../tool_quit.cpp ../global_variable.cpp ../global_file.cpp ../global_function.cpp ../memory.cpp ../timer.cpp - LIBS parameter ${math_libs} + SOURCES assoc_laguerre_test.cpp + LIBS parameter ${math_libs} base device ) AddTest( diff --git a/source/module_base/test/clebsch_gordan_coeff_test.cpp b/source/module_base/test/clebsch_gordan_coeff_test.cpp index 16efa091b5..888249765f 100644 --- a/source/module_base/test/clebsch_gordan_coeff_test.cpp +++ b/source/module_base/test/clebsch_gordan_coeff_test.cpp @@ -16,18 +16,6 @@ * - functions: gen_rndm_r and compute_ap */ -namespace ModuleBase -{ -void WARNING_QUIT(const std::string& file, const std::string& description) -{ - return; -} -void WARNING(const std::string& file, const std::string& description) -{ - return; -} -} // namespace ModuleBase - TEST(ClebschGordanTest, ClebschGordanExit) { int lmaxkb = -2; diff --git a/source/module_base/test/complexmatrix_test.cpp b/source/module_base/test/complexmatrix_test.cpp index 0adc52363a..da11fafcfd 100644 --- a/source/module_base/test/complexmatrix_test.cpp +++ b/source/module_base/test/complexmatrix_test.cpp @@ -38,12 +38,6 @@ * */ -//a mock function of WARNING_QUIT, to avoid the uncorrected call by matrix.cpp at line 37. -namespace ModuleBase -{ - void WARNING_QUIT(const std::string &file,const std::string &description) {exit(1);} -} - inline void EXPECT_COMPLEX_EQ(const std::complex& a,const std::complex& b) { EXPECT_DOUBLE_EQ(a.real(),b.real()); diff --git a/source/module_base/test/global_function_test.cpp b/source/module_base/test/global_function_test.cpp index 013396d6b1..05d4d70877 100644 --- a/source/module_base/test/global_function_test.cpp +++ b/source/module_base/test/global_function_test.cpp @@ -4,7 +4,6 @@ #include "module_parameter/parameter.h" #undef private #include "../vector3.h" -#include "../blas_connector.h" #include "../tool_quit.h" #include #include @@ -692,6 +691,9 @@ TEST_F(GlobalFunctionTest,MemAvailable) TEST_F(GlobalFunctionTest,BlockHere) { +#ifdef __MPI +#undef __MPI +#endif std::string output2; std::string block_in="111"; GlobalV::MY_RANK=1; @@ -706,6 +708,9 @@ TEST_F(GlobalFunctionTest,BlockHere) TEST_F(GlobalFunctionTest,BlockHere2) { +#ifdef __MPI +#undef __MPI +#endif std::string output2; std::string block_in="111"; GlobalV::MY_RANK=0; @@ -724,6 +729,9 @@ TEST_F(GlobalFunctionTest,BlockHere2) TEST_F(GlobalFunctionTest,BlockHere3) { +#ifdef __MPI +#undef __MPI +#endif std::string output2; std::string block_in="111"; GlobalV::MY_RANK=0; diff --git a/source/module_base/test/inverse_matrix_test.cpp b/source/module_base/test/inverse_matrix_test.cpp index a871f906cd..b88e556af1 100644 --- a/source/module_base/test/inverse_matrix_test.cpp +++ b/source/module_base/test/inverse_matrix_test.cpp @@ -16,12 +16,6 @@ * - computes the inverse of a dim*dim real matrix */ -//a mock function of WARNING_QUIT, to avoid the uncorrected call by matrix.cpp at line 37. -namespace ModuleBase -{ - void WARNING_QUIT(const std::string &file,const std::string &description) {exit(1);} -} - TEST(InverseMatrixComplexTest, InverseMatrixComplex) { int dim = 10; diff --git a/source/module_base/test/math_chebyshev_test.cpp b/source/module_base/test/math_chebyshev_test.cpp index 125dbdaeaa..a7ea215266 100644 --- a/source/module_base/test/math_chebyshev_test.cpp +++ b/source/module_base/test/math_chebyshev_test.cpp @@ -336,6 +336,8 @@ TEST_F(MathChebyshevTest, tracepolyA) TEST_F(MathChebyshevTest, checkconverge) { +#ifdef __MPI +#undef __MPI const int norder = 100; p_chetest = new ModuleBase::Chebyshev(norder); auto fun_sigma_y @@ -377,6 +379,8 @@ TEST_F(MathChebyshevTest, checkconverge) delete[] v; delete p_chetest; +#define __MPI +#endif } TEST_F(MathChebyshevTest, recurs) diff --git a/source/module_base/test/math_ylmreal_test.cpp b/source/module_base/test/math_ylmreal_test.cpp index c973d8cd28..891c948f7e 100644 --- a/source/module_base/test/math_ylmreal_test.cpp +++ b/source/module_base/test/math_ylmreal_test.cpp @@ -36,16 +36,6 @@ * */ - - -//mock functions of WARNING_QUIT and WARNING -namespace ModuleBase -{ - void WARNING_QUIT(const std::string &file,const std::string &description) {exit(1);} - void WARNING(const std::string &file,const std::string &description) {return ;} -} - - class YlmRealTest : public testing::Test { protected: diff --git a/source/module_base/test/opt_CG_test.cpp b/source/module_base/test/opt_CG_test.cpp index 4b324c7cbb..b8abeb5760 100644 --- a/source/module_base/test/opt_CG_test.cpp +++ b/source/module_base/test/opt_CG_test.cpp @@ -1,3 +1,6 @@ +#ifdef __MPI +#undef __MPI +#endif #include "gtest/gtest.h" #include "../opt_CG.h" #include "../opt_DCsrch.h" @@ -18,10 +21,10 @@ class CG_test : public testing::Test double residual = 10.; double tol = 1e-5; int final_iter = 0; - char *task = NULL; - double *Ap = NULL; - double *p = NULL; - double *x = NULL; + char *task = nullptr; + double *Ap = nullptr; + double *p = nullptr; + double *x = nullptr; void SetUp() { @@ -65,7 +68,8 @@ class CG_test : public testing::Test tools.le.get_Ap(tools.le.A, p, Ap); int ifPD = 0; step = cg.step_length(Ap, p, ifPD); - for (int i = 0; i < 3; ++i) x[i] += step * p[i]; + for (int i = 0; i < 3; ++i) { x[i] += step * p[i]; +} residual = cg.get_residual(); } } @@ -102,14 +106,16 @@ class CG_test : public testing::Test { tools.dfuncdx(x, gradient, func_label); residual = 0; - for (int i = 0; i<3 ;++i) residual += gradient[i] * gradient[i]; + for (int i = 0; i<3 ;++i) { residual += gradient[i] * gradient[i]; +} if (residual < tol) { final_iter = iter; break; } cg.next_direct(gradient, cg_label, p); - for (int i = 0; i < 3; ++i) temp_x[i] = x[i]; + for (int i = 0; i < 3; ++i) { temp_x[i] = x[i]; +} task[0] = 'S'; task[1] = 'T'; task[2] = 'A'; task[3] = 'R'; task[4] = 'T'; while (true) { @@ -118,7 +124,8 @@ class CG_test : public testing::Test ds.dcSrch(f, g, step, task); if (task[0] == 'F' && task[1] == 'G') { - for (int j = 0; j < 3; ++j) temp_x[j] = x[j] + step * p[j]; + for (int j = 0; j < 3; ++j) { temp_x[j] = x[j] + step * p[j]; +} continue; } else if (task[0] == 'C' && task[1] == 'O') @@ -134,7 +141,8 @@ class CG_test : public testing::Test break; } } - for (int i = 0; i < 3; ++i) x[i] += step * p[i]; + for (int i = 0; i < 3; ++i) { x[i] += step * p[i]; +} } delete[] temp_x; delete[] gradient; @@ -143,51 +151,71 @@ class CG_test : public testing::Test TEST_F(CG_test, Stand_Solve_LinearEq) { +#ifdef __MPI +#undef __MPI CG_Solve_LinearEq(); EXPECT_NEAR(x[0], 0.5, DOUBLETHRESHOLD); EXPECT_NEAR(x[1], 1.6429086563584579739e-18, DOUBLETHRESHOLD); EXPECT_NEAR(x[2], 1.5, DOUBLETHRESHOLD); ASSERT_EQ(final_iter, 4); ASSERT_EQ(cg.get_iter(), 4); +#define __MPI +#endif } TEST_F(CG_test, PR_Solve_LinearEq) { +#ifdef __MPI +#undef __MPI Solve(1, 0); EXPECT_NEAR(x[0], 0.50000000000003430589, DOUBLETHRESHOLD); EXPECT_NEAR(x[1], -3.4028335704761047964e-14, DOUBLETHRESHOLD); EXPECT_NEAR(x[2], 1.5000000000000166533, DOUBLETHRESHOLD); ASSERT_EQ(final_iter, 3); ASSERT_EQ(cg.get_iter(), 3); +#define __MPI +#endif } TEST_F(CG_test, HZ_Solve_LinearEq) { +#ifdef __MPI +#undef __MPI Solve(2, 0); EXPECT_NEAR(x[0], 0.49999999999999944489, DOUBLETHRESHOLD); EXPECT_NEAR(x[1], -9.4368957093138305936e-16, DOUBLETHRESHOLD); EXPECT_NEAR(x[2], 1.5000000000000011102, DOUBLETHRESHOLD); ASSERT_EQ(final_iter, 3); ASSERT_EQ(cg.get_iter(), 3); +#define __MPI +#endif } TEST_F(CG_test, PR_Min_Func) { +#ifdef __MPI +#undef __MPI Solve(1, 1); EXPECT_NEAR(x[0], 4.0006805979150792396, DOUBLETHRESHOLD); EXPECT_NEAR(x[1], 2.0713759992720870429, DOUBLETHRESHOLD); EXPECT_NEAR(x[2], 9.2871067233169171118, DOUBLETHRESHOLD); ASSERT_EQ(final_iter, 18); ASSERT_EQ(cg.get_iter(), 18); +#define __MPI +#endif } TEST_F(CG_test, HZ_Min_Func) { +#ifdef __MPI +#undef __MPI Solve(2, 1); EXPECT_NEAR(x[0], 4.0006825378033568086, DOUBLETHRESHOLD); EXPECT_NEAR(x[1], 2.0691732100663737803, DOUBLETHRESHOLD); EXPECT_NEAR(x[2], 9.2780872787668311474, DOUBLETHRESHOLD); ASSERT_EQ(final_iter, 18); ASSERT_EQ(cg.get_iter(), 18); +#define __MPI +#endif } // g++ -std=c++11 ../opt_CG.cpp ../opt_DCsrch.cpp ./CG_test.cpp ./test_tools.cpp -lgtest -lpthread -lgtest_main -o test.exe \ No newline at end of file diff --git a/source/module_base/test/opt_TN_test.cpp b/source/module_base/test/opt_TN_test.cpp index db523b53e9..1fc5b7f2d6 100644 --- a/source/module_base/test/opt_TN_test.cpp +++ b/source/module_base/test/opt_TN_test.cpp @@ -17,9 +17,9 @@ class TN_test : public testing::Test double tol = 1e-5; int final_iter = 0; int flag = 0; - char *task = NULL; - double *p = NULL; - double *x = NULL; + char *task = nullptr; + double *p = nullptr; + double *x = nullptr; void SetUp() { @@ -61,7 +61,8 @@ class TN_test : public testing::Test { tools.dfuncdx(x, gradient, func_label); residual = 0; - for (int i = 0; i<3 ;++i) residual += gradient[i] * gradient[i]; + for (int i = 0; i<3 ;++i) { residual += gradient[i] * gradient[i]; +} if (residual < tol) { final_iter = iter; @@ -75,7 +76,8 @@ class TN_test : public testing::Test { tn.next_direct(x, gradient, flag, p, &(tools.mf), &ModuleESolver::ESolver_OF::dfuncdx); } - for (int i = 0; i < 3; ++i) temp_x[i] = x[i]; + for (int i = 0; i < 3; ++i) { temp_x[i] = x[i]; +} task[0] = 'S'; task[1] = 'T'; task[2] = 'A'; task[3] = 'R'; task[4] = 'T'; while (true) { @@ -84,7 +86,8 @@ class TN_test : public testing::Test ds.dcSrch(f, g, step, task); if (task[0] == 'F' && task[1] == 'G') { - for (int j = 0; j < 3; ++j) temp_x[j] = x[j] + step * p[j]; + for (int j = 0; j < 3; ++j) { temp_x[j] = x[j] + step * p[j]; +} continue; } else if (task[0] == 'C' && task[1] == 'O') @@ -100,7 +103,8 @@ class TN_test : public testing::Test break; } } - for (int i = 0; i < 3; ++i) x[i] += step * p[i]; + for (int i = 0; i < 3; ++i) { x[i] += step * p[i]; +} } delete[] temp_x; delete[] gradient; @@ -110,20 +114,28 @@ class TN_test : public testing::Test TEST_F(TN_test, TN_Solve_LinearEq) { +#ifdef __MPI +#undef __MPI Solve(0); EXPECT_NEAR(x[0], 0.50000000000003430589, DOUBLETHRESHOLD); EXPECT_NEAR(x[1], -3.4028335704761047964e-14, DOUBLETHRESHOLD); EXPECT_NEAR(x[2], 1.5000000000000166533, DOUBLETHRESHOLD); ASSERT_EQ(final_iter, 1); ASSERT_EQ(tn.get_iter(), 1); +#define __MPI +#endif } TEST_F(TN_test, TN_Min_Func) { +#ifdef __MPI +#undef __MPI Solve(1); EXPECT_NEAR(x[0], 4.0049968540891525137, DOUBLETHRESHOLD); EXPECT_NEAR(x[1], 2.1208751163987624722, DOUBLETHRESHOLD); EXPECT_NEAR(x[2], 9.4951527720891863993, DOUBLETHRESHOLD); ASSERT_EQ(final_iter, 6); ASSERT_EQ(tn.get_iter(), 6); +#define __MPI +#endif } \ No newline at end of file diff --git a/source/module_base/test/opt_test_tools.cpp b/source/module_base/test/opt_test_tools.cpp index 1c90b79bca..71e136b3ef 100644 --- a/source/module_base/test/opt_test_tools.cpp +++ b/source/module_base/test/opt_test_tools.cpp @@ -1,3 +1,6 @@ +#ifdef __MPI +#undef __MPI +#endif #include "./opt_test_tools.h" #include