From 84e63b6ca4318a787fe3330393829493c1aad0ee Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 02:46:32 +0000 Subject: [PATCH 01/49] remove Psi(const Psi& psi_in, const int nk_in, int nband_in); --- source/module_elecstate/cal_dm.h | 3 ++- .../module_elecstate/module_dm/cal_dm_psi.cpp | 3 ++- source/module_hamilt_general/operator.cpp | 3 ++- source/module_io/get_pchg_lcao.cpp | 6 +++-- source/module_io/write_dos_lcao.cpp | 4 ++- source/module_io/write_proj_band_lcao.cpp | 4 ++- source/module_psi/psi.cpp | 26 ------------------- source/module_psi/psi.h | 3 --- source/module_psi/test/psi_test.cpp | 2 +- 9 files changed, 17 insertions(+), 37 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 5ac41aab9a..5344cabc1a 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -27,7 +27,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + // psi::Psi wg_wfc(wfc, 1); + psi::Psi wg_wfc(1, nbands_local, nbasis_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_elecstate/module_dm/cal_dm_psi.cpp b/source/module_elecstate/module_dm/cal_dm_psi.cpp index 47fbfbf8c3..dc15f0635c 100644 --- a/source/module_elecstate/module_dm/cal_dm_psi.cpp +++ b/source/module_elecstate/module_dm/cal_dm_psi.cpp @@ -32,7 +32,8 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); // dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + // psi::Psi wg_wfc(wfc, 1, ); + psi::Psi wg_wfc(1, nbands_local, nbasis_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a99e813e01..3ec209d0bf 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -156,7 +156,8 @@ T* Operator::get_hpsi(const hpsi_info& info) const else if(hpsi_pointer == psi_pointer) { this->in_place = true; - this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + // this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + this->hpsi = new psi::Psi(1, nbands_range, std::get<0>(info)->get_nbasis()); } else { diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 6e069fd017..4cd3b05024 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -478,7 +478,8 @@ void IState_Charge::idmatrix(const int& ib, // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); this->psi_gamma->fix_k(is); - psi::Psi wg_wfc(*this->psi_gamma, 1); + // psi::Psi wg_wfc(*this->psi_gamma, 1); + psi::Psi wg_wfc(1, this->psi_gamma->get_nbands(), this->psi_gamma->get_nbasis()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { @@ -540,7 +541,8 @@ void IState_Charge::idmatrix(const int& ib, } this->psi_k->fix_k(ik); - psi::Psi> wg_wfc(*this->psi_k, 1); + // psi::Psi> wg_wfc(*this->psi_k, 1); + psi::Psi> wg_wfc(1, this->psi_k->get_nbands(), this->psi_k->get_nbasis()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_io/write_dos_lcao.cpp b/source/module_io/write_dos_lcao.cpp index e475c77459..8f4c251f4c 100644 --- a/source/module_io/write_dos_lcao.cpp +++ b/source/module_io/write_dos_lcao.cpp @@ -461,7 +461,9 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell, } psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + // psi::Psi> Dwfc(psi[0], 1); + psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index ccd7a0d4b0..18759f2e29 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -221,7 +221,9 @@ void ModuleIO::write_proj_band_lcao( // calculate Mulk psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + // psi::Psi> Dwfc(psi[0], 1); + psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index fb8abc78cd..20940529f1 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -107,32 +107,6 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } -template -Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) -{ - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } - this->k_first = psi_in.get_k_first(); - this->device = psi_in.device; - this->resize(nk_in, nband_in, psi_in.get_nbasis()); - this->ngk = psi_in.ngk; - this->npol = psi_in.npol; - if (nband_in <= psi_in.get_nbands()) - { - // copy from Psi from psi_in(current_k, 0, 0), - // if size of k is 1, current_k in new Psi is psi_in.current_k - if (nk_in == 1) - { - // current_k for this Psi only keep the spin index same as the copied Psi - this->current_k = psi_in.get_current_k(); - } - synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); - } -} - template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 6b374c8a70..2fe4f6cca6 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -42,9 +42,6 @@ class Psi // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); - // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); - // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() // in this case, fix_k can not be used Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index df22b5f885..fa3f357407 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -15,7 +15,7 @@ class TestPsi : public ::testing::Test const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); + // psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); psi::Psi>* psi_object5 = new psi::Psi>(psi_object31->get_pointer(), *psi_object31, ink, 0); }; From b15cd5c2413d02db421c336a8f4aad6c6949c04e Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 03:30:14 +0000 Subject: [PATCH 02/49] fix bug --- source/module_psi/psi.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 20940529f1..28b6f4c90b 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -208,8 +208,12 @@ template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { assert(nks_in > 0 && nbands_in >= 0 && nbasis_in > 0); + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. resize_memory_op()(this->ctx, this->psi, nks_in * static_cast(nbands_in) * nbasis_in, "no_record"); + + this->zero_out(); + this->nk = nks_in; this->nbands = nbands_in; this->nbasis = nbasis_in; From 9900bb77542ecb8bcd305a4f36e5243d50cc04b7 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 07:20:46 +0000 Subject: [PATCH 03/49] fix bug --- source/module_elecstate/cal_dm.h | 3 +- .../module_elecstate/module_dm/cal_dm_psi.cpp | 4 +-- source/module_hamilt_general/operator.cpp | 3 +- source/module_io/get_pchg_lcao.cpp | 4 +-- source/module_io/write_dos_lcao.cpp | 3 +- source/module_io/write_proj_band_lcao.cpp | 3 +- source/module_psi/psi.cpp | 32 +++++++++++++++++-- source/module_psi/psi.h | 4 +++ 8 files changed, 42 insertions(+), 14 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 5344cabc1a..b28f685d20 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -27,8 +27,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - // psi::Psi wg_wfc(wfc, 1); - psi::Psi wg_wfc(1, nbands_local, nbasis_local); + psi::Psi wg_wfc(wfc, 1, nbands_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_elecstate/module_dm/cal_dm_psi.cpp b/source/module_elecstate/module_dm/cal_dm_psi.cpp index dc15f0635c..cd868dcf9e 100644 --- a/source/module_elecstate/module_dm/cal_dm_psi.cpp +++ b/source/module_elecstate/module_dm/cal_dm_psi.cpp @@ -32,8 +32,8 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); // dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - // psi::Psi wg_wfc(wfc, 1, ); - psi::Psi wg_wfc(1, nbands_local, nbasis_local); + + psi::Psi wg_wfc(wfc, 1, nbands_local); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 3ec209d0bf..a99e813e01 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -156,8 +156,7 @@ T* Operator::get_hpsi(const hpsi_info& info) const else if(hpsi_pointer == psi_pointer) { this->in_place = true; - // this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); - this->hpsi = new psi::Psi(1, nbands_range, std::get<0>(info)->get_nbasis()); + this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); } else { diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 4cd3b05024..4b3013b581 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -478,8 +478,8 @@ void IState_Charge::idmatrix(const int& ib, // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); this->psi_gamma->fix_k(is); - // psi::Psi wg_wfc(*this->psi_gamma, 1); - psi::Psi wg_wfc(1, this->psi_gamma->get_nbands(), this->psi_gamma->get_nbasis()); + + psi::Psi wg_wfc(*this->psi_gamma, 1, this->psi_gamma->get_nbands()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_io/write_dos_lcao.cpp b/source/module_io/write_dos_lcao.cpp index 8f4c251f4c..df07cef1d6 100644 --- a/source/module_io/write_dos_lcao.cpp +++ b/source/module_io/write_dos_lcao.cpp @@ -461,8 +461,7 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell, } psi->fix_k(ik); - // psi::Psi> Dwfc(psi[0], 1); - psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + psi::Psi> Dwfc(*psi, 1, psi->get_nbands()); std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index 18759f2e29..34843207b7 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -221,8 +221,7 @@ void ModuleIO::write_proj_band_lcao( // calculate Mulk psi->fix_k(ik); - // psi::Psi> Dwfc(psi[0], 1); - psi::Psi> Dwfc(1, psi->get_nbands(), psi->get_nbasis()); + psi::Psi> Dwfc(psi[0], 1, psi->get_nbands()); std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 28b6f4c90b..85e2f416cb 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -107,6 +107,34 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } + +template +Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) +{ + assert(nk_in <= psi_in.get_nk()); + if (nband_in == 0) + { + nband_in = psi_in.get_nbands(); + } + this->k_first = psi_in.get_k_first(); + this->device = psi_in.device; + this->resize(nk_in, nband_in, psi_in.get_nbasis()); + this->ngk = psi_in.ngk; + this->npol = psi_in.npol; + if (nband_in <= psi_in.get_nbands()) + { + // copy from Psi from psi_in(current_k, 0, 0), + // if size of k is 1, current_k in new Psi is psi_in.current_k + if (nk_in == 1) + { + // current_k for this Psi only keep the spin index same as the copied Psi + this->current_k = psi_in.get_current_k(); + } + synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + } +} + + template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { @@ -208,11 +236,11 @@ template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { assert(nks_in > 0 && nbands_in >= 0 && nbasis_in > 0); - + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. resize_memory_op()(this->ctx, this->psi, nks_in * static_cast(nbands_in) * nbasis_in, "no_record"); - this->zero_out(); + // this->zero_out(); this->nk = nks_in; this->nbands = nbands_in; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 2fe4f6cca6..41ac645ce3 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -42,6 +42,10 @@ class Psi // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); + // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in + Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); + + // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() // in this case, fix_k can not be used Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); From a02b5d8d5eb139a9d5ff46e41441807c93db5722 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 25 Dec 2024 08:05:47 +0000 Subject: [PATCH 04/49] [pre-commit.ci lite] apply automatic fixes --- source/module_elecstate/cal_dm.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index b28f685d20..13f41bf455 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -41,7 +41,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); @@ -99,7 +100,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); std::complex* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); From 8e3a58fd7c60f60524f8301d74abb3645fefce47 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 16:37:04 +0800 Subject: [PATCH 05/49] remove device value in psi --- source/module_psi/psi.cpp | 21 +++++++-------------- source/module_psi/psi.h | 15 ++++----------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 85e2f416cb..7bd2996808 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -32,7 +32,6 @@ template Psi::Psi() { this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); } template @@ -52,8 +51,9 @@ Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); + this->resize(nk_in, nbd_in, nbs_in); + // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); base_device::information::record_device_memory(this->ctx, @@ -76,7 +76,6 @@ Psi::Psi(T* psi_pointer, this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; @@ -96,7 +95,6 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; @@ -111,13 +109,10 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int template Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) { - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } + assert(nk_in <= psi_in.get_nk() && nk_in > 0); + assert(nband_in <= psi_in.get_nbands() && nband_in > 0); + this->k_first = psi_in.get_k_first(); - this->device = psi_in.device; this->resize(nk_in, nband_in, psi_in.get_nbasis()); this->ngk = psi_in.ngk; this->npol = psi_in.npol; @@ -139,8 +134,6 @@ template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { this->k_first = psi_in.get_k_first(); - this->device = base_device::get_device_type(this->ctx); - assert(this->device == psi_in.device); assert(nk_in <= psi_in.get_nk()); if (nband_in == 0) { @@ -168,7 +161,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); base_device::memory::synchronize_memory_op()(this->ctx, psi_in.get_device(), @@ -193,7 +186,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); // Specifically, if the Device_in type is CPU and the Device type is GPU. diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 41ac645ce3..042fd865d7 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -43,9 +43,8 @@ class Psi Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); - - + Psi(const Psi& psi_in, const int nk_in, const int nband_in); + // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() // in this case, fix_k can not be used Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); @@ -69,13 +68,8 @@ class Psi // Constructor 8-2: a pointer version of constructor 3 // only used in operator.cpp call_act func - Psi(T* psi_pointer, - const int nk_in, - const int nbd_in, - const int nbs_in, - const bool k_first_in); + Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in); - // Destructor for deleting the psi array manually ~Psi(); @@ -141,8 +135,7 @@ class Psi private: T* psi = nullptr; // avoid using C++ STL - - base_device::AbacusDevice_t device = {}; // track the device type (CPU, GPU and SYCL are supported currented) + Device* ctx = {}; // an context identifier for obtaining the device variable // dimensions From c716bb7e15ef8f62347c765b5495a81cbbb043cd Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 17:44:17 +0800 Subject: [PATCH 06/49] update Psi(const Psi& psi_in, const int nk_in, int nband_in) --- source/module_psi/psi.cpp | 39 ++++++++++++++++++++++++++------------- source/module_psi/psi.h | 1 + 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 7bd2996808..cf2a926c22 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -105,7 +105,6 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } - template Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) { @@ -113,23 +112,37 @@ Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) assert(nband_in <= psi_in.get_nbands() && nband_in > 0); this->k_first = psi_in.get_k_first(); - this->resize(nk_in, nband_in, psi_in.get_nbasis()); - this->ngk = psi_in.ngk; this->npol = psi_in.npol; - if (nband_in <= psi_in.get_nbands()) + this->allocate_inside = true; + + this->nk = nk_in; + this->nbands = nband_in; + this->nbasis = psi_in.get_nbasis(); + + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. + resize_memory_op()(this->ctx, + this->psi, + (static_cast(this->nk) * static_cast(this->nbands) + * static_cast(this->nbasis)), + "no_record"); + synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + + this->current_k = 0; + this->current_b = 0; + this->current_nbasis = this->nbasis; + this->psi_current = this->psi; + this->psi_bias = 0; + + if (this->nk != psi_in.get_nk()) { - // copy from Psi from psi_in(current_k, 0, 0), - // if size of k is 1, current_k in new Psi is psi_in.current_k - if (nk_in == 1) - { - // current_k for this Psi only keep the spin index same as the copied Psi - this->current_k = psi_in.get_current_k(); - } - synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); + this->ngk = nullptr; + } + else + { + this->ngk = psi_in.ngk; } } - template Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 042fd865d7..88f143df2e 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -131,6 +131,7 @@ class Psi // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; + int npol = 1; private: From 1c2f523affda8da23db65558b88266951ee143f2 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 12:21:23 +0000 Subject: [PATCH 07/49] update get_ngk usage --- source/module_hamilt_general/operator.cpp | 2 +- .../module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp | 2 +- source/module_psi/psi.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a99e813e01..db54852331 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -73,7 +73,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node); + op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), is_first_node); break; } }; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 858e6b3fd5..1a1196b864 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -47,7 +47,7 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - const int npw = psi_in->get_ngk(this->ik); + const int npw = psi_in->get_current_nbas(); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index cf2a926c22..f4ba33eedd 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -106,7 +106,7 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int } template -Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) +Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) { assert(nk_in <= psi_in.get_nk() && nk_in > 0); assert(nband_in <= psi_in.get_nbands() && nband_in > 0); From 1fb8851c4e7be7e5494c4dc41031867181dd617b Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 25 Dec 2024 13:33:37 +0000 Subject: [PATCH 08/49] fix bug about ngk --- source/module_hamilt_general/operator.cpp | 5 +++++ source/module_hsolver/hsolver_pw.cpp | 27 ++++++++++++++--------- source/module_psi/psi.cpp | 3 ++- source/module_psi/psi.h | 1 + 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index db54852331..6f7a7cc77e 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -65,6 +65,11 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); + + // std::cout << "op->ik : " << op->ik << std::endl; + // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + switch (op->get_act_type()) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 0c1ad2e8b8..97f32aa587 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -378,10 +378,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ngk_vector[i] = ngk_pointer[i]; } + const int cur_nbasis = psi.get_current_nbas(); + if (this->method == "cg") { // wrap the subspace_func into a lambda function - auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) { + auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) { // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); @@ -391,12 +393,14 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, psi_in.shape().dim_size(0), psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto psi_out_wrapper = psi::Psi(psi_out.data(), 1, psi_out.shape().dim_size(0), psi_out.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); auto eigen = ct::Tensor(ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, ct::TensorShape({psi_in.shape().dim_size(0)})); @@ -415,7 +419,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, using ct_Device = typename ct::PsiToContainer::type; // wrap the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] @@ -426,7 +430,8 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, ndim == 1 ? 1 : psi_in.shape().dim_size(0), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_vector); + ngk_vector, + cur_nbasis); psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); @@ -486,11 +491,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -507,11 +512,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, else if (this->method == "dav_subspace") { // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -558,11 +563,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, // Davidson matrix-blockvector functions /// wrap hpsi into lambda function, Matrix \times blockvector // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("David", "hpsi_func"); // Convert pointer of psi_in to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index f4ba33eedd..3fcb347790 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -69,6 +69,7 @@ Psi::Psi(T* psi_pointer, const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in) { this->k_first = k_first_in; @@ -79,7 +80,7 @@ Psi::Psi(T* psi_pointer, this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; - this->current_nbasis = nbs_in; + this->current_nbasis = current_nbasis_in; this->psi_current = this->psi = psi_pointer; this->allocate_inside = false; // Currently only GPU's implementation is supported for device recording! diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 88f143df2e..860112f066 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -64,6 +64,7 @@ class Psi const int nbd_in, const int nbs_in, const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in = true); // Constructor 8-2: a pointer version of constructor 3 From 1a9aea99c3278074504e01c3e1b66b6e825a0aa9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 25 Dec 2024 14:18:30 +0000 Subject: [PATCH 09/49] [pre-commit.ci lite] apply automatic fixes --- .../module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 1a1196b864..94a671372b 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -23,7 +23,8 @@ Velocity::Velocity ModuleBase::WARNING_QUIT("Velocity", "Constuctor of Operator::Velocity is failed, please check your code!"); } this->tpiba = ucell_in -> tpiba; - if(this->nonlocal) this->ppcell->initgradq_vnl(*this->ucell); + if(this->nonlocal) { this->ppcell->initgradq_vnl(*this->ucell); +} } void Velocity::init(const int ik_in) From 9a3a9f07f1a89a050e6f76547428ba8262a16d8b Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 11:13:38 +0800 Subject: [PATCH 10/49] fix bug --- source/module_hamilt_general/operator.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 6f7a7cc77e..f36691b950 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -66,9 +66,11 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - // std::cout << "op->ik : " << op->ik << std::endl; - // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; - // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + std::cout << "op->ik : " << op->ik << std::endl; + std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + + std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; @@ -78,7 +80,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), is_first_node); + op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas() / psi_input->npol, is_first_node); break; } }; From 093c3f213ec9d943d800d1b5ca3c29e1024c011b Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 11:30:52 +0800 Subject: [PATCH 11/49] format operator --- source/module_hamilt_general/operator.cpp | 109 ++++++++++++---------- source/module_hamilt_general/operator.h | 76 ++++++++------- 2 files changed, 103 insertions(+), 82 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index f36691b950..a8c95955f7 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -4,28 +4,31 @@ using namespace hamilt; - -template -Operator::Operator(){} - -template -Operator::~Operator() +template +Operator::Operator() { - if(this->hpsi != nullptr) { delete this->hpsi; } + +template +Operator::~Operator() +{ + if (this->hpsi != nullptr) + { + delete this->hpsi; + } Operator* last = this->next_op; Operator* last_sub = this->next_sub_op; - while(last != nullptr || last_sub != nullptr) + while (last != nullptr || last_sub != nullptr) { - if(last_sub != nullptr) - {//delete sub_chain first + if (last_sub != nullptr) + { // delete sub_chain first Operator* node_delete = last_sub; last_sub = last_sub->next_sub_op; node_delete->next_sub_op = nullptr; delete node_delete; } else - {//delete main chain if sub_chain is deleted + { // delete main chain if sub_chain is deleted Operator* node_delete = last; last_sub = last->next_sub_op; node_delete->next_sub_op = nullptr; @@ -36,7 +39,7 @@ Operator::~Operator() } } -template +template typename Operator::hpsi_info Operator::hPsi(hpsi_info& input) const { using syncmem_op = base_device::memory::synchronize_memory_op; @@ -46,12 +49,12 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp T* tmhpsi = this->get_hpsi(input); const T* tmpsi_in = std::get<0>(psi_info); - //if range in hpsi_info is illegal, the first return of to_range() would be nullptr + // if range in hpsi_info is illegal, the first return of to_range() would be nullptr if (tmpsi_in == nullptr) { ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!"); } - //if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return + // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return T* hpsi_pointer = std::get<2>(input); if (this->in_place) { @@ -62,28 +65,31 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp } auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void { - // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - std::cout << "op->ik : " << op->ik << std::endl; - std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; - std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + // std::cout << "op->ik : " << op->ik << std::endl; + // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + + // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; - std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; - - - switch (op->get_act_type()) { case 2: op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas() / psi_input->npol, is_first_node); + op->act(nbands, + psi_input->get_nbasis(), + psi_input->npol, + tmpsi_in, + this->hpsi->get_pointer(), + psi_input->get_current_nbas() / psi_input->npol, + is_first_node); break; } - }; + }; ModuleBase::timer::tick("Operator", "hPsi"); call_act(this, true); // first node @@ -98,39 +104,43 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); } - -template -void Operator::init(const int ik_in) +template +void Operator::init(const int ik_in) { this->ik = ik_in; - if(this->next_op != nullptr) { + if (this->next_op != nullptr) + { this->next_op->init(ik_in); } } -template -void Operator::add(Operator* next) +template +void Operator::add(Operator* next) { - if(next==nullptr) { return; -} + if (next == nullptr) + { + return; + } next->is_first_node = false; - if(next->next_op != nullptr) { this->add(next->next_op); -} + if (next->next_op != nullptr) + { + this->add(next->next_op); + } Operator* last = this; - //loop to end of the chain - while(last->next_op != nullptr) + // loop to end of the chain + while (last->next_op != nullptr) { - if(next->cal_type==last->cal_type) + if (next->cal_type == last->cal_type) { break; } last = last->next_op; } - if(next->cal_type == last->cal_type) + if (next->cal_type == last->cal_type) { - //insert next to sub chain of current node + // insert next to sub chain of current node Operator* sub_last = last; - while(sub_last->next_sub_op != nullptr) + while (sub_last->next_sub_op != nullptr) { sub_last = sub_last->next_sub_op; } @@ -143,24 +153,24 @@ void Operator::add(Operator* next) } } -template +template T* Operator::get_hpsi(const hpsi_info& info) const { const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1); - //in_place call of hPsi, hpsi inputs as new psi, - //create a new hpsi and delete old hpsi later + // in_place call of hPsi, hpsi inputs as new psi, + // create a new hpsi and delete old hpsi later T* hpsi_pointer = std::get<2>(info); const T* psi_pointer = std::get<0>(info)->get_pointer(); - if(this->hpsi != nullptr) + if (this->hpsi != nullptr) { delete this->hpsi; this->hpsi = nullptr; } - if(!hpsi_pointer) + if (!hpsi_pointer) { ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr"); } - else if(hpsi_pointer == psi_pointer) + else if (hpsi_pointer == psi_pointer) { this->in_place = true; this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); @@ -170,7 +180,7 @@ T* Operator::get_hpsi(const hpsi_info& info) const this->in_place = false; this->hpsi = new psi::Psi(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range); } - + hpsi_pointer = this->hpsi->get_pointer(); size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis(); // ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size); @@ -179,7 +189,8 @@ T* Operator::get_hpsi(const hpsi_info& info) const return hpsi_pointer; } -namespace hamilt { +namespace hamilt +{ template class Operator; template class Operator, base_device::DEVICE_CPU>; template class Operator; @@ -190,4 +201,4 @@ template class Operator, base_device::DEVICE_GPU>; template class Operator; template class Operator, base_device::DEVICE_GPU>; #endif -} +} // namespace hamilt diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 6cf29122fe..80ed065ccc 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -1,19 +1,19 @@ #ifndef OPERATOR_H #define OPERATOR_H -#include - #include "module_base/global_function.h" #include "module_base/tool_quit.h" #include "module_psi/psi.h" +#include + namespace hamilt { enum class calculation_type { no, - pw_ekinetic, + pw_ekinetic, pw_nonlocal, pw_veff, pw_meta, @@ -28,49 +28,54 @@ enum class calculation_type lcao_tddft_velocity, }; -// Basic class for operator module, +// Basic class for operator module, // it is designed for "O|psi>" and "" // Operator "O" might have several different types, which should be calculated one by one. // In basic class , function add() is designed for combine all operators together with a chain. template class Operator { - public: + public: Operator(); virtual ~Operator(); - //this is the core function for Operator - // do H|psi> from input |psi> , + // this is the core function for Operator + // do H|psi> from input |psi> , /// as default, different operators donate hPsi independently - /// run this->act function for the first operator and run all act() for other nodes in chain table + /// run this->act function for the first operator and run all act() for other nodes in chain table /// if this procedure is not suitable for your operator, just override this function. - /// output of hpsi would be first member of the returned tuple + /// output of hpsi would be first member of the returned tuple typedef std::tuple*, const psi::Range, T*> hpsi_info; - virtual hpsi_info hPsi(hpsi_info& input)const; + + virtual hpsi_info hPsi(hpsi_info& input) const; virtual void init(const int ik_in); virtual void add(Operator* next); - virtual int get_ik() const { return this->ik; } + virtual int get_ik() const + { + return this->ik; + } - ///do operation : |hpsi_choosed> = V|psi_choosed> - ///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi - /// interface type 1: pointer-only (default) - /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! + /// do operation : |hpsi_choosed> = V|psi_choosed> + /// V is the target operator act on choosed psi, the consequence should be added to choosed hpsi + /// interface type 1: pointer-only (default) + /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! virtual void act(const int nbands, - const int nbasis, - const int npol, - const T* tmpsi_in, - T* tmhpsi, - const int ngk_ik = 0, - const bool is_first_node = false)const {}; + const int nbasis, + const int npol, + const T* tmpsi_in, + T* tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const {}; /// developer-friendly interfaces for act() function /// interface type 2: input and change the Psi-type HPsi // virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out) const {}; virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out, const int nbands) const {}; + /// interface type 3: return a Psi-type HPsi // virtual psi::Psi act(const psi::Psi& psi_in) const { return psi_in; }; @@ -78,36 +83,41 @@ class Operator /// type 1 (default): pointer-only /// act(const T* psi_in, T* psi_out) - /// type 2: use the `Psi`class + /// type 2: use the `Psi`class /// act(const Psi& psi_in, Psi& psi_out) - int get_act_type() const { return this->act_type; } -protected: + int get_act_type() const + { + return this->act_type; + } + + protected: int ik = 0; - int act_type = 1; ///< determine which act() interface would be called in hPsi() + int act_type = 1; ///< determine which act() interface would be called in hPsi() mutable bool in_place = false; - //calculation type, only different type can be in main chain table + // calculation type, only different type can be in main chain table enum calculation_type cal_type; Operator* next_sub_op = nullptr; bool is_first_node = true; - //if this Operator is first node in chain table, hpsi would not be empty + // if this Operator is first node in chain table, hpsi would not be empty mutable psi::Psi* hpsi = nullptr; /*This function would analyze hpsi_info and choose how to arrange hpsi storage In hpsi_info, if the third parameter hpsi_pointer is set, which indicates memory of hpsi is arranged by developer; - if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. + if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. two cases would occurred: - 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary memory - 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case + 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary + memory + 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case */ - T* get_hpsi(const hpsi_info& info)const; + T* get_hpsi(const hpsi_info& info) const; - Device *ctx = {}; + Device* ctx = {}; using set_memory_op = base_device::memory::set_memory_op; }; -}//end namespace hamilt +} // end namespace hamilt #endif \ No newline at end of file From af1b7bc7028155c868ac11038a8e22c60088e1c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Thu, 26 Dec 2024 04:11:55 +0000 Subject: [PATCH 12/49] [pre-commit.ci lite] apply automatic fixes --- source/module_io/write_proj_band_lcao.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index 34843207b7..b5660f7da5 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -25,8 +25,9 @@ void ModuleIO::write_proj_band_lcao( const double* sk = dynamic_cast*>(p_ham)->getSk(); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -103,14 +104,16 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; out << "" << PARAM.inp.nspin << "" << std::endl; - if (PARAM.inp.nspin == 4) + if (PARAM.inp.nspin == 4) { out << "" << std::setw(2) << PARAM.globalv.nlocal / 2 << "" << std::endl; - else + } else { out << "" << std::setw(2) << PARAM.globalv.nlocal << "" << std::endl; +} out << "" << std::endl; - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; out << "" << std::endl; @@ -139,9 +142,9 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; for (int ib = 0; ib < PARAM.inp.nbands; ib++) { - if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) { out << std::setw(13) << weight(is, ib * PARAM.globalv.nlocal + w); - else if (PARAM.inp.nspin == 4) + } else if (PARAM.inp.nspin == 4) { int w0 = w - s0; out << std::setw(13) @@ -178,8 +181,9 @@ void ModuleIO::write_proj_band_lcao( ModuleBase::timer::tick("ModuleIO", "write_proj_band_lcao"); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -302,8 +306,9 @@ void ModuleIO::write_proj_band_lcao( for (int ik = 0; ik < nks; ik++) { - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(ik + is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; } out << "" << std::endl; From 16687c3baf5a7e18ce079419ec5019ea52b0fd0f Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 07:03:05 +0000 Subject: [PATCH 13/49] fix bug --- source/module_hamilt_general/operator.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a8c95955f7..dce5335db4 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -74,6 +74,10 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; + // std::cout << "psi_input->npol : " << psi_input->npol << std::endl; + + + switch (op->get_act_type()) { case 2: @@ -85,7 +89,10 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - psi_input->get_current_nbas() / psi_input->npol, + psi_input->get_ngk(op->ik), + // 0, + // psi_input->get_current_nbas(), + // psi_input->get_current_nbas() / psi_input->npol, is_first_node); break; } From 35d26d65720b0e2b7ea8222cbe0556897ece5902 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 15:55:12 +0800 Subject: [PATCH 14/49] fix bug --- source/module_hamilt_general/operator.cpp | 4 ++-- source/module_hsolver/hsolver_pw.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index dce5335db4..fbd1525805 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -89,9 +89,9 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - psi_input->get_ngk(op->ik), + // psi_input->get_ngk(op->ik), // 0, - // psi_input->get_current_nbas(), + psi_input->get_current_nbas(), // psi_input->get_current_nbas() / psi_input->npol, is_first_node); break; diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 97f32aa587..a885296c62 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -378,7 +378,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ngk_vector[i] = ngk_pointer[i]; } - const int cur_nbasis = psi.get_current_nbas(); + const int cur_nbasis = psi.get_ngk(psi.get_current_k()); if (this->method == "cg") { From 3096085262c07e1edb4de90031c89e29481df527 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 09:13:00 +0000 Subject: [PATCH 15/49] fix bug --- source/module_hamilt_general/operator.cpp | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index fbd1525805..4a830489f4 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -68,16 +68,6 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - // std::cout << "op->ik : " << op->ik << std::endl; - // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; - // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; - - // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; - - // std::cout << "psi_input->npol : " << psi_input->npol << std::endl; - - - switch (op->get_act_type()) { case 2: @@ -89,10 +79,8 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - // psi_input->get_ngk(op->ik), - // 0, - psi_input->get_current_nbas(), - // psi_input->get_current_nbas() / psi_input->npol, + psi_input->get_ngk(op->ik), + // psi_input->get_current_nbas(), is_first_node); break; } From a3817e4983c0147697ab4f46c066ec24a97678b6 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 26 Dec 2024 10:06:41 +0000 Subject: [PATCH 16/49] fix bug --- .../module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 94a671372b..0d49cadaa0 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -48,7 +48,7 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - const int npw = psi_in->get_current_nbas(); + const int npw = psi_in->get_ngk(this->ik); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; From 67fda40768601950c3b7ad75f46efb0c88f2f690 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 27 Dec 2024 03:06:42 +0000 Subject: [PATCH 17/49] add get_cur_effective_basis func --- source/module_elecstate/elecstate_pw.cpp | 4 ++-- .../module_elecstate/elecstate_pw_cal_tau.cpp | 2 +- .../hamilt_stodft/sto_iter.cpp | 2 +- source/module_psi/psi.cpp | 20 +++++++++++++++++++ source/module_psi/psi.h | 2 ++ 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index a22c87bcf7..08ba4750bd 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -154,7 +154,7 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) this->init_rho_data(); int ik = psi.get_current_k(); - int npw = psi.get_current_nbas(); + int npw = psi.get_cur_effective_basis(); int current_spin = 0; if (PARAM.inp.nspin == 2) { @@ -258,7 +258,7 @@ void ElecStatePW::cal_becsum(const psi::Psi& psi) psi.fix_k(ik); const T* psi_now = psi.get_pointer(); const int currect_spin = this->klist->isk[ik]; - const int npw = psi.get_current_nbas(); + const int npw = psi.get_cur_effective_basis(); // get |beta> if (this->ppcell->nkb > 0) diff --git a/source/module_elecstate/elecstate_pw_cal_tau.cpp b/source/module_elecstate/elecstate_pw_cal_tau.cpp index fd07f834af..451aa9688a 100644 --- a/source/module_elecstate/elecstate_pw_cal_tau.cpp +++ b/source/module_elecstate/elecstate_pw_cal_tau.cpp @@ -15,7 +15,7 @@ void ElecStatePW::cal_tau(const psi::Psi& psi) for (int ik = 0; ik < psi.get_nk(); ++ik) { psi.fix_k(ik); - int npw = psi.get_current_nbas(); + int npw = psi.get_cur_effective_basis(); int current_spin = 0; if (PARAM.inp.nspin == 2) { diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index bcfbd2da61..407879d24f 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -60,7 +60,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, if (PARAM.inp.nbands > 0) { const int nchipk = stowf.nchip[ik]; - const int npw = psi.get_current_nbas(); + const int npw = psi.get_cur_effective_basis(); const int npwx = psi.get_nbasis(); stowf.chi0->fix_k(ik); stowf.chiortho->fix_k(ik); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 3fcb347790..cbb0d4cf34 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -295,6 +295,26 @@ const int& Psi::get_psi_bias() const return this->psi_bias; } +template +const int& Psi::get_cur_effective_basis() const +{ + if (this->npol == 1) + { + if (this->ngk != nullptr) + { + return this->ngk[this->current_k]; + } + else + { + return this->current_nbasis; + } + } + else + { + return this->nbasis; + } +} + template const int& Psi::get_nk() const { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 860112f066..54bcd31fa2 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -127,6 +127,8 @@ class Psi // return psi_bias const int& get_psi_bias() const; + const int& get_cur_effective_basis() const; + // mark void zero_out(); From 0b0604c5c945faa55407bb7585ce61097acbab28 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 27 Dec 2024 06:05:47 +0000 Subject: [PATCH 18/49] fix bug --- source/module_psi/psi.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index cbb0d4cf34..be045742af 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -300,11 +300,11 @@ const int& Psi::get_cur_effective_basis() const { if (this->npol == 1) { - if (this->ngk != nullptr) - { - return this->ngk[this->current_k]; - } - else + // if (this->ngk != nullptr) + // { + // return this->ngk[this->current_k]; + // } + // else { return this->current_nbasis; } From d5634b3af6fc4b044a6f8239878179cc3c0e7a19 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 27 Dec 2024 15:02:10 +0800 Subject: [PATCH 19/49] update get_cur_effective_basis --- source/module_hsolver/diago_iter_assist.cpp | 10 +++++----- source/module_hsolver/hsolver_pw.cpp | 6 +++--- source/module_hsolver/test/diago_cg_float_test.cpp | 4 ++-- source/module_hsolver/test/diago_cg_real_test.cpp | 4 ++-- source/module_hsolver/test/diago_cg_test.cpp | 4 ++-- source/module_hsolver/test/diago_david_float_test.cpp | 2 +- source/module_hsolver/test/diago_david_real_test.cpp | 2 +- source/module_hsolver/test/diago_david_test.cpp | 2 +- source/module_psi/psi.cpp | 9 +-------- 9 files changed, 18 insertions(+), 25 deletions(-) diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 5ec443ab4e..29cd923c0d 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -49,7 +49,7 @@ void DiagoIterAssist::diagH_subspace(const hamilt::Hamilt* setmem_complex_op()(ctx, scc, 0, nstart * nstart); setmem_complex_op()(ctx, vcc, 0, nstart * nstart); - const int dmin = psi.get_current_nbas(); + const int dmin = psi.get_cur_effective_basis(); const int dmax = psi.get_nbasis(); T* temp = nullptr; @@ -167,7 +167,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* const int nstart = psi_nr; const int n_band = evc.get_nbands(); const int dmax = evc.get_nbasis(); - const int dmin = evc.get_current_nbas(); + const int dmin = evc.get_cur_effective_basis(); // skip the diagonalization if the operators are not allocated if (pHamilt->ops == nullptr) @@ -264,7 +264,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* T* spsi = temp; // do sPsi for all bands - pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_current_nbas(), psi_temp.get_nbands()); + pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_cur_effective_basis(), psi_temp.get_nbands()); gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart); delmem_complex_op()(ctx, temp); @@ -423,7 +423,7 @@ void DiagoIterAssist::cal_hs_subspace(const hamilt::Hamilt setmem_complex_op()(ctx, hcc, 0, nstart * nstart); setmem_complex_op()(ctx, scc, 0, nstart * nstart); - const int dmin = psi.get_current_nbas(); + const int dmin = psi.get_cur_effective_basis(); const int dmax = psi.get_nbasis(); T* temp = nullptr; @@ -549,7 +549,7 @@ void DiagoIterAssist::diag_subspace_psi(const T* hcc, DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc); { // code block to calculate tar_mat - const int dmin = evc.get_current_nbas(); + const int dmin = evc.get_cur_effective_basis(); const int dmax = evc.get_nbasis(); T* temp = nullptr; resmem_complex_op()(ctx, temp, nstart * dmax, "DiagSub::temp"); diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index a885296c62..a9d48ee769 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -480,7 +480,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ct::DeviceTypeToEnum::value, ct::TensorShape({static_cast(pre_condition.size())})) .to_device() - .slice({0}, {psi.get_current_nbas()}); + .slice({0}, {psi.get_cur_effective_basis()}); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor); // TODO: Double check tensormap's potential problem @@ -530,7 +530,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, Diago_DavSubspace dav_subspace(pre_condition, psi.get_nbands(), - psi.get_k_first() ? psi.get_current_nbas() + psi.get_k_first() ? psi.get_cur_effective_basis() : psi.get_nk() * psi.get_nbasis(), PARAM.inp.pw_diag_ndim, this->diag_thr, @@ -556,7 +556,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int david_maxiter = this->diag_iter_max; // dimensions of matrix to be solved - const int dim = psi.get_current_nbas(); /// dimension of matrix + const int dim = psi.get_cur_effective_basis(); /// dimension of matrix const int nband = psi.get_nbands(); /// number of eigenpairs sought const int ld_psi = psi.get_nbasis(); /// leading dimension of psi diff --git a/source/module_hsolver/test/diago_cg_float_test.cpp b/source/module_hsolver/test/diago_cg_float_test.cpp index 47fac4ef01..0500424b92 100644 --- a/source/module_hsolver/test/diago_cg_float_test.cpp +++ b/source/module_hsolver/test/diago_cg_float_test.cpp @@ -182,7 +182,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_COMPLEX, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_FLOAT, @@ -192,7 +192,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_FLOAT, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()}); + ct::TensorShape({static_cast(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_cg_real_test.cpp b/source/module_hsolver/test/diago_cg_real_test.cpp index 97872c316d..f6aa978620 100644 --- a/source/module_hsolver/test/diago_cg_real_test.cpp +++ b/source/module_hsolver/test/diago_cg_real_test.cpp @@ -185,7 +185,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_DOUBLE, @@ -195,7 +195,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()}); + ct::TensorShape({static_cast(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_cg_test.cpp b/source/module_hsolver/test/diago_cg_test.cpp index 08912bc428..5d144ae9fb 100644 --- a/source/module_hsolver/test/diago_cg_test.cpp +++ b/source/module_hsolver/test/diago_cg_test.cpp @@ -176,7 +176,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_DOUBLE, @@ -186,7 +186,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()}); + ct::TensorShape({static_cast(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_david_float_test.cpp b/source/module_hsolver/test/diago_david_float_test.cpp index c3feeea246..37930da8e6 100644 --- a/source/module_hsolver/test/diago_david_float_test.cpp +++ b/source/module_hsolver/test/diago_david_float_test.cpp @@ -90,7 +90,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_current_nbas() ; + const int dim = phi.get_cur_effective_basis() ; const int nband = phi.get_nbands(); const int ld_psi =phi.get_nbasis(); hsolver::DiagoDavid> dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_hsolver/test/diago_david_real_test.cpp b/source/module_hsolver/test/diago_david_real_test.cpp index a1c4dee958..2a0103fe49 100644 --- a/source/module_hsolver/test/diago_david_real_test.cpp +++ b/source/module_hsolver/test/diago_david_real_test.cpp @@ -89,7 +89,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_current_nbas(); + const int dim = phi.get_cur_effective_basis(); const int nband = phi.get_nbands(); const int ld_psi = phi.get_nbasis(); hsolver::DiagoDavid dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_hsolver/test/diago_david_test.cpp b/source/module_hsolver/test/diago_david_test.cpp index 71005a78b9..542deeb663 100644 --- a/source/module_hsolver/test/diago_david_test.cpp +++ b/source/module_hsolver/test/diago_david_test.cpp @@ -92,7 +92,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_current_nbas(); + const int dim = phi.get_cur_effective_basis(); const int nband = phi.get_nbands(); const int ld_psi = phi.get_nbasis(); hsolver::DiagoDavid> dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index be045742af..bfb60f49cc 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -300,14 +300,7 @@ const int& Psi::get_cur_effective_basis() const { if (this->npol == 1) { - // if (this->ngk != nullptr) - // { - // return this->ngk[this->current_k]; - // } - // else - { - return this->current_nbasis; - } + return this->current_nbasis; } else { From 0339ba3eb1e7f81cbfdce1c2d346e180af6b723a Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 27 Dec 2024 08:58:41 +0000 Subject: [PATCH 20/49] check bugs --- source/module_psi/psi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index bfb60f49cc..e582dee5e6 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -341,7 +341,7 @@ void Psi::fix_k(const int ik) const { assert(ik >= 0); this->current_k = ik; - if (this->ngk != nullptr && this->npol != 2) + if (this->ngk != nullptr) { this->current_nbasis = this->ngk[ik]; } From 190e74a47db4513b584073089138f5b2148c9334 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 27 Dec 2024 10:12:02 +0000 Subject: [PATCH 21/49] update Constructor 8-1 --- source/module_hamilt_general/operator.cpp | 4 ++-- .../hamilt_pwdft/operator_pw/velocity_pw.cpp | 1 + source/module_psi/psi.cpp | 17 ++++++++++++----- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 4a830489f4..ae04b1f4f0 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -79,8 +79,8 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - psi_input->get_ngk(op->ik), - // psi_input->get_current_nbas(), + // psi_input->get_ngk(op->ik), + psi_input->get_current_nbas(), is_first_node); break; } diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 0d49cadaa0..3115b49798 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -49,6 +49,7 @@ void Velocity::act { ModuleBase::timer::tick("Operator", "Velocity"); const int npw = psi_in->get_ngk(this->ik); + // const int npw = psi_in->get_current_nbas(); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index e582dee5e6..276bc6369a 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -73,16 +73,23 @@ Psi::Psi(T* psi_pointer, const bool k_first_in) { this->k_first = k_first_in; - this->ngk = ngk_vector_in.data(); - this->current_b = 0; - this->current_k = 0; this->npol = PARAM.globalv.npol; + this->allocate_inside = false; + + this->ngk = ngk_vector_in.data(); + + this->psi = psi_pointer; + this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; + + this->current_k = 0; + this->current_b = 0; this->current_nbasis = current_nbasis_in; - this->psi_current = this->psi = psi_pointer; - this->allocate_inside = false; + this->psi_current = psi_pointer; + this->psi_bias = 0; + // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } From cf0cad7447568f19806d6ad56b040f4ddb5a1d7a Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sat, 28 Dec 2024 06:31:10 +0000 Subject: [PATCH 22/49] fix bug --- source/module_hamilt_general/operator.cpp | 10 ++++++ source/module_hsolver/diago_iter_assist.cpp | 4 +-- source/module_psi/psi.cpp | 38 +++++++++++++++++++++ source/module_psi/psi.h | 5 +++ 4 files changed, 55 insertions(+), 2 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index ae04b1f4f0..cf33a77111 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -68,6 +68,16 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // a "psi" with the bands of needed range psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); + + // if (psi_input->get_ngk(op->ik) != psi_input->get_current_nbas()) + // { + // std::cout << "op->ik : " << op->ik << std::endl; + // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; + // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; + + // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; + // } + switch (op->get_act_type()) { case 2: diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 29cd923c0d..a5bb076c6d 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -199,7 +199,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, 1, psi_nc, dmin, true); T* ppsi = psi_temp.get_pointer(); // hpsi and spsi share the temp space T* temp = nullptr; @@ -246,7 +246,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* } else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { - psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, nstart, psi_nc, dmin, true); T* ppsi = psi_temp.get_pointer(); syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size()); // hpsi and spsi share the temp space diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 276bc6369a..6893042343 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -113,6 +113,44 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } +// Constructor 8-3: 2D Psi version 3 +template +Psi::Psi(const int nk_in, + const int nbd_in, + const int nbs_in, + const int current_nbasis_in, + const bool k_first_in) +{ + + // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. + assert(nk_in == 1); + + this->k_first = k_first_in; + this->npol = PARAM.globalv.npol; + this->allocate_inside = true; + + this->ngk = nullptr; + assert(nk_in > 0 && nbd_in > 0 && nbs_in > 0); + resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); + + this->nk = nk_in; + this->nbands = nbd_in; + this->nbasis = nbs_in; + + this->current_k = 0; + this->current_b = 0; + this->current_nbasis = current_nbasis_in; + this->psi_current = this->psi; + this->psi_bias = 0; + + // Currently only GPU's implementation is supported for device recording! + base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); + base_device::information::record_device_memory(this->ctx, + GlobalV::ofs_device, + "Psi->resize()", + sizeof(T) * nk_in * nbd_in * nbs_in); +} + template Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 54bcd31fa2..c07840bf7c 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -70,7 +70,12 @@ class Psi // Constructor 8-2: a pointer version of constructor 3 // only used in operator.cpp call_act func Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in); + + + // Constructor 8-3: 2D Psi version 3 + Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in); + // Destructor for deleting the psi array manually ~Psi(); From 30b1aa498a51839f3b44278fcc1a7aa76bf9908d Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sat, 28 Dec 2024 07:36:37 +0000 Subject: [PATCH 23/49] fix bug --- .../module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 3115b49798..bf4b18135e 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -48,8 +48,8 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - const int npw = psi_in->get_ngk(this->ik); - // const int npw = psi_in->get_current_nbas(); + // const int npw = psi_in->get_ngk(this->ik); + const int npw = psi_in->get_current_nbas(); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; From fcc167f445919db78c6cada499866359382ffddb Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sat, 28 Dec 2024 08:58:45 +0000 Subject: [PATCH 24/49] fix bug --- source/module_hamilt_general/operator.cpp | 4 ++-- .../hamilt_pwdft/operator_pw/velocity_pw.cpp | 16 ++++++++++++++-- source/module_hsolver/diago_iter_assist.cpp | 4 ++-- source/module_psi/psi.cpp | 3 ++- source/module_psi/psi.h | 7 ++++++- 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index cf33a77111..424a4f4ce9 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -89,8 +89,8 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - // psi_input->get_ngk(op->ik), - psi_input->get_current_nbas(), + psi_input->get_ngk(op->ik), + // psi_input->get_current_nbas(), is_first_node); break; } diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index bf4b18135e..81ebbded47 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -48,8 +48,20 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - // const int npw = psi_in->get_ngk(this->ik); - const int npw = psi_in->get_current_nbas(); + + // if (psi_in->get_ngk(this->ik) != psi_in->get_current_nbas()) + // { + // std::cout << "op->ik : " << this->ik << std::endl; + // std::cout << "get_ngk(op->ik) : " << psi_in->get_ngk(this->ik) << std::endl; + // std::cout << "get_current_nbas() : " << psi_in->get_current_nbas() << std::endl; + + // std::cout << "ik : " << this->ik << std::endl; + // } + + + const int npw = psi_in->get_ngk(this->ik); + // const int npw = psi_in->get_current_nbas(); + const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index a5bb076c6d..a05c35ab8c 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -199,7 +199,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - psi::Psi psi_temp(1, 1, psi_nc, dmin, true); + psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true); T* ppsi = psi_temp.get_pointer(); // hpsi and spsi share the temp space T* temp = nullptr; @@ -246,7 +246,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* } else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { - psi::Psi psi_temp(1, nstart, psi_nc, dmin, true); + psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true); T* ppsi = psi_temp.get_pointer(); syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size()); // hpsi and spsi share the temp space diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 6893042343..0e24d17b4b 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -118,6 +118,7 @@ template Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, + const int* ngk_in, const int current_nbasis_in, const bool k_first_in) { @@ -129,7 +130,7 @@ Psi::Psi(const int nk_in, this->npol = PARAM.globalv.npol; this->allocate_inside = true; - this->ngk = nullptr; + this->ngk = ngk_in; assert(nk_in > 0 && nbd_in > 0 && nbs_in > 0); resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index c07840bf7c..449e37bbef 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -73,7 +73,12 @@ class Psi // Constructor 8-3: 2D Psi version 3 - Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in); + Psi(const int nk_in, + const int nbd_in, + const int nbs_in, + const int* ngk_in, + const int current_nbasis_in, + const bool k_first_in); // Destructor for deleting the psi array manually From 8cf13631de6b7f90bef04ebd9a566ef106933061 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 29 Dec 2024 03:08:00 +0000 Subject: [PATCH 25/49] fix bug maybe --- source/module_hsolver/diago_iter_assist.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index a05c35ab8c..171b0ada8d 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -199,7 +199,8 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true); + // psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true); + psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0)); T* ppsi = psi_temp.get_pointer(); // hpsi and spsi share the temp space T* temp = nullptr; @@ -246,7 +247,8 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* } else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { - psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true); + // psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true); + psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); T* ppsi = psi_temp.get_pointer(); syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size()); // hpsi and spsi share the temp space From 7e099719d9c5500bf00381c9a9308d612dfdfdc4 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 29 Dec 2024 04:26:15 +0000 Subject: [PATCH 26/49] fix bug --- source/module_hsolver/diago_iter_assist.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 171b0ada8d..670e313317 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -199,8 +199,8 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - // psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true); - psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true); + // psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0)); T* ppsi = psi_temp.get_pointer(); // hpsi and spsi share the temp space T* temp = nullptr; @@ -247,8 +247,9 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* } else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { - // psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true); - psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true); + // psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); + T* ppsi = psi_temp.get_pointer(); syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size()); // hpsi and spsi share the temp space @@ -266,7 +267,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* T* spsi = temp; // do sPsi for all bands - pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_cur_effective_basis(), psi_temp.get_nbands()); + pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_nbasis(), psi_temp.get_nbands()); gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart); delmem_complex_op()(ctx, temp); From 6ff1b3a0cff19b3146d81123f6f45909c3b49a15 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 29 Dec 2024 05:45:03 +0000 Subject: [PATCH 27/49] check correct --- source/module_hamilt_general/operator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 424a4f4ce9..cf33a77111 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -89,8 +89,8 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - psi_input->get_ngk(op->ik), - // psi_input->get_current_nbas(), + // psi_input->get_ngk(op->ik), + psi_input->get_current_nbas(), is_first_node); break; } From 7e44e9f29bf22ea2b5b44b57ef8fcacb7212cc91 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 29 Dec 2024 06:48:57 +0000 Subject: [PATCH 28/49] check 1 --- source/module_hamilt_general/operator.cpp | 18 ++++------ source/module_hsolver/diago_iter_assist.cpp | 6 ++-- source/module_hsolver/hsolver_pw.cpp | 37 +++++++++------------ source/module_psi/psi.cpp | 31 ++++------------- source/module_psi/psi.h | 7 ---- 5 files changed, 31 insertions(+), 68 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index cf33a77111..ff29f64cf8 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -66,17 +66,12 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void { // a "psi" with the bands of needed range - psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - - - // if (psi_input->get_ngk(op->ik) != psi_input->get_current_nbas()) - // { - // std::cout << "op->ik : " << op->ik << std::endl; - // std::cout << "psi_input->get_ngk(op->ik) : " << psi_input->get_ngk(op->ik) << std::endl; - // std::cout << "psi_input->get_current_nbas() : " << psi_input->get_current_nbas() << std::endl; - - // std::cout << "psi_input->ik : " << psi_input->get_nk() << std::endl; - // } + psi::Psi psi_wrapper(const_cast(tmpsi_in), + 1, + nbands, + psi_input->get_nbasis(), + psi_input->get_nbasis(), + true); switch (op->get_act_type()) { @@ -89,7 +84,6 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), - // psi_input->get_ngk(op->ik), psi_input->get_current_nbas(), is_first_node); break; diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 670e313317..ff0cb59ffc 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -199,8 +199,8 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0), dmin, true); - // psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, 1, psi_nc, dmin, true); + T* ppsi = psi_temp.get_pointer(); // hpsi and spsi share the temp space T* temp = nullptr; @@ -247,7 +247,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* } else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { - psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0), dmin, true); + psi::Psi psi_temp(1, nstart, psi_nc, dmin, true); // psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); T* ppsi = psi_temp.get_pointer(); diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index a9d48ee769..24cf6742d1 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -310,7 +310,11 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, #endif /// solve eigenvector and eigenvalue for H(k) - this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands(), this->wfc_basis->nks); + this->hamiltSolvePsiK(pHamilt, + psi, + precondition, + eigenvalues.data() + ik * psi.get_nbands(), + this->wfc_basis->nks); if (skip_charge) { @@ -370,20 +374,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool}; #endif - auto ngk_pointer = psi.get_ngk_pointer(); - - std::vector ngk_vector(nk_nums, 0); - for (int i = 0; i < nk_nums; i++) - { - ngk_vector[i] = ngk_pointer[i]; - } - const int cur_nbasis = psi.get_ngk(psi.get_current_k()); if (this->method == "cg") { // wrap the subspace_func into a lambda function - auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) { + auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) { // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); @@ -393,13 +389,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, psi_in.shape().dim_size(0), psi_in.shape().dim_size(1), - ngk_vector, cur_nbasis); auto psi_out_wrapper = psi::Psi(psi_out.data(), 1, psi_out.shape().dim_size(0), psi_out.shape().dim_size(1), - ngk_vector, cur_nbasis); auto eigen = ct::Tensor(ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, @@ -419,7 +413,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, using ct_Device = typename ct::PsiToContainer::type; // wrap the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { + auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] @@ -430,7 +424,6 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, ndim == 1 ? 1 : psi_in.shape().dim_size(0), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_vector, cur_nbasis); psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; @@ -491,11 +484,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -512,11 +505,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, else if (this->method == "dav_subspace") { // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -557,17 +550,17 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, // dimensions of matrix to be solved const int dim = psi.get_cur_effective_basis(); /// dimension of matrix - const int nband = psi.get_nbands(); /// number of eigenpairs sought - const int ld_psi = psi.get_nbasis(); /// leading dimension of psi + const int nband = psi.get_nbands(); /// number of eigenpairs sought + const int ld_psi = psi.get_nbasis(); /// leading dimension of psi // Davidson matrix-blockvector functions /// wrap hpsi into lambda function, Matrix \times blockvector // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("David", "hpsi_func"); // Convert pointer of psi_in to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector, cur_nbasis); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 0e24d17b4b..fd539b1fac 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -68,15 +68,18 @@ Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, - const std::vector& ngk_vector_in, const int current_nbasis_in, const bool k_first_in) { + + // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. + assert(nk_in == 1); + this->k_first = k_first_in; this->npol = PARAM.globalv.npol; this->allocate_inside = false; - this->ngk = ngk_vector_in.data(); + this->ngk = nullptr; this->psi = psi_pointer; @@ -94,31 +97,11 @@ Psi::Psi(T* psi_pointer, base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } -// Constructor 8-2: -template -Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in) -{ - this->k_first = k_first_in; - this->ngk = nullptr; - this->current_b = 0; - this->current_k = 0; - this->npol = PARAM.globalv.npol; - this->nk = nk_in; - this->nbands = nbd_in; - this->nbasis = nbs_in; - this->current_nbasis = nbs_in; - this->psi_current = this->psi = psi_pointer; - this->allocate_inside = false; - // Currently only GPU's implementation is supported for device recording! - base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); -} - // Constructor 8-3: 2D Psi version 3 template Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, - const int* ngk_in, const int current_nbasis_in, const bool k_first_in) { @@ -130,7 +113,7 @@ Psi::Psi(const int nk_in, this->npol = PARAM.globalv.npol; this->allocate_inside = true; - this->ngk = ngk_in; + this->ngk = nullptr; assert(nk_in > 0 && nbd_in > 0 && nbs_in > 0); resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); @@ -138,7 +121,7 @@ Psi::Psi(const int nk_in, this->nbands = nbd_in; this->nbasis = nbs_in; - this->current_k = 0; + this->current_k = 0; this->current_b = 0; this->current_nbasis = current_nbasis_in; this->psi_current = this->psi; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 449e37bbef..82cda7fa66 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -63,20 +63,13 @@ class Psi const int nk_in, const int nbd_in, const int nbs_in, - const std::vector& ngk_vector_in, const int current_nbasis_in, const bool k_first_in = true); - // Constructor 8-2: a pointer version of constructor 3 - // only used in operator.cpp call_act func - Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in); - - // Constructor 8-3: 2D Psi version 3 Psi(const int nk_in, const int nbd_in, const int nbs_in, - const int* ngk_in, const int current_nbasis_in, const bool k_first_in); From f4f958c7981d7d12e8443c7f30637966b285da25 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 29 Dec 2024 07:53:58 +0000 Subject: [PATCH 29/49] fix unit test --- source/module_io/test/write_wfc_nao_test.cpp | 2 +- source/module_lr/utils/lr_util.hpp | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/source/module_io/test/write_wfc_nao_test.cpp b/source/module_io/test/write_wfc_nao_test.cpp index 1b39f34a0e..c0effc8711 100644 --- a/source/module_io/test/write_wfc_nao_test.cpp +++ b/source/module_io/test/write_wfc_nao_test.cpp @@ -167,7 +167,7 @@ class WriteWfcLcaoTest : public testing::Test TEST_F(WriteWfcLcaoTest, WriteWfcLcao) { // create a psi object - psi::Psi my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, true); + psi::Psi my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, nbasis_local, true); PARAM.sys.global_out_dir = "./"; ModuleIO::write_wfc_nao(2, my_psi, ekb, wg, kvec_c, pv, -1); diff --git a/source/module_lr/utils/lr_util.hpp b/source/module_lr/utils/lr_util.hpp index 8fdc1b9b96..1ae5c59dc2 100644 --- a/source/module_lr/utils/lr_util.hpp +++ b/source/module_lr/utils/lr_util.hpp @@ -104,14 +104,19 @@ namespace LR_Util /// psi(nk=1, nbands=nb, nk * nbasis) -> psi(nb, nk, nbasis) without memory copy template - psi::Psi k1_to_bfirst_wrapper(const psi::Psi& psi_kfirst, int nk_in, int nbasis_in) + psi::Psi c(const psi::Psi& psi_kfirst, int nk_in, int nbasis_in) { assert(psi_kfirst.get_nk() == 1); assert(nk_in * nbasis_in == psi_kfirst.get_nbasis()); int ib_now = psi_kfirst.get_current_b(); psi_kfirst.fix_b(0); // for get_pointer() to get the head pointer - psi::Psi psi_bfirst(psi_kfirst.get_pointer(), nk_in, psi_kfirst.get_nbands(), nbasis_in, false); + psi::Psi psi_bfirst(psi_kfirst.get_pointer(), + nk_in, + psi_kfirst.get_nbands(), + nbasis_in, + nbasis_in, + false); psi_kfirst.fix_b(ib_now); return psi_bfirst; } @@ -124,7 +129,12 @@ namespace LR_Util int ik_now = psi_bfirst.get_current_k(); psi_bfirst.fix_kb(0, 0); // for get_pointer() to get the head pointer - psi::Psi psi_kfirst(psi_bfirst.get_pointer(), 1, psi_bfirst.get_nbands(), psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), true); + psi::Psi psi_kfirst(psi_bfirst.get_pointer(), + 1, + psi_bfirst.get_nbands(), + psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), + psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), + true); psi_bfirst.fix_kb(ik_now, ib_now); return psi_kfirst; } From 444f21ecca9b0061f188d056fade83f5f85ae502 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 29 Dec 2024 08:44:06 +0000 Subject: [PATCH 30/49] fix unit bug --- source/module_lr/utils/lr_util.hpp | 2 +- source/module_psi/psi.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_lr/utils/lr_util.hpp b/source/module_lr/utils/lr_util.hpp index 1ae5c59dc2..310ae702c5 100644 --- a/source/module_lr/utils/lr_util.hpp +++ b/source/module_lr/utils/lr_util.hpp @@ -104,7 +104,7 @@ namespace LR_Util /// psi(nk=1, nbands=nb, nk * nbasis) -> psi(nb, nk, nbasis) without memory copy template - psi::Psi c(const psi::Psi& psi_kfirst, int nk_in, int nbasis_in) + psi::Psi k1_to_bfirst_wrapper(const psi::Psi& psi_kfirst, int nk_in, int nbasis_in) { assert(psi_kfirst.get_nk() == 1); assert(nk_in * nbasis_in == psi_kfirst.get_nbasis()); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index fd539b1fac..2acb4c2711 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -73,7 +73,7 @@ Psi::Psi(T* psi_pointer, { // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. - assert(nk_in == 1); + // assert(nk_in == 1); this->k_first = k_first_in; this->npol = PARAM.globalv.npol; From 588a335da0195755d6860cb0c1caf828ee53b1f7 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 29 Dec 2024 13:01:13 +0000 Subject: [PATCH 31/49] update get_ngk func --- source/module_psi/psi.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 2acb4c2711..529f18df83 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -484,10 +484,7 @@ int Psi::get_current_nbas() const template const int& Psi::get_ngk(const int ik_in) const { - if (!this->ngk) - { - return this->nbasis; - } + assert(this->ngk != nullptr); return this->ngk[ik_in]; } From 76893ee6e5015e8b450fe2c5dbd03acc3affafd2 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 30 Dec 2024 10:19:40 +0000 Subject: [PATCH 32/49] remove get-ngk in velocity-pw --- .../hamilt_pwdft/operator_pw/velocity_pw.cpp | 4 +- .../hamilt_stodft/sto_elecond.cpp | 28 ++++++------ source/module_hsolver/diago_iter_assist.cpp | 1 - source/module_hsolver/hsolver_pw.cpp | 2 +- source/module_io/write_vxc_lip.hpp | 2 +- source/module_psi/psi.cpp | 30 +++++++++---- source/module_psi/psi.h | 7 ++- source/module_psi/test/psi_test.cpp | 44 ------------------- 8 files changed, 45 insertions(+), 73 deletions(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 81ebbded47..3a01ebf715 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -59,8 +59,8 @@ void Velocity::act // } - const int npw = psi_in->get_ngk(this->ik); - // const int npw = psi_in->get_current_nbas(); + // const int npw = psi_in->get_ngk(this->ik); + const int npw = psi_in->get_current_nbas(); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp index 8692a586ee..c0e5d61d3f 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp @@ -172,9 +172,9 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi>& kspsi_all, const int allbands = bandinfo[5]; const int dim_jmatrix = perbands_ks * allbands_sto + perbands_sto * allbands; - psi::Psi> right_hchi(1, perbands_sto, npwx, p_kv->ngk.data()); - psi::Psi> f_rightchi(1, perbands_sto, npwx, p_kv->ngk.data()); - psi::Psi> f_right_hchi(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> right_hchi(1, perbands_sto, npwx, npw, true); + psi::Psi> f_rightchi(1, perbands_sto, npwx, npw, true); + psi::Psi> f_right_hchi(1, perbands_sto, npwx, npw, true); this->p_hamilt_sto->hPsi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto); this->p_hamilt_sto->hPsi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto); @@ -206,8 +206,8 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi>& kspsi_all, } #endif - psi::Psi> f_batch_vchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); - psi::Psi> f_batch_vhchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); + psi::Psi> f_batch_vchi(1, bsize_psi * ndim, npwx, npw, true); + psi::Psi> f_batch_vhchi(1, bsize_psi * ndim, npwx, npw, true); std::vector> tmpj(ndim * allbands_sto * perbands_sto); // 1. (<\psi|J|\chi>)^T @@ -663,19 +663,19 @@ void Sto_EleCond::sKG(const int& smear_type, //----------------------------------------------------------- //------------------- allocate ------------------------- size_t ks_memory_cost = perbands_ks * npwx * sizeof(std::complex); - psi::Psi> kspsi(1, perbands_ks, npwx, p_kv->ngk.data()); - psi::Psi> vkspsi(1, perbands_ks * ndim, npwx, p_kv->ngk.data()); + psi::Psi> kspsi(1, perbands_ks, npwx, npw, true); + psi::Psi> vkspsi(1, perbands_ks * ndim, npwx, npw, true); std::vector> expmtmf_fact(perbands_ks), expmtf_fact(perbands_ks); - psi::Psi> f_kspsi(1, perbands_ks, npwx, p_kv->ngk.data()); + psi::Psi> f_kspsi(1, perbands_ks, npwx, npw, true); ModuleBase::Memory::record("SDFT::kspsi", ks_memory_cost); - psi::Psi> f_vkspsi(1, perbands_ks * ndim, npwx, p_kv->ngk.data()); + psi::Psi> f_vkspsi(1, perbands_ks * ndim, npwx, npw, true); ModuleBase::Memory::record("SDFT::vkspsi", ks_memory_cost); psi::Psi>* kspsi_all = &f_kspsi; size_t sto_memory_cost = perbands_sto * npwx * sizeof(std::complex); - psi::Psi> sfchi(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> sfchi(1, perbands_sto, npwx, npw, true); ModuleBase::Memory::record("SDFT::sfchi", sto_memory_cost); - psi::Psi> smfchi(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> smfchi(1, perbands_sto, npwx, npw, true); ModuleBase::Memory::record("SDFT::smfchi", sto_memory_cost); #ifdef __MPI psi::Psi> chi_all, hchi_all, psi_all; @@ -702,8 +702,8 @@ void Sto_EleCond::sKG(const int& smear_type, const int nbatch_psi = npart_sto; const int bsize_psi = ceil(double(perbands_sto) / nbatch_psi); - psi::Psi> batch_vchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); - psi::Psi> batch_vhchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); + psi::Psi> batch_vchi(1, bsize_psi * ndim, npwx, npw, true); + psi::Psi> batch_vhchi(1, bsize_psi * ndim, npwx, npw, true); ModuleBase::Memory::record("SDFT::batchjpsi", 3 * bsize_psi * ndim * npwx * sizeof(std::complex)); //------------------- sqrt(f)|psi> sqrt(1-f)|psi> --------------- @@ -781,7 +781,7 @@ void Sto_EleCond::sKG(const int& smear_type, std::vector> j1r(ndim * dim_jmatrix), j2r(ndim * dim_jmatrix); ModuleBase::Memory::record("SDFT::j1r", sizeof(std::complex) * ndim * dim_jmatrix); ModuleBase::Memory::record("SDFT::j2r", sizeof(std::complex) * ndim * dim_jmatrix); - psi::Psi> tmphchil(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> tmphchil(1, perbands_sto, npwx, npw, true); ModuleBase::Memory::record("SDFT::tmphchil/r", sto_memory_cost * 2); //------------------------ t loop -------------------------- diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index ff0cb59ffc..f2fa909fc2 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -248,7 +248,6 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { psi::Psi psi_temp(1, nstart, psi_nc, dmin, true); - // psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); T* ppsi = psi_temp.get_pointer(); syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size()); diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 24cf6742d1..21c9fc9bfc 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -374,7 +374,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool}; #endif - const int cur_nbasis = psi.get_ngk(psi.get_current_k()); + const int cur_nbasis = psi.get_current_nbas(); if (this->method == "cg") { diff --git a/source/module_io/write_vxc_lip.hpp b/source/module_io/write_vxc_lip.hpp index 205fdbb057..f671e726e6 100644 --- a/source/module_io/write_vxc_lip.hpp +++ b/source/module_io/write_vxc_lip.hpp @@ -122,7 +122,7 @@ namespace ModuleIO // const ModuleBase::matrix vr_localxc = potxc->get_veff_smooth(); // 2. allocate xc operator - psi::Psi hpsi_localxc(psi_pw.get_nk(), psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.get_ngk_pointer()); + psi::Psi hpsi_localxc(psi_pw.get_nk(), psi_pw.get_nbands(), psi_pw.get_nbasis(), kv.ngk, true); hpsi_localxc.zero_out(); // std::cout << "hpsi.nk=" << hpsi_localxc.get_nk() << std::endl; // std::cout << "hpsi.nbands=" << hpsi_localxc.get_nbands() << std::endl; diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 529f18df83..cc2aa931f4 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -62,6 +62,26 @@ Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i sizeof(T) * nk_in * nbd_in * nbs_in); } + +template +Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in) +{ + this->k_first = k_first_in; + this->ngk = ngk_in.data(); + this->current_b = 0; + this->current_k = 0; + this->npol = PARAM.globalv.npol; + + this->resize(nk_in, nbd_in, nbs_in); + + // Currently only GPU's implementation is supported for device recording! + base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); + base_device::information::record_device_memory(this->ctx, + GlobalV::ofs_device, + "Psi->resize()", + sizeof(T) * nk_in * nbd_in * nbs_in); +} + // Constructor 8-1: template Psi::Psi(T* psi_pointer, @@ -195,7 +215,7 @@ Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nban template Psi::Psi(const Psi& psi_in) { - this->ngk = psi_in.get_ngk_pointer(); + this->ngk = psi_in.ngk; this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); @@ -220,7 +240,7 @@ template template Psi::Psi(const Psi& psi_in) { - this->ngk = psi_in.get_ngk_pointer(); + this->ngk = psi_in.ngk; this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); @@ -300,12 +320,6 @@ T* Psi::get_pointer(const int& ikb) const return this->psi_current + ikb * this->nbasis; } -template -const int* Psi::get_ngk_pointer() const -{ - return this->ngk; -} - template const bool& Psi::get_k_first() const { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 82cda7fa66..f83cf992ae 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -42,6 +42,8 @@ class Psi // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); + Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); + // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in Psi(const Psi& psi_in, const int nk_in, const int nband_in); @@ -121,12 +123,13 @@ class Psi int get_current_nbas() const; const int& get_ngk(const int ik_in) const; - // return ngk array of psi - const int* get_ngk_pointer() const; + // return k_first const bool& get_k_first() const; + // return device type of psi const Device* get_device() const; + // return psi_bias const int& get_psi_bias() const; diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index fa3f357407..a79ba4354a 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -63,26 +63,6 @@ TEST_F(TestPsi, get_val) EXPECT_EQ(psi_object14->get_psi_bias(), 0); } -// TEST_F(TestPsi, get_ngk) -// { -// psi::Psi>* psi_object21 = new psi::Psi>(&ngk[0]); -// psi::Psi* psi_object22 = new psi::Psi(&ngk[0]); -// psi::Psi>* psi_object23 = new psi::Psi>(&ngk[0]); -// psi::Psi* psi_object24 = new psi::Psi(&ngk[0]); - -// EXPECT_EQ(psi_object21->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object21->get_ngk_pointer()[0], ngk[0]); - -// EXPECT_EQ(psi_object22->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object22->get_ngk_pointer()[0], ngk[0]); - -// EXPECT_EQ(psi_object23->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object23->get_ngk_pointer()[0], ngk[0]); - -// EXPECT_EQ(psi_object24->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object24->get_ngk_pointer()[0], ngk[0]); -// } - TEST_F(TestPsi, get_pointer_op_zero_complex_double) { for (int i = 0; i < ink; i++) @@ -331,30 +311,6 @@ TEST_F(TestPsi, band_first) EXPECT_EQ(std::get<0>(psi_band_32->to_range(illegal_range1)), nullptr); EXPECT_EQ(std::get<1>(psi_band_32->to_range(illegal_range2)), 0); - // pointer constructor - // band-first to k-first - // psi::Psi psi_band_32_k(psi_band_32->get_pointer(), psi_band_32->get_nk(), psi_band_32->get_nbands(), psi_band_32->get_nbasis(), psi_band_32->get_ngk_pointer(), true); - // k-first to band-first - // psi::Psi psi_band_32_b(psi_band_32_k.get_pointer(), psi_band_32_k.get_nk(), psi_band_32_k.get_nbands(), psi_band_32_k.get_nbasis(), psi_band_32_k.get_ngk_pointer(), false); - // EXPECT_EQ(psi_band_32_k.get_nk(), ink); - // EXPECT_EQ(psi_band_32_k.get_nbands(), inbands); - // EXPECT_EQ(psi_band_32_k.get_nbasis(), inbasis); - // EXPECT_EQ(psi_band_32_b.get_nk(), ink); - // EXPECT_EQ(psi_band_32_b.get_nbands(), inbands); - // EXPECT_EQ(psi_band_32_b.get_nbasis(), inbasis); - // for (int ik = 0;ik < ink;++ik) - // { - // for (int ib = 0;ib < inbands;++ib) - // { - // psi_band_32->fix_kb(ik, ib); - // psi_band_32_k.fix_kb(ik, ib); - // psi_band_32_b.fix_kb(ik, ib); - // EXPECT_EQ(psi_band_32->get_psi_bias(), (ib * ink + ik) * inbasis); - // EXPECT_EQ(psi_band_32_k.get_psi_bias(), (ik * inbands + ib) * inbasis); - // EXPECT_EQ(psi_band_32_b.get_psi_bias(), (ib * ink + ik) * inbasis); - // } - // } - delete psi_band_c64; delete psi_band_64; delete psi_band_c32; From 462857f6c13265b37d3a2e03f5c27c08cdc0e66c Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 30 Dec 2024 10:27:47 +0000 Subject: [PATCH 33/49] fix bug --- source/module_psi/psi.cpp | 8 +++++++- source/module_psi/psi.h | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index cc2aa931f4..9ae0f974f6 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -240,7 +240,7 @@ template template Psi::Psi(const Psi& psi_in) { - this->ngk = psi_in.ngk; + this->ngk = psi_in.get_ngk_pointer(); this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); @@ -332,6 +332,12 @@ const Device* Psi::get_device() const return this->ctx; } +template +const int* Psi::get_ngk_pointer() const +{ + return this->ngk; +} + template const int& Psi::get_psi_bias() const { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index f83cf992ae..5a6218a367 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -124,6 +124,8 @@ class Psi const int& get_ngk(const int ik_in) const; + const int* get_ngk_pointer() const; + // return k_first const bool& get_k_first() const; From 7f66c7da6524707fd2c5f6d8e03cab1ecd3cba21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 11:22:42 +0000 Subject: [PATCH 34/49] [pre-commit.ci lite] apply automatic fixes --- source/module_io/write_vxc_lip.hpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/source/module_io/write_vxc_lip.hpp b/source/module_io/write_vxc_lip.hpp index f671e726e6..d57c8f2ccd 100644 --- a/source/module_io/write_vxc_lip.hpp +++ b/source/module_io/write_vxc_lip.hpp @@ -59,24 +59,27 @@ namespace ModuleIO assert(nbands >= 0); #endif std::vector e(nbands, 0.0); - for (int i = 0; i < nbands; ++i) + for (int i = 0; i < nbands; ++i) { e[i] = get_real(mat_mo[i * nbands + i]); +} return e; } template FPTYPE all_band_energy(const int ik, const int nbands, const std::vector>& mat_mo, const ModuleBase::matrix& wg) { FPTYPE e = 0.0; - for (int i = 0; i < nbands; ++i) + for (int i = 0; i < nbands; ++i) { e += get_real(mat_mo[i * nbands + i]) * (FPTYPE)wg(ik, i); +} return e; } template FPTYPE all_band_energy(const int ik, const std::vector& orbital_energy, const ModuleBase::matrix& wg) { FPTYPE e = 0.0; - for (int i = 0; i < orbital_energy.size(); ++i) + for (int i = 0; i < orbital_energy.size(); ++i) { e += orbital_energy[i] * (FPTYPE)wg(ik, i); +} return e; } @@ -170,9 +173,11 @@ namespace ModuleIO #if((defined __LCAO)&&(defined __EXX) && !(defined __CUDA)&& !(defined __ROCM)) if (GlobalC::exx_info.info_global.cal_exx) { - for (int n = 0; n < naos; ++n) - for (int m = 0; m < naos; ++m) + for (int n = 0; n < naos; ++n) { + for (int m = 0; m < naos; ++m) { vexx_k_ao[n * naos + m] += (T)GlobalC::exx_info.info_global.hybrid_alpha * exx_lip.get_exx_matrix()[ik][m][n]; +} +} std::vector vexx_k_mo = cVc(vexx_k_ao.data(), &(exx_lip.get_hvec()(ik, 0, 0)), naos, nbands); Parallel_Reduce::reduce_pool(vexx_k_mo.data(), nbands * nbands); e_orb_exx.emplace_back(orbital_energy(ik, nbands, vexx_k_mo)); From ed7387e1c3b2f93e5402def72217e7a839614fd3 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 30 Dec 2024 11:50:02 +0000 Subject: [PATCH 35/49] fix 186_PW_SKG_ALL bug --- source/module_psi/psi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 9ae0f974f6..90f6d809cc 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -134,7 +134,7 @@ Psi::Psi(const int nk_in, this->allocate_inside = true; this->ngk = nullptr; - assert(nk_in > 0 && nbd_in > 0 && nbs_in > 0); + assert(nk_in > 0 && nbd_in >= 0 && nbs_in > 0); resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); this->nk = nk_in; From 0906e22210d43196d7dbca407d59f1c1ecb8b95e Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 30 Dec 2024 12:52:31 +0000 Subject: [PATCH 36/49] format source/module_io/unk_overlap_pw.cpp --- .../hamilt_pwdft/operator_pw/velocity_pw.cpp | 11 - source/module_io/unk_overlap_pw.cpp | 241 +++++++++--------- 2 files changed, 118 insertions(+), 134 deletions(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 3a01ebf715..a694db196e 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -49,17 +49,6 @@ void Velocity::act { ModuleBase::timer::tick("Operator", "Velocity"); - // if (psi_in->get_ngk(this->ik) != psi_in->get_current_nbas()) - // { - // std::cout << "op->ik : " << this->ik << std::endl; - // std::cout << "get_ngk(op->ik) : " << psi_in->get_ngk(this->ik) << std::endl; - // std::cout << "get_current_nbas() : " << psi_in->get_current_nbas() << std::endl; - - // std::cout << "ik : " << this->ik << std::endl; - // } - - - // const int npw = psi_in->get_ngk(this->ik); const int npw = psi_in->get_current_nbas(); const int max_npw = psi_in->get_nbasis() / psi_in->npol; diff --git a/source/module_io/unk_overlap_pw.cpp b/source/module_io/unk_overlap_pw.cpp index d0d1d7c706..d4f0d2a85b 100644 --- a/source/module_io/unk_overlap_pw.cpp +++ b/source/module_io/unk_overlap_pw.cpp @@ -1,16 +1,16 @@ #include "unk_overlap_pw.h" -#include "module_parameter/parameter.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_parameter/parameter.h" unkOverlap_pw::unkOverlap_pw() { - //GlobalV::ofs_running << "this is unkOverlap_pw()" << std::endl; + // GlobalV::ofs_running << "this is unkOverlap_pw()" << std::endl; } unkOverlap_pw::~unkOverlap_pw() { - //GlobalV::ofs_running << "this is ~unkOverlap_pw()" << std::endl; + // GlobalV::ofs_running << "this is ~unkOverlap_pw()" << std::endl; } std::complex unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw, @@ -20,50 +20,44 @@ std::complex unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw, const int iband_R, const psi::Psi>* evc) { - - std::complex result(0.0,0.0); - const int number_pw = wfcpw->npw; - std::complex *unk_L = new std::complex[number_pw]; - std::complex *unk_R = new std::complex[number_pw]; - ModuleBase::GlobalFunc::ZEROS(unk_L,number_pw); - ModuleBase::GlobalFunc::ZEROS(unk_R,number_pw); - - - for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) - { - unk_L[wfcpw->getigl2ig(ik_L,igl)] = evc[0](ik_L, iband_L, igl); - } - - for (int igl = 0; igl < evc->get_ngk(ik_R); igl++) - { - unk_R[wfcpw->getigl2ig(ik_R,igl)] = evc[0](ik_R, iband_R, igl); - } - - - for (int iG = 0; iG < number_pw; iG++) - { - - result = result + conj(unk_L[iG]) * unk_R[iG]; - - } + std::complex result(0.0, 0.0); + const int number_pw = wfcpw->npw; + std::complex* unk_L = new std::complex[number_pw]; + std::complex* unk_R = new std::complex[number_pw]; + ModuleBase::GlobalFunc::ZEROS(unk_L, number_pw); + ModuleBase::GlobalFunc::ZEROS(unk_R, number_pw); + + for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) + { + unk_L[wfcpw->getigl2ig(ik_L, igl)] = evc[0](ik_L, iband_L, igl); + } + + for (int igl = 0; igl < evc->get_ngk(ik_R); igl++) + { + unk_R[wfcpw->getigl2ig(ik_R, igl)] = evc[0](ik_R, iband_R, igl); + } + + for (int iG = 0; iG < number_pw; iG++) + { + + result = result + conj(unk_L[iG]) * unk_R[iG]; + } #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - delete[] unk_L; - delete[] unk_R; - return result; - - + delete[] unk_L; + delete[] unk_R; + return result; } std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, @@ -75,24 +69,24 @@ std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, const psi::Psi>* evc, const ModuleBase::Vector3 G) { - // (1) set value - std::complex result(0.0,0.0); + // (1) set value + std::complex result(0.0, 0.0); std::complex* psi_r = new std::complex[wfcpw->nmaxgr]; std::complex* phase = new std::complex[rhopw->nmaxgr]; // get the phase value in realspace for (int ig = 0; ig < rhopw->nmaxgr; ig++) { - ModuleBase::Vector3 delta_G = rhopw->gdirect[ig] - G; - if (delta_G.norm2() < 1e-10) // rhopw->gdirect[ig] == G - { - phase[ig] = std::complex(1.0,0.0); - break; - } - } - - // (2) fft and get value - rhopw->recip2real(phase, phase); + ModuleBase::Vector3 delta_G = rhopw->gdirect[ig] - G; + if (delta_G.norm2() < 1e-10) // rhopw->gdirect[ig] == G + { + phase[ig] = std::complex(1.0, 0.0); + break; + } + } + + // (2) fft and get value + rhopw->recip2real(phase, phase); wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_r, ik_L); for (int ir = 0; ir < rhopw->nmaxgr; ir++) @@ -110,17 +104,17 @@ std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - - delete[] psi_r; - delete[] phase; + + delete[] psi_r; + delete[] phase; return result; } @@ -133,18 +127,18 @@ std::complex unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf const int npwx, const psi::Psi>* evc) { - - std::complex result(0.0,0.0); + + std::complex result(0.0, 0.0); const int number_pw = wfcpw->npw; std::complex* unk_L = new std::complex[number_pw * PARAM.globalv.npol]; std::complex* unk_R = new std::complex[number_pw * PARAM.globalv.npol]; - ModuleBase::GlobalFunc::ZEROS(unk_L,number_pw*PARAM.globalv.npol); - ModuleBase::GlobalFunc::ZEROS(unk_R,number_pw*PARAM.globalv.npol); - - for(int i = 0; i < PARAM.globalv.npol; i++) - { - for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) - { + ModuleBase::GlobalFunc::ZEROS(unk_L, number_pw * PARAM.globalv.npol); + ModuleBase::GlobalFunc::ZEROS(unk_R, number_pw * PARAM.globalv.npol); + + for (int i = 0; i < PARAM.globalv.npol; i++) + { + for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) + { unk_L[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_L, iband_L, igl + i * npwx); } @@ -154,32 +148,29 @@ std::complex unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf } } - for (int iG = 0; iG < number_pw*PARAM.globalv.npol; iG++) - { + for (int iG = 0; iG < number_pw * PARAM.globalv.npol; iG++) + { - result = result + conj(unk_L[iG]) * unk_R[iG]; + result = result + conj(unk_L[iG]) * unk_R[iG]; + } - } - #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - delete[] unk_L; - delete[] unk_R; - return result; - - + delete[] unk_L; + delete[] unk_R; + return result; } -//here G is in direct coordinate +// here G is in direct coordinate std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rhopw, const ModulePW::PW_Basis_K* wfcpw, const int ik_L, @@ -189,32 +180,32 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho const psi::Psi>* evc, const ModuleBase::Vector3 G) { - // (1) set value - std::complex result(0.0,0.0); - std::complex *phase =new std::complex[rhopw->nmaxgr]; + // (1) set value + std::complex result(0.0, 0.0); + std::complex* phase = new std::complex[rhopw->nmaxgr]; std::complex* psi_up = new std::complex[wfcpw->nmaxgr]; std::complex* psi_down = new std::complex[wfcpw->nmaxgr]; const int npwx = wfcpw->npwk_max; // get the phase value in realspace for (int ig = 0; ig < rhopw->npw; ig++) - { - if (rhopw->gdirect[ig] == G) - { - phase[ig] = std::complex(1.0,0.0); - break; - } - } - - // (2) fft and get value - rhopw->recip2real(phase, phase); + { + if (rhopw->gdirect[ig] == G) + { + phase[ig] = std::complex(1.0, 0.0); + break; + } + } + + // (2) fft and get value + rhopw->recip2real(phase, phase); wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L); wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L); for (int ir = 0; ir < wfcpw->nrxx; ir++) { psi_up[ir] = psi_up[ir] * phase[ir]; - psi_down[ir] = psi_down[ir] * phase[ir]; + psi_down[ir] = psi_down[ir] * phase[ir]; } // (3) calculate the overlap in ik_L and ik_R @@ -223,27 +214,31 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho for (int i = 0; i < PARAM.globalv.npol; i++) { - for(int ig = 0; ig < evc->get_ngk(ik_R); ig++) - { - if( i == 0 ) { result = result + conj( psi_up[ig] ) * evc[0](ik_R, iband_R, ig); -} - if( i == 1 ) { result = result + conj( psi_down[ig] ) * evc[0](ik_R, iband_R, ig + npwx); -} - } - } - + for (int ig = 0; ig < evc->get_ngk(ik_R); ig++) + { + if (i == 0) + { + result = result + conj(psi_up[ig]) * evc[0](ik_R, iband_R, ig); + } + if (i == 1) + { + result = result + conj(psi_down[ig]) * evc[0](ik_R, iband_R, ig + npwx); + } + } + } + #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - - delete[] psi_up; - delete[] psi_down; + + delete[] psi_up; + delete[] psi_down; return result; } From c2cb0df0ab6aff5b4da69ad8dc2356e15ec88fbb Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 30 Dec 2024 13:07:34 +0000 Subject: [PATCH 37/49] update Constructor in psi --- source/module_esolver/esolver_ks_lcao.cpp | 2 +- source/module_esolver/esolver_of.cpp | 6 +++++- source/module_esolver/esolver_of_tool.cpp | 6 +++++- source/module_io/get_pchg_lcao.cpp | 8 ++++++-- source/module_lr/esolver_lrtd_lcao.cpp | 12 +++++++++--- source/module_lr/hamilt_casida.cpp | 8 ++++++-- source/module_psi/psi.cpp | 2 +- source/module_psi/psi.h | 2 +- 8 files changed, 34 insertions(+), 12 deletions(-) diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index e5e0684e9e..939fc1ae84 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -1080,7 +1080,7 @@ void ESolver_KS_LCAO::after_scf(UnitCell& ucell, const int istep) //! initialize the gradients of Etotal with respect to occupation numbers and wfc, //! and set all elements to 0. ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true); - psi::Psi dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis()); + psi::Psi dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis(), this->kv.ngk, true); dE_dWfc.zero_out(); double Etotal_RDMFT = this->rdmft_solver.run(dE_dOccNum, dE_dWfc); diff --git a/source/module_esolver/esolver_of.cpp b/source/module_esolver/esolver_of.cpp index 8c30664573..ce51a73e1b 100644 --- a/source/module_esolver/esolver_of.cpp +++ b/source/module_esolver/esolver_of.cpp @@ -220,7 +220,11 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell) // Refresh the arrays delete this->psi_; - this->psi_ = new psi::Psi(1, PARAM.inp.nspin, this->pw_rho->nrxx); + this->psi_ = new psi::Psi(1, + PARAM.inp.nspin, + this->pw_rho->nrxx, + this->pw_rho->nrxx, + true); for (int is = 0; is < PARAM.inp.nspin; ++is) { this->pphi_[is] = this->psi_->get_pointer(is); diff --git a/source/module_esolver/esolver_of_tool.cpp b/source/module_esolver/esolver_of_tool.cpp index 750598ab2e..e430347215 100644 --- a/source/module_esolver/esolver_of_tool.cpp +++ b/source/module_esolver/esolver_of_tool.cpp @@ -71,7 +71,11 @@ void ESolver_OF::init_elecstate(UnitCell& ucell) void ESolver_OF::allocate_array() { // Initialize the "wavefunction", which is sqrt(rho) - this->psi_ = new psi::Psi(1, PARAM.inp.nspin, this->pw_rho->nrxx); + this->psi_ = new psi::Psi(1, + PARAM.inp.nspin, + this->pw_rho->nrxx, + this->pw_rho->nrxx, + true); ModuleBase::Memory::record("OFDFT::Psi", sizeof(double) * PARAM.inp.nspin * this->pw_rho->nrxx); this->pphi_ = new double*[PARAM.inp.nspin]; for (int is = 0; is < PARAM.inp.nspin; ++is) diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 4b3013b581..9e902e6392 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -541,8 +541,12 @@ void IState_Charge::idmatrix(const int& ib, } this->psi_k->fix_k(ik); - // psi::Psi> wg_wfc(*this->psi_k, 1); - psi::Psi> wg_wfc(1, this->psi_k->get_nbands(), this->psi_k->get_nbasis()); + + psi::Psi> wg_wfc(1, + this->psi_k->get_nbands(), + this->psi_k->get_nbasis(), + this->psi_k->get_nbasis(), + true); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_lr/esolver_lrtd_lcao.cpp b/source/module_lr/esolver_lrtd_lcao.cpp index 38fb300594..7ecdfa2948 100644 --- a/source/module_lr/esolver_lrtd_lcao.cpp +++ b/source/module_lr/esolver_lrtd_lcao.cpp @@ -181,7 +181,11 @@ LR::ESolver_LR::ESolver_LR(ModuleESolver::ESolver_KS_LCAO&& ks_sol if (this->nbands == PARAM.inp.nbands) { move_gs(); } else // copy the part of ground state info according to paraC_ { - this->psi_ks = new psi::Psi(this->kv.get_nks(), this->paraC_.get_col_size(), this->paraC_.get_row_size()); + this->psi_ks = new psi::Psi(this->kv.get_nks(), + this->paraC_.get_col_size(), + this->paraC_.get_row_size(), + this->kv.ngk, + true); this->eig_ks.create(this->kv.get_nks(), this->nbands); const int start_band = this->nocc_max - *std::max_element(nocc.begin(), nocc.end()); for (int ik = 0;ik < this->kv.get_nks();++ik) @@ -289,8 +293,10 @@ LR::ESolver_LR::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu // now ModuleIO::read_wfc_nao needs `Parallel_Orbitals` and can only read all the bands // it need improvement to read only the bands needed this->psi_ks = new psi::Psi(this->kv.get_nks(), - this->paraMat_.ncol_bands, - this->paraMat_.get_row_size()); + this->paraMat_.ncol_bands, + this->paraMat_.get_row_size(), + this->kv.ngk, + true); this->read_ks_wfc(); if (nspin == 2) { diff --git a/source/module_lr/hamilt_casida.cpp b/source/module_lr/hamilt_casida.cpp index a29429d5be..2cc4382b91 100644 --- a/source/module_lr/hamilt_casida.cpp +++ b/source/module_lr/hamilt_casida.cpp @@ -18,13 +18,17 @@ namespace LR {//calculate A^{ai} for each bj int bj = j * nv + b; //global int kbj = ik * npairs + bj; //global - psi::Psi X_bj(1, 1, this->nk * px.get_local_size()); // k1-first, like in iterative solver + psi::Psi X_bj(1, 1, this->nk * px.get_local_size(), this->nk * px.get_local_size(), true); // k1-first, like in iterative solver X_bj.zero_out(); // X_bj(0, 0, lj * px.get_row_size() + lb) = this->one(); int lj = px.global2local_col(j); int lb = px.global2local_row(b); if (px.in_this_processor(b, j)) { X_bj(0, 0, ik * px.get_local_size() + lj * px.get_row_size() + lb) = this->one(); } - psi::Psi A_aibj(1, 1, this->nk * px.get_local_size()); // k1-first + psi::Psi A_aibj(1, + 1, + this->nk * px.get_local_size(), + this->nk * px.get_local_size(), + true); // k1-first A_aibj.zero_out(); this->cal_dm_trans(0, X_bj.get_pointer()); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 90f6d809cc..2dc14ab9a7 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -93,7 +93,7 @@ Psi::Psi(T* psi_pointer, { // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. - // assert(nk_in == 1); + // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func this->k_first = k_first_in; this->npol = PARAM.globalv.npol; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 5a6218a367..f384ed17c4 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -40,7 +40,7 @@ class Psi Psi(); // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later - Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); + Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true); Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); From 5a86f4576fe34ac105e7c827766d1c54b9c6980d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 13:22:07 +0000 Subject: [PATCH 38/49] [pre-commit.ci lite] apply automatic fixes --- source/module_lr/hamilt_casida.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/source/module_lr/hamilt_casida.cpp b/source/module_lr/hamilt_casida.cpp index 2cc4382b91..5d7958295b 100644 --- a/source/module_lr/hamilt_casida.cpp +++ b/source/module_lr/hamilt_casida.cpp @@ -12,8 +12,8 @@ namespace LR const int ldim = nk * px.get_local_size(); int npairs = no * nv; std::vector Amat_full(this->nk * npairs * this->nk * npairs, 0.0); - for (int ik = 0;ik < this->nk;++ik) - for (int j = 0;j < no;++j) + for (int ik = 0;ik < this->nk;++ik) { + for (int j = 0;j < no;++j) { for (int b = 0;b < nv;++b) {//calculate A^{ai} for each bj int bj = j * nv + b; //global @@ -41,12 +41,15 @@ namespace LR // reduce ai for a fixed bj A_aibj.fix_kb(0, 0); #ifdef __MPI - for (int ik_ai = 0;ik_ai < this->nk;++ik_ai) + for (int ik_ai = 0;ik_ai < this->nk;++ik_ai) { LR_Util::gather_2d_to_full(px, &A_aibj.get_pointer()[ik_ai * px.get_local_size()], Amat_full.data() + kbj * this->nk * npairs /*col, bj*/ + ik_ai * npairs/*row, ai*/, false, nv, no); +} #endif } +} +} // output Amat std::cout << "Full A matrix: (elements < 1e-10 is set to 0)" << std::endl; LR_Util::print_value(Amat_full.data(), nk * npairs, nk * npairs); From 8935299e78f00a6b41af46bc9ac0d1a88165ae54 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 30 Dec 2024 13:47:11 +0000 Subject: [PATCH 39/49] debug unit test --- .../hamilt_stodft/test/test_sto_tool.cpp | 8 +++--- source/module_lr/AX/test/AX_test.cpp | 23 +++++++++++----- .../module_lr/dm_trans/test/dm_trans_test.cpp | 26 ++++++++++++++----- .../utils/test/lr_util_algorithms_test.cpp | 2 +- source/module_psi/test/psi_test.cpp | 2 +- 5 files changed, 42 insertions(+), 19 deletions(-) diff --git a/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp b/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp index 2e429dfa3f..f343c93c81 100644 --- a/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp @@ -68,8 +68,8 @@ TEST_F(TestStoTool, parallel_distribution) TEST_F(TestStoTool, convert_psi) { - psi::Psi> psi_in(1, 1, 10); - psi::Psi> psi_out(1, 1, 10); + psi::Psi> psi_in(1, 1, 10, 10, true); + psi::Psi> psi_out(1, 1, 10, 10, true); for (int i = 0; i < 10; ++i) { psi_in.get_pointer()[i] = std::complex(i, i); @@ -83,8 +83,8 @@ TEST_F(TestStoTool, convert_psi) TEST_F(TestStoTool, gatherchi) { - psi::Psi> chi(1, 1, 10); - psi::Psi> chi_all(1, 1, 10); + psi::Psi> chi(1, 1, 10, 10, true); + psi::Psi> chi_all(1, 1, 10, 10, true); int npwx = 10; int nrecv_sto[4] = {1, 2, 3, 4}; int displs_sto[4] = {0, 1, 3, 6}; diff --git a/source/module_lr/AX/test/AX_test.cpp b/source/module_lr/AX/test/AX_test.cpp index b137f78ea0..697c71bc7e 100644 --- a/source/module_lr/AX/test/AX_test.cpp +++ b/source/module_lr/AX/test/AX_test.cpp @@ -70,7 +70,8 @@ TEST_F(AXTest, DoubleSerial) int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos); + std::vector temp(s.nks, s.naos); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); std::vector V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data(), size_v); } @@ -91,7 +92,8 @@ TEST_F(AXTest, ComplexSerial) int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos); + std::vector temp(s.nks, s.naos); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); std::vector V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data>(), size_v); } @@ -113,7 +115,9 @@ TEST_F(AXTest, DoubleParallel) std::vector V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() })); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector ngk_temp(s.nks, pc.get_row_size()); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp.data(), true); Parallel_2D px; LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); @@ -139,7 +143,9 @@ TEST_F(AXTest, DoubleParallel) } // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos); + + std::vector ngk_temp_1(s.nks, s.naos); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1.data(), true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data(), V_full.at(isk).data(), false, s.naos, s.naos); @@ -165,7 +171,9 @@ TEST_F(AXTest, ComplexParallel) std::vector V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() })); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector ngk_temp_1(s.nks, pc.get_row_size()); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1.data(), true); Parallel_2D px; LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); @@ -187,7 +195,10 @@ TEST_F(AXTest, ComplexParallel) } // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos); + + + std::vector ngk_temp_2(s.nks, s.naos); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data>(), V_full.at(isk).data>(), false, s.naos, s.naos); diff --git a/source/module_lr/dm_trans/test/dm_trans_test.cpp b/source/module_lr/dm_trans/test/dm_trans_test.cpp index b78bbb50bc..06e412155f 100644 --- a/source/module_lr/dm_trans/test/dm_trans_test.cpp +++ b/source/module_lr/dm_trans/test/dm_trans_test.cpp @@ -66,7 +66,9 @@ TEST_F(DMTransTest, DoubleSerial) for (int istate = 0;istate < nstate;++istate) { int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos); + + std::vector temp(s.nks, s.naos); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); set_rand(c.get_pointer(), size_c); X.fix_b(istate); const std::vector& dm_for = LR::cal_dm_trans_forloop_serial(X.get_pointer(), c, s.nocc, s.nvirt); @@ -85,7 +87,9 @@ TEST_F(DMTransTest, ComplexSerial) for (int istate = 0;istate < nstate;++istate) { int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos); + + std::vector temp(s.nks, s.naos); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); set_rand(c.get_pointer(), size_c); X.fix_b(istate); const std::vector& dm_for = LR::cal_dm_trans_forloop_serial(X.get_pointer(), c, s.nocc, s.nvirt); @@ -105,10 +109,14 @@ TEST_F(DMTransTest, DoubleParallel) // X: nvirt*nocc in para2d, nocc*nvirt in psi (row-para and constructed: nvirt) Parallel_2D px; LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc); - psi::Psi X(s.nks, nstate, px.get_local_size(), nullptr, false); + + std::vector temp_1(s.nks, px.get_local_size()); + psi::Psi X(s.nks, nstate, px.get_local_size(), temp_1.data(), false); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px.blacs_ctxt); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector temp_2(s.nks, pc.get_row_size()); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2.data(), true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px.blacs_ctxt); @@ -147,7 +155,8 @@ TEST_F(DMTransTest, DoubleParallel) LR_Util::gather_2d_to_full(pmat, dm_pblas_loc[isk].data(), dm_gather[isk].data(), false, s.naos, s.naos); // compare to global matrix - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos); + std::vector temp(s.nks, s.naos); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); @@ -173,7 +182,9 @@ TEST_F(DMTransTest, ComplexParallel) psi::Psi> X(s.nks, nstate, px.get_local_size(), nullptr, false); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px.blacs_ctxt); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector temp(s.nks, pc.get_row_size()); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp.data(), true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px.blacs_ctxt); @@ -206,7 +217,8 @@ TEST_F(DMTransTest, ComplexParallel) LR_Util::gather_2d_to_full(pmat, dm_pblas_loc[isk].data>(), dm_gather[isk].data>(), false, s.naos, s.naos); // compare to global matrix - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos); + std::vector ngk_temp_2(s.nks, s.naos); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); diff --git a/source/module_lr/utils/test/lr_util_algorithms_test.cpp b/source/module_lr/utils/test/lr_util_algorithms_test.cpp index f33105d32b..3318e526cb 100644 --- a/source/module_lr/utils/test/lr_util_algorithms_test.cpp +++ b/source/module_lr/utils/test/lr_util_algorithms_test.cpp @@ -9,7 +9,7 @@ TEST(LR_Util, PsiWrapper) int nbands = 5; int nbasis = 6; - psi::Psi k1(1, nbands, nk * nbasis); + psi::Psi k1(1, nbands, nk * nbasis, nk * nbasis, true); for (int i = 0;i < nbands * nk * nbasis;++i)k1.get_pointer()[i] = i; k1.fix_b(2); diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index a79ba4354a..f3be8fb84a 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -99,7 +99,7 @@ TEST_F(TestPsi, get_pointer_op_zero_complex_double) // cover all lines in fix_k func psi_object31->fix_k(2); EXPECT_EQ(psi_object31->get_psi_bias(), 0); - psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis); + psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, inbasis, true); psi_temp->fix_k(0); EXPECT_EQ(psi_object31->get_current_nbas(), inbasis); delete psi_temp; From bfdfc922584ac529439e8088c1d159b27a0d5c6d Mon Sep 17 00:00:00 2001 From: haozhihan Date: Mon, 30 Dec 2024 14:16:44 +0000 Subject: [PATCH 40/49] fix ri test bug --- source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp | 4 ++-- source/module_ri/exx_lip.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp b/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp index 71b4fdcbd7..7ae95a8918 100644 --- a/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp +++ b/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp @@ -23,7 +23,7 @@ UnitCell::~UnitCell() { TEST(RI_Benchmark, SlicePsi) { const int nk = 1, nbands = 2, nbasis = 3; - psi::Psi psi(nk, nbands, nbasis); + psi::Psi psi(nk, nbands, nbasis, nbasis, true); for (int i = 0; i < nk * nbands * nbasis; i++) { psi.get_pointer()[i] = i; } @@ -50,7 +50,7 @@ TEST(RI_Benchmark, CalCsMO) for (int i = 0;i < nabf * nao * nao;++i) { Cs_ao[0][{0, { 0, 0, 0 }}].ptr()[i] = static_cast(i); } const UnitCell ucell; - psi::Psi psi_ks(1, 2, 2); + psi::Psi psi_ks(1, 2, 2, 2, true); for (int i = 0;i < 4;++i) { psi_ks.get_pointer()[i] = static_cast(i); } RI_Benchmark::TLRI Cs_a_mo = RI_Benchmark::cal_Cs_mo(ucell, Cs_ao, psi_ks, nocc, nvirt, false); std::vector Cs_a_mo_ref = { 11,31 }; diff --git a/source/module_ri/exx_lip.hpp b/source/module_ri/exx_lip.hpp index 06456b6f80..bd94316027 100644 --- a/source/module_ri/exx_lip.hpp +++ b/source/module_ri/exx_lip.hpp @@ -114,7 +114,7 @@ Exx_Lip::Exx_Lip(const Exx_Info::Exx_Info_Lip& info_in, #endif this->k_pack->wf_wg.create(this->k_pack->kv_ptr->get_nks(),PARAM.inp.nbands); - this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal); + this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk.data(), true); // this->k_pack->hvec_array = new ModuleBase::ComplexMatrix[this->k_pack->kv_ptr->get_nks()]; // for( int ik=0; ikk_pack->kv_ptr->get_nks(); ++ik) // { From c55e20d1a6b4a6970e1f0ad82727640af3621fc6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 15:00:41 +0000 Subject: [PATCH 41/49] [pre-commit.ci lite] apply automatic fixes --- source/module_hsolver/diago_iter_assist.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index f2fa909fc2..33986955c3 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -213,7 +213,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* { // psi_temp is one band psi, psi is all bands psi, the range always is 1 for the only band in psi_temp syncmem_complex_op()(ctx, ctx, ppsi, psi + i * psi_nc, psi_nc); - psi::Range band_by_band_range(1, 0, 0, 0); + psi::Range band_by_band_range(true, 0, 0, 0); hpsi_info hpsi_in(&psi_temp, band_by_band_range, hpsi); // H|Psi> to get hpsi for target band @@ -258,7 +258,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* T* hpsi = temp; // do hPsi for all bands - psi::Range all_bands_range(1, 0, 0, nstart - 1); + psi::Range all_bands_range(true, 0, 0, nstart - 1); hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); pHamilt->ops->hPsi(hpsi_in); @@ -588,8 +588,9 @@ bool DiagoIterAssist::test_exit_cond(const int& ntry, const int& notc //================================================================ bool scf = true; - if (PARAM.inp.calculation == "nscf") + if (PARAM.inp.calculation == "nscf") { scf = false; +} // If ntry <=5, try to do it better, if ntry > 5, exit. const bool f1 = (ntry <= 5); From 49987f86755b2016db5271832242e81972318e86 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 31 Dec 2024 01:26:12 +0000 Subject: [PATCH 42/49] fix psi-ut bug --- source/module_psi/test/psi_test.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index f3be8fb84a..52567c5aa9 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -99,7 +99,9 @@ TEST_F(TestPsi, get_pointer_op_zero_complex_double) // cover all lines in fix_k func psi_object31->fix_k(2); EXPECT_EQ(psi_object31->get_psi_bias(), 0); - psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, inbasis, true); + + std::vector temp(ink, inbasis); + psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp.data(), true); psi_temp->fix_k(0); EXPECT_EQ(psi_object31->get_current_nbas(), inbasis); delete psi_temp; From 2def09e25e3fe03c15af9a4de2bc19e6f9f8d5d3 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 31 Dec 2024 02:47:35 +0000 Subject: [PATCH 43/49] remove Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) --- source/module_hamilt_general/operator.cpp | 15 +++++++-- source/module_lr/utils/lr_util.hpp | 6 +++- source/module_psi/psi.cpp | 38 +++++++++++------------ source/module_psi/psi.h | 6 ++-- source/module_psi/test/psi_test.cpp | 3 -- 5 files changed, 40 insertions(+), 28 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index ff29f64cf8..24191f15f8 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -61,7 +61,12 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp // ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size()); syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size()); delete this->hpsi; - this->hpsi = new psi::Psi(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol); + this->hpsi = new psi::Psi(hpsi_pointer, + 1, + nbands / psi_input->npol, + psi_input->get_nbasis(), + psi_input->get_nbasis(), + true); } auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void { @@ -177,7 +182,13 @@ T* Operator::get_hpsi(const hpsi_info& info) const else { this->in_place = false; - this->hpsi = new psi::Psi(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range); + // this->hpsi = new psi::Psi(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range); + this->hpsi = new psi::Psi(hpsi_pointer, + 1, + nbands_range, + std::get<0>(info)->get_nbasis(), + std::get<0>(info)->get_nbasis(), + true); } hpsi_pointer = this->hpsi->get_pointer(); diff --git a/source/module_lr/utils/lr_util.hpp b/source/module_lr/utils/lr_util.hpp index 310ae702c5..5bbedf645f 100644 --- a/source/module_lr/utils/lr_util.hpp +++ b/source/module_lr/utils/lr_util.hpp @@ -99,7 +99,11 @@ namespace LR_Util template psi::Psi get_psi_spin(const psi::Psi& psi_in, const int& is, const int& nk) { - return psi::Psi(&psi_in(is * nk, 0, 0), psi_in, nk, psi_in.get_nbands()); + return psi::Psi(&psi_in(is * nk, 0, 0), + nk, + psi_in.get_nbands(), + psi_in.get_nbasis(), + true); } /// psi(nk=1, nbands=nb, nk * nbasis) -> psi(nb, nk, nbasis) without memory copy diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 2dc14ab9a7..e84ef3ba6d 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -93,7 +93,7 @@ Psi::Psi(T* psi_pointer, { // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. - // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func + // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func this->k_first = k_first_in; this->npol = PARAM.globalv.npol; @@ -193,24 +193,24 @@ Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) } } -template -Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) -{ - this->k_first = psi_in.get_k_first(); - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } - this->ngk = psi_in.ngk; - this->npol = psi_in.npol; - this->nk = nk_in; - this->nbands = nband_in; - this->nbasis = psi_in.nbasis; - this->psi_current = psi_pointer; - this->allocate_inside = false; - this->psi = psi_pointer; -} +// template +// Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) +// { +// this->k_first = psi_in.get_k_first(); +// assert(nk_in <= psi_in.get_nk()); +// if (nband_in == 0) +// { +// nband_in = psi_in.get_nbands(); +// } +// this->ngk = psi_in.ngk; +// this->npol = psi_in.npol; +// this->nk = nk_in; +// this->nbands = nband_in; +// this->nbasis = psi_in.nbasis; +// this->psi_current = psi_pointer; +// this->allocate_inside = false; +// this->psi = psi_pointer; +// } template Psi::Psi(const Psi& psi_in) diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index f384ed17c4..3b0f732713 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -47,9 +47,9 @@ class Psi // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in Psi(const Psi& psi_in, const int nk_in, const int nband_in); - // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() - // in this case, fix_k can not be used - Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); + // // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() + // // in this case, fix_k can not be used + // Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in); // Constructor 6: initialize a new psi from the given psi_in Psi(const Psi& psi_in); diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index 52567c5aa9..0b42df63c7 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -14,9 +14,6 @@ class TestPsi : public ::testing::Test const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - - // psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); - psi::Psi>* psi_object5 = new psi::Psi>(psi_object31->get_pointer(), *psi_object31, ink, 0); }; TEST_F(TestPsi, get_val) From d23dfe52bf421798b0cfeb9cef9b2c140f12bc83 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 31 Dec 2024 05:13:03 +0000 Subject: [PATCH 44/49] remove useless code --- source/module_psi/psi.cpp | 19 ------------------- source/module_psi/psi.h | 26 ++++++++++++-------------- 2 files changed, 12 insertions(+), 33 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index e84ef3ba6d..251f0fe065 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -193,25 +193,6 @@ Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) } } -// template -// Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) -// { -// this->k_first = psi_in.get_k_first(); -// assert(nk_in <= psi_in.get_nk()); -// if (nband_in == 0) -// { -// nband_in = psi_in.get_nbands(); -// } -// this->ngk = psi_in.ngk; -// this->npol = psi_in.npol; -// this->nk = nk_in; -// this->nbands = nband_in; -// this->nbasis = psi_in.nbasis; -// this->psi_current = psi_pointer; -// this->allocate_inside = false; -// this->psi = psi_pointer; -// } - template Psi::Psi(const Psi& psi_in) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 3b0f732713..eaabb978b8 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -36,31 +36,25 @@ template class Psi { public: - // Constructor 1: basic + // Constructor 0: basic Psi(); - // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later + // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true); + // Constructor 1-2: Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); - // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, const int nband_in); - - // // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() - // // in this case, fix_k can not be used - // Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in); - - // Constructor 6: initialize a new psi from the given psi_in + // Constructor 2-1: initialize a new psi from the given psi_in Psi(const Psi& psi_in); - // Constructor 7: initialize a new psi from the given psi_in with a different class template + // Constructor 2-2: initialize a new psi from the given psi_in with a different class template // in this case, psi_in may have a different device type. template Psi(const Psi& psi_in); - // Constructor 8-1: a pointer version of constructor 3 - // only used in hsolver-pw function pointer. + // Constructor 3-1: 2D Psi version + // used in hsolver-pw function pointer and somewhere. Psi(T* psi_pointer, const int nk_in, const int nbd_in, @@ -68,13 +62,17 @@ class Psi const int current_nbasis_in, const bool k_first_in = true); - // Constructor 8-3: 2D Psi version 3 + // Constructor 3-2: 2D Psi version Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in); + + // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in + Psi(const Psi& psi_in, const int nk_in, const int nband_in); + // Destructor for deleting the psi array manually ~Psi(); From 24adafe514334a33732fc3803400d8a6a674e1d0 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 31 Dec 2024 05:29:49 +0000 Subject: [PATCH 45/49] update Psi(const Psi& psi_in, const int nk_in, const int nband_in); --- source/module_hamilt_general/operator.cpp | 9 +++++++-- source/module_psi/psi.cpp | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 24191f15f8..008d5e30e3 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -177,12 +177,17 @@ T* Operator::get_hpsi(const hpsi_info& info) const else if (hpsi_pointer == psi_pointer) { this->in_place = true; - this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + // this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + this->hpsi = new psi::Psi(1, + nbands_range, + std::get<0>(info)->get_nbasis(), + std::get<0>(info)->get_nbasis(), + true); } else { this->in_place = false; - // this->hpsi = new psi::Psi(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range); + this->hpsi = new psi::Psi(hpsi_pointer, 1, nbands_range, diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 251f0fe065..20f61866b7 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -158,7 +158,7 @@ Psi::Psi(const int nk_in, template Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) { - assert(nk_in <= psi_in.get_nk() && nk_in > 0); + assert(nk_in == 1); assert(nband_in <= psi_in.get_nbands() && nband_in > 0); this->k_first = psi_in.get_k_first(); From a920924927e6e3e6b21faf5e8f2b206351153af5 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 31 Dec 2024 07:16:59 +0000 Subject: [PATCH 46/49] remove Psi(const Psi& psi_in, const int nk_in, const int nband_in); --- source/module_elecstate/cal_dm.h | 7 +- .../module_elecstate/module_dm/cal_dm_psi.cpp | 17 ++++- source/module_io/get_pchg_lcao.cpp | 8 ++- source/module_io/write_dos_lcao.cpp | 9 ++- source/module_io/write_proj_band_lcao.cpp | 8 ++- source/module_psi/psi.cpp | 71 ++++++++++--------- source/module_psi/psi.h | 30 ++++---- 7 files changed, 90 insertions(+), 60 deletions(-) diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 13f41bf455..56aad08f3c 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -27,7 +27,12 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1, nbands_local); + psi::Psi wg_wfc(1, + wfc.get_nbands(), + wfc.get_nbasis(), + wfc.get_nbasis(), + true); + wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size()); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) diff --git a/source/module_elecstate/module_dm/cal_dm_psi.cpp b/source/module_elecstate/module_dm/cal_dm_psi.cpp index cd868dcf9e..21d91e5225 100644 --- a/source/module_elecstate/module_dm/cal_dm_psi.cpp +++ b/source/module_elecstate/module_dm/cal_dm_psi.cpp @@ -32,8 +32,14 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); // dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - - psi::Psi wg_wfc(wfc, 1, nbands_local); + + psi::Psi wg_wfc(1, + wfc.get_nbands(), + wfc.get_nbasis(), + wfc.get_nbasis(), + true); + wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size()); + int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) @@ -90,7 +96,12 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); //dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr); + psi::Psi> wg_wfc(1, + wfc.get_nbands(), + wfc.get_nbasis(), + wfc.get_nbasis(), + true); + const std::complex* pwfc = wfc.get_pointer(); std::complex* pwg_wfc = wg_wfc.get_pointer(); #ifdef _OPENMP diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 9e902e6392..3cea8a3940 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -479,7 +479,13 @@ void IState_Charge::idmatrix(const int& ib, // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); this->psi_gamma->fix_k(is); - psi::Psi wg_wfc(*this->psi_gamma, 1, this->psi_gamma->get_nbands()); + // psi::Psi wg_wfc(*this->psi_gamma, 1, this->psi_gamma->get_nbands()); + psi::Psi wg_wfc(1, + this->psi_gamma->get_nbands(), + this->psi_gamma->get_nbasis(), + this->psi_gamma->get_nbasis(), + true); + wg_wfc.set_all_psi(this->psi_gamma->get_pointer(), wg_wfc.size()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_io/write_dos_lcao.cpp b/source/module_io/write_dos_lcao.cpp index df07cef1d6..015c5bc1c1 100644 --- a/source/module_io/write_dos_lcao.cpp +++ b/source/module_io/write_dos_lcao.cpp @@ -461,12 +461,17 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell, } psi->fix_k(ik); - psi::Psi> Dwfc(*psi, 1, psi->get_nbands()); + + psi::Psi> Dwfc(1, + psi->get_nbands(), + psi->get_nbasis(), + psi->get_nbasis(), + true); std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { - p_dwfc[index] = conj(p_dwfc[index]); + p_dwfc[index] = conj(psi->get_pointer()[index]); } for (int i = 0; i < PARAM.inp.nbands; ++i) diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index b5660f7da5..47d4907b5b 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -225,12 +225,16 @@ void ModuleIO::write_proj_band_lcao( // calculate Mulk psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1, psi->get_nbands()); + psi::Psi> Dwfc(1, + psi->get_nbands(), + psi->get_nbasis(), + psi->get_nbasis(), + true); std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { - p_dwfc[index] = conj(p_dwfc[index]); + p_dwfc[index] = conj(psi->get_pointer()[index]); } for (int i = 0; i < PARAM.inp.nbands; ++i) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 20f61866b7..60778277f5 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -62,9 +62,12 @@ Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i sizeof(T) * nk_in * nbd_in * nbs_in); } - template -Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in) +Psi::Psi(const int nk_in, + const int nbd_in, + const int nbs_in, + const std::vector& ngk_in, + const bool k_first_in) { this->k_first = k_first_in; this->ngk = ngk_in.data(); @@ -155,43 +158,36 @@ Psi::Psi(const int nk_in, sizeof(T) * nk_in * nbd_in * nbs_in); } -template -Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) -{ - assert(nk_in == 1); - assert(nband_in <= psi_in.get_nbands() && nband_in > 0); +// template +// Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) +// { +// assert(nk_in == 1); +// assert(nband_in <= psi_in.get_nbands() && nband_in > 0); - this->k_first = psi_in.get_k_first(); - this->npol = psi_in.npol; - this->allocate_inside = true; +// this->k_first = psi_in.get_k_first(); +// this->npol = psi_in.npol; +// this->allocate_inside = true; - this->nk = nk_in; - this->nbands = nband_in; - this->nbasis = psi_in.get_nbasis(); +// this->nk = nk_in; +// this->nbands = nband_in; +// this->nbasis = psi_in.get_nbasis(); - // This function will delete the psi array first(if psi exist), then malloc a new memory for it. - resize_memory_op()(this->ctx, - this->psi, - (static_cast(this->nk) * static_cast(this->nbands) - * static_cast(this->nbasis)), - "no_record"); - synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); +// // This function will delete the psi array first(if psi exist), then malloc a new memory for it. +// resize_memory_op()(this->ctx, +// this->psi, +// (static_cast(this->nk) * static_cast(this->nbands) +// * static_cast(this->nbasis)), +// "no_record"); +// synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); - this->current_k = 0; - this->current_b = 0; - this->current_nbasis = this->nbasis; - this->psi_current = this->psi; - this->psi_bias = 0; +// this->current_k = 0; +// this->current_b = 0; +// this->current_nbasis = this->nbasis; +// this->psi_current = this->psi; +// this->psi_bias = 0; - if (this->nk != psi_in.get_nk()) - { - this->ngk = nullptr; - } - else - { - this->ngk = psi_in.ngk; - } -} +// this->ngk = nullptr; +// } template Psi::Psi(const Psi& psi_in) @@ -269,6 +265,13 @@ Psi::Psi(const Psi& psi_in) this->psi_current = this->psi + psi_in.get_psi_bias(); } +template +void Psi::set_all_psi(const T* another_pointer, const std::size_t size_in) +{ + assert(size_in == this->size()); + synchronize_memory_op()(this->ctx, this->ctx, this->psi, another_pointer, this->size()); +} + template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index eaabb978b8..bc9c88bf49 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -42,7 +42,7 @@ class Psi // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true); - // Constructor 1-2: + // Constructor 1-2: Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); // Constructor 2-1: initialize a new psi from the given psi_in @@ -63,20 +63,19 @@ class Psi const bool k_first_in = true); // Constructor 3-2: 2D Psi version - Psi(const int nk_in, - const int nbd_in, - const int nbs_in, - const int current_nbasis_in, - const bool k_first_in); + Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in); + // // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in + // Psi(const Psi& psi_in, const int nk_in, const int nband_in); - // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, const int nband_in); - - // Destructor for deleting the psi array manually ~Psi(); + void set_all_psi(const T* another_pointer, const std::size_t size_in); + + // mark + void zero_out(); + // allocate psi for three dimensions void resize(const int nks_in, const int nbands_in, const int nbasis_in); @@ -129,24 +128,21 @@ class Psi // return device type of psi const Device* get_device() const; - + // return psi_bias const int& get_psi_bias() const; const int& get_cur_effective_basis() const; - // mark - void zero_out(); - // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; - + int npol = 1; private: T* psi = nullptr; // avoid using C++ STL - - Device* ctx = {}; // an context identifier for obtaining the device variable + + Device* ctx = {}; // an context identifier for obtaining the device variable // dimensions int nk = 1; // number of k points From b9d01600060f6504d1fd92d9fec412f877894298 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 31 Dec 2024 09:22:29 +0000 Subject: [PATCH 47/49] refactor psi code --- source/module_psi/psi.cpp | 88 +++++++++++++++++++-------------------- source/module_psi/psi.h | 12 +++--- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 60778277f5..826309a9f2 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -28,6 +28,7 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_ range_2 = range_2_in; } +// Constructor 0: basic template Psi::Psi() { @@ -43,16 +44,31 @@ Psi::~Psi() } } +// Constructor 1-1: template Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) { + assert(nk_in > 0); + assert(nbd_in > 0); + assert(nbs_in > 0); + this->k_first = k_first_in; - this->ngk = ngk_in; - this->current_b = 0; - this->current_k = 0; this->npol = PARAM.globalv.npol; + this->allocate_inside = true; - this->resize(nk_in, nbd_in, nbs_in); + this->ngk = ngk_in; // modify later + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. + resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); + + this->nk = nk_in; + this->nbands = nbd_in; + this->nbasis = nbs_in; + + this->current_b = 0; + this->current_k = 0; + this->current_nbasis = nbs_in; + this->psi_current = this->psi; + this->psi_bias = 0; // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); @@ -62,6 +78,7 @@ Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i sizeof(T) * nk_in * nbd_in * nbs_in); } +// Constructor 1-2: template Psi::Psi(const int nk_in, const int nbd_in, @@ -69,13 +86,27 @@ Psi::Psi(const int nk_in, const std::vector& ngk_in, const bool k_first_in) { + assert(nk_in > 0); + assert(nbd_in > 0); + assert(nbs_in > 0); + this->k_first = k_first_in; - this->ngk = ngk_in.data(); - this->current_b = 0; - this->current_k = 0; this->npol = PARAM.globalv.npol; + this->allocate_inside = true; + + this->ngk = ngk_in.data(); // modify later + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. + resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); - this->resize(nk_in, nbd_in, nbs_in); + this->nk = nk_in; + this->nbands = nbd_in; + this->nbasis = nbs_in; + + this->current_b = 0; + this->current_k = 0; + this->current_nbasis = nbs_in; + this->psi_current = this->psi; + this->psi_bias = 0; // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); @@ -85,7 +116,7 @@ Psi::Psi(const int nk_in, sizeof(T) * nk_in * nbd_in * nbs_in); } -// Constructor 8-1: +// Constructor 3-1: 2D Psi version template Psi::Psi(T* psi_pointer, const int nk_in, @@ -94,7 +125,6 @@ Psi::Psi(T* psi_pointer, const int current_nbasis_in, const bool k_first_in) { - // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func @@ -103,7 +133,6 @@ Psi::Psi(T* psi_pointer, this->allocate_inside = false; this->ngk = nullptr; - this->psi = psi_pointer; this->nk = nk_in; @@ -120,7 +149,7 @@ Psi::Psi(T* psi_pointer, base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } -// Constructor 8-3: 2D Psi version 3 +// Constructor 3-2: 2D Psi version template Psi::Psi(const int nk_in, const int nbd_in, @@ -128,7 +157,6 @@ Psi::Psi(const int nk_in, const int current_nbasis_in, const bool k_first_in) { - // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. assert(nk_in == 1); @@ -158,37 +186,7 @@ Psi::Psi(const int nk_in, sizeof(T) * nk_in * nbd_in * nbs_in); } -// template -// Psi::Psi(const Psi& psi_in, const int nk_in, const int nband_in) -// { -// assert(nk_in == 1); -// assert(nband_in <= psi_in.get_nbands() && nband_in > 0); - -// this->k_first = psi_in.get_k_first(); -// this->npol = psi_in.npol; -// this->allocate_inside = true; - -// this->nk = nk_in; -// this->nbands = nband_in; -// this->nbasis = psi_in.get_nbasis(); - -// // This function will delete the psi array first(if psi exist), then malloc a new memory for it. -// resize_memory_op()(this->ctx, -// this->psi, -// (static_cast(this->nk) * static_cast(this->nbands) -// * static_cast(this->nbasis)), -// "no_record"); -// synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); - -// this->current_k = 0; -// this->current_b = 0; -// this->current_nbasis = this->nbasis; -// this->psi_current = this->psi; -// this->psi_bias = 0; - -// this->ngk = nullptr; -// } - +// Constructor 2-1: template Psi::Psi(const Psi& psi_in) { @@ -213,6 +211,8 @@ Psi::Psi(const Psi& psi_in) this->psi_current = this->psi + psi_in.get_psi_bias(); } + +// Constructor 2-2: template template Psi::Psi(const Psi& psi_in) diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index bc9c88bf49..caa3fa6ac9 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -65,22 +65,24 @@ class Psi // Constructor 3-2: 2D Psi version Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in); - // // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - // Psi(const Psi& psi_in, const int nk_in, const int nband_in); - // Destructor for deleting the psi array manually ~Psi(); + // set psi value func 1 void set_all_psi(const T* another_pointer, const std::size_t size_in); - // mark + // set psi value func 2 void zero_out(); + // size_t size() const {return this->psi.size();} + size_t size() const; + // allocate psi for three dimensions void resize(const int nks_in, const int nbands_in, const int nbasis_in); // get the pointer for the 1st index T* get_pointer() const; + // get the pointer for the 2nd index (iband for k_first = true, ik for k_first = false) T* get_pointer(const int& ikb) const; @@ -88,8 +90,6 @@ class Psi const int& get_nk() const; const int& get_nbands() const; const int& get_nbasis() const; - // size_t size() const {return this->psi.size();} - size_t size() const; /// if k_first=true: choose k-point index , then Psi(iband, ibasis) can reach Psi(ik, iband, ibasis) /// if k_first=false: choose k-point index, then Psi(ibasis) can reach Psi(iband, ik, ibasis) From 5373bdc90581fa5492075e131b1df13c093bd70b Mon Sep 17 00:00:00 2001 From: haozhihan Date: Tue, 31 Dec 2024 09:53:50 +0000 Subject: [PATCH 48/49] fix sdft bug --- source/module_psi/psi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 826309a9f2..0372159f0f 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -49,7 +49,7 @@ template Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) { assert(nk_in > 0); - assert(nbd_in > 0); + assert(nbd_in >= 0); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU assert(nbs_in > 0); this->k_first = k_first_in; From 0145476deb68224d2cfe43fbc06baec8fba51523 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Sun, 5 Jan 2025 06:58:36 +0000 Subject: [PATCH 49/49] change to get_current_ngk --- source/module_elecstate/elecstate_pw.cpp | 4 ++-- source/module_elecstate/elecstate_pw_cal_tau.cpp | 2 +- source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp | 2 +- source/module_hsolver/diago_iter_assist.cpp | 8 ++++---- source/module_hsolver/hsolver_pw.cpp | 6 +++--- source/module_hsolver/test/diago_cg_float_test.cpp | 4 ++-- source/module_hsolver/test/diago_cg_real_test.cpp | 4 ++-- source/module_hsolver/test/diago_cg_test.cpp | 4 ++-- source/module_hsolver/test/diago_david_float_test.cpp | 2 +- source/module_hsolver/test/diago_david_real_test.cpp | 2 +- source/module_hsolver/test/diago_david_test.cpp | 2 +- source/module_psi/psi.cpp | 2 +- source/module_psi/psi.h | 2 +- 13 files changed, 22 insertions(+), 22 deletions(-) diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index 69bc14aba4..f55f2ec447 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -183,7 +183,7 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) this->init_rho_data(); int ik = psi.get_current_k(); - int npw = psi.get_cur_effective_basis(); + int npw = psi.get_current_ngk(); int current_spin = 0; if (PARAM.inp.nspin == 2) { @@ -287,7 +287,7 @@ void ElecStatePW::cal_becsum(const psi::Psi& psi) psi.fix_k(ik); const T* psi_now = psi.get_pointer(); const int currect_spin = this->klist->isk[ik]; - const int npw = psi.get_cur_effective_basis(); + const int npw = psi.get_current_ngk(); // get |beta> if (this->ppcell->nkb > 0) diff --git a/source/module_elecstate/elecstate_pw_cal_tau.cpp b/source/module_elecstate/elecstate_pw_cal_tau.cpp index 451aa9688a..ad8c9ce42f 100644 --- a/source/module_elecstate/elecstate_pw_cal_tau.cpp +++ b/source/module_elecstate/elecstate_pw_cal_tau.cpp @@ -15,7 +15,7 @@ void ElecStatePW::cal_tau(const psi::Psi& psi) for (int ik = 0; ik < psi.get_nk(); ++ik) { psi.fix_k(ik); - int npw = psi.get_cur_effective_basis(); + int npw = psi.get_current_ngk(); int current_spin = 0; if (PARAM.inp.nspin == 2) { diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 407879d24f..ec4aa26c1c 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -60,7 +60,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, if (PARAM.inp.nbands > 0) { const int nchipk = stowf.nchip[ik]; - const int npw = psi.get_cur_effective_basis(); + const int npw = psi.get_current_ngk(); const int npwx = psi.get_nbasis(); stowf.chi0->fix_k(ik); stowf.chiortho->fix_k(ik); diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 33986955c3..bdb60ffaff 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -49,7 +49,7 @@ void DiagoIterAssist::diagH_subspace(const hamilt::Hamilt* setmem_complex_op()(ctx, scc, 0, nstart * nstart); setmem_complex_op()(ctx, vcc, 0, nstart * nstart); - const int dmin = psi.get_cur_effective_basis(); + const int dmin = psi.get_current_ngk(); const int dmax = psi.get_nbasis(); T* temp = nullptr; @@ -167,7 +167,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* const int nstart = psi_nr; const int n_band = evc.get_nbands(); const int dmax = evc.get_nbasis(); - const int dmin = evc.get_cur_effective_basis(); + const int dmin = evc.get_current_ngk(); // skip the diagonalization if the operators are not allocated if (pHamilt->ops == nullptr) @@ -425,7 +425,7 @@ void DiagoIterAssist::cal_hs_subspace(const hamilt::Hamilt setmem_complex_op()(ctx, hcc, 0, nstart * nstart); setmem_complex_op()(ctx, scc, 0, nstart * nstart); - const int dmin = psi.get_cur_effective_basis(); + const int dmin = psi.get_current_ngk(); const int dmax = psi.get_nbasis(); T* temp = nullptr; @@ -551,7 +551,7 @@ void DiagoIterAssist::diag_subspace_psi(const T* hcc, DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc); { // code block to calculate tar_mat - const int dmin = evc.get_cur_effective_basis(); + const int dmin = evc.get_current_ngk(); const int dmax = evc.get_nbasis(); T* temp = nullptr; resmem_complex_op()(ctx, temp, nstart * dmax, "DiagSub::temp"); diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 21c9fc9bfc..dbfca81061 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -473,7 +473,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ct::DeviceTypeToEnum::value, ct::TensorShape({static_cast(pre_condition.size())})) .to_device() - .slice({0}, {psi.get_cur_effective_basis()}); + .slice({0}, {psi.get_current_ngk()}); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor); // TODO: Double check tensormap's potential problem @@ -523,7 +523,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, Diago_DavSubspace dav_subspace(pre_condition, psi.get_nbands(), - psi.get_k_first() ? psi.get_cur_effective_basis() + psi.get_k_first() ? psi.get_current_ngk() : psi.get_nk() * psi.get_nbasis(), PARAM.inp.pw_diag_ndim, this->diag_thr, @@ -549,7 +549,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int david_maxiter = this->diag_iter_max; // dimensions of matrix to be solved - const int dim = psi.get_cur_effective_basis(); /// dimension of matrix + const int dim = psi.get_current_ngk(); /// dimension of matrix const int nband = psi.get_nbands(); /// number of eigenpairs sought const int ld_psi = psi.get_nbasis(); /// leading dimension of psi diff --git a/source/module_hsolver/test/diago_cg_float_test.cpp b/source/module_hsolver/test/diago_cg_float_test.cpp index 0500424b92..29fadb84bc 100644 --- a/source/module_hsolver/test/diago_cg_float_test.cpp +++ b/source/module_hsolver/test/diago_cg_float_test.cpp @@ -182,7 +182,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_COMPLEX, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_FLOAT, @@ -192,7 +192,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_FLOAT, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()}); + ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_cg_real_test.cpp b/source/module_hsolver/test/diago_cg_real_test.cpp index f6aa978620..7ee33a7e99 100644 --- a/source/module_hsolver/test/diago_cg_real_test.cpp +++ b/source/module_hsolver/test/diago_cg_real_test.cpp @@ -185,7 +185,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_DOUBLE, @@ -195,7 +195,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()}); + ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_cg_test.cpp b/source/module_hsolver/test/diago_cg_test.cpp index 5d144ae9fb..5783c74c12 100644 --- a/source/module_hsolver/test/diago_cg_test.cpp +++ b/source/module_hsolver/test/diago_cg_test.cpp @@ -176,7 +176,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_cur_effective_basis()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_DOUBLE, @@ -186,7 +186,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_cur_effective_basis())})).slice({0}, {psi_local.get_cur_effective_basis()}); + ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_david_float_test.cpp b/source/module_hsolver/test/diago_david_float_test.cpp index 37930da8e6..0f05717511 100644 --- a/source/module_hsolver/test/diago_david_float_test.cpp +++ b/source/module_hsolver/test/diago_david_float_test.cpp @@ -90,7 +90,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_cur_effective_basis() ; + const int dim = phi.get_current_ngk() ; const int nband = phi.get_nbands(); const int ld_psi =phi.get_nbasis(); hsolver::DiagoDavid> dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_hsolver/test/diago_david_real_test.cpp b/source/module_hsolver/test/diago_david_real_test.cpp index 2a0103fe49..634b2ab83b 100644 --- a/source/module_hsolver/test/diago_david_real_test.cpp +++ b/source/module_hsolver/test/diago_david_real_test.cpp @@ -89,7 +89,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_cur_effective_basis(); + const int dim = phi.get_current_ngk(); const int nband = phi.get_nbands(); const int ld_psi = phi.get_nbasis(); hsolver::DiagoDavid dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_hsolver/test/diago_david_test.cpp b/source/module_hsolver/test/diago_david_test.cpp index 542deeb663..01c0e62a42 100644 --- a/source/module_hsolver/test/diago_david_test.cpp +++ b/source/module_hsolver/test/diago_david_test.cpp @@ -92,7 +92,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_cur_effective_basis(); + const int dim = phi.get_current_ngk(); const int nband = phi.get_nbands(); const int ld_psi = phi.get_nbasis(); hsolver::DiagoDavid> dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 0372159f0f..7942b412c9 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -329,7 +329,7 @@ const int& Psi::get_psi_bias() const } template -const int& Psi::get_cur_effective_basis() const +const int& Psi::get_current_ngk() const { if (this->npol == 1) { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index caa3fa6ac9..d8a994377a 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -132,7 +132,7 @@ class Psi // return psi_bias const int& get_psi_bias() const; - const int& get_cur_effective_basis() const; + const int& get_current_ngk() const; // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const;