Skip to content

Commit

Permalink
Merge branch 'develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
dzzz2001 authored Jan 10, 2025
2 parents 2acab4b + c898e52 commit e34e1b4
Show file tree
Hide file tree
Showing 90 changed files with 2,618 additions and 2,495 deletions.
19 changes: 10 additions & 9 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ VPATH=./src_global:\
./module_ri:\
./module_parameter:\
./module_lr:\
./module_lr/AX:\
./module_lr/ao_to_mo_transformer:\
./module_lr/dm_trans:\
./module_lr/operator_casida:\
./module_lr/potentials:\
Expand Down Expand Up @@ -189,23 +189,24 @@ OBJS_CELL=atom_pseudo.o\
check_atomic_stru.o\
update_cell.o\
bcast_cell.o\
read_stru.o\
read_atom_species.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_basic.o\
deepks_descriptor.o\
deepks_force.o\
deepks_fpre.o\
deepks_spre.o\
deepks_descriptor.o\
deepks_orbital.o\
deepks_orbpre.o\
deepks_vdelta.o\
deepks_vdpre.o\
deepks_hmat.o\
deepks_pdm.o\
deepks_phialpha.o\
LCAO_deepks_io.o\
LCAO_deepks_pdm.o\
LCAO_deepks_phialpha.o\
LCAO_deepks_torch.o\
LCAO_deepks_vdelta.o\
LCAO_deepks_interface.o\
cal_gedm.o\


OBJS_ELECSTAT=elecstate.o\
Expand Down Expand Up @@ -723,8 +724,8 @@ OBJS_TENSOR=tensor.o\

OBJS_LR=lr_util.o\
lr_util_hcontainer.o\
AX_parallel.o\
AX_serial.o\
ao_to_mo_parallel.o\
ao_to_mo_serial.o\
dm_trans_parallel.o\
dm_trans_serial.o\
dmr_complex.o\
Expand Down
129 changes: 125 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ namespace BlasUtils{
return CUBLAS_OP_N;
}

cublasSideMode_t judge_side(const char& trans)
{
if (trans == 'L')
{
return CUBLAS_SIDE_LEFT;
}
else if (trans == 'R')
{
return CUBLAS_SIDE_RIGHT;
}
return CUBLAS_SIDE_LEFT;
}

cublasFillMode_t judge_fill(const char& trans)
{
if (trans == 'F')
{
return CUBLAS_FILL_MODE_FULL;
}
else if (trans == 'U')
{
return CUBLAS_FILL_MODE_UPPER;
}
else if (trans == 'D')
{
return CUBLAS_FILL_MODE_LOWER;
}
return CUBLAS_FILL_MODE_FULL;
}

} // namespace BlasUtils

#endif
Expand Down Expand Up @@ -398,6 +428,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
Expand All @@ -409,6 +446,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
Expand All @@ -420,6 +464,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
Expand All @@ -431,6 +482,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
Expand All @@ -442,6 +500,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
Expand All @@ -453,6 +518,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -461,7 +533,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -470,7 +548,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -479,7 +563,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag());
cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag());
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy));
#endif
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand All @@ -488,7 +580,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag());
cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag());
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy));
#endif
}
}

// out = ||x||_2
Expand All @@ -497,6 +597,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return snrm2_( &n, X, &incX );
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
float result = 0.0;
cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
return result;
#endif
}
return snrm2_( &n, X, &incX );
}

Expand All @@ -506,6 +613,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dnrm2_( &n, X, &incX );
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
return result;
#endif
}
return dnrm2_( &n, X, &incX );
}

Expand All @@ -515,6 +629,13 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dznrm2_( &n, X, &incX );
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result));
return result;
#endif
}
return dznrm2_( &n, X, &incX );
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
AddTest(
TARGET container_ops_uts
LIBS parameter ${math_libs}
SOURCES einsum_op_test.cpp linalg_op_test.cpp
SOURCES einsum_op_test.cpp linalg_op_test.cpp ../../kernels/lapack.cpp
)

target_link_libraries(container_ops_uts container base device)
2 changes: 2 additions & 0 deletions source/module_cell/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ add_library(
check_atomic_stru.cpp
update_cell.cpp
bcast_cell.cpp
read_stru.cpp
read_atom_species.cpp
)

if(ENABLE_COVERAGE)
Expand Down
12 changes: 4 additions & 8 deletions source/module_cell/module_neighbor/test/prepare_unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,12 @@ class UcellTestPrepare
this->lmaxmax,
this->init_vel,
this->fixed_axes);
delete[] ucell->atom_label;
delete[] ucell->atom_mass;
delete[] ucell->pseudo_fn;
delete[] ucell->pseudo_type;

delete[] ucell->magnet.start_magnetization; //mag set here
ucell->atom_label = new std::string[ucell->ntype];
ucell->atom_mass = new double[ucell->ntype];
ucell->pseudo_fn = new std::string[ucell->ntype];
ucell->pseudo_type = new std::string[ucell->ntype];
ucell->atom_label.resize(ucell->ntype);
ucell->atom_mass.resize(ucell->ntype);
ucell->pseudo_fn.resize(ucell->ntype);
ucell->pseudo_type.resize(ucell->ntype);
ucell->orbital_fn.resize(ucell->ntype);
ucell->magnet.start_magnetization = new double[ucell->ntype]; //mag set here
ucell->magnet.ux_[0] = 0.0; // ux_ set here
Expand Down
Loading

0 comments on commit e34e1b4

Please sign in to comment.