diff --git a/source/module_elecstate/cal_dm.h b/source/module_elecstate/cal_dm.h index 5ac41aab9a..56aad08f3c 100644 --- a/source/module_elecstate/cal_dm.h +++ b/source/module_elecstate/cal_dm.h @@ -27,7 +27,12 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, //dm.fix_k(ik); dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + psi::Psi wg_wfc(1, + wfc.get_nbands(), + wfc.get_nbasis(), + wfc.get_nbasis(), + true); + wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size()); int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) @@ -41,7 +46,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); @@ -99,7 +105,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) continue; + if (ib_global >= wg.nc) { continue; +} const double wg_local = wg(ik, ib_global); std::complex* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index 5558856289..f55f2ec447 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -183,7 +183,7 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) this->init_rho_data(); int ik = psi.get_current_k(); - int npw = psi.get_current_nbas(); + int npw = psi.get_current_ngk(); int current_spin = 0; if (PARAM.inp.nspin == 2) { @@ -287,7 +287,7 @@ void ElecStatePW::cal_becsum(const psi::Psi& psi) psi.fix_k(ik); const T* psi_now = psi.get_pointer(); const int currect_spin = this->klist->isk[ik]; - const int npw = psi.get_current_nbas(); + const int npw = psi.get_current_ngk(); // get |beta> if (this->ppcell->nkb > 0) diff --git a/source/module_elecstate/elecstate_pw_cal_tau.cpp b/source/module_elecstate/elecstate_pw_cal_tau.cpp index fd07f834af..ad8c9ce42f 100644 --- a/source/module_elecstate/elecstate_pw_cal_tau.cpp +++ b/source/module_elecstate/elecstate_pw_cal_tau.cpp @@ -15,7 +15,7 @@ void ElecStatePW::cal_tau(const psi::Psi& psi) for (int ik = 0; ik < psi.get_nk(); ++ik) { psi.fix_k(ik); - int npw = psi.get_current_nbas(); + int npw = psi.get_current_ngk(); int current_spin = 0; if (PARAM.inp.nspin == 2) { diff --git a/source/module_elecstate/module_dm/cal_dm_psi.cpp b/source/module_elecstate/module_dm/cal_dm_psi.cpp index 47fbfbf8c3..21d91e5225 100644 --- a/source/module_elecstate/module_dm/cal_dm_psi.cpp +++ b/source/module_elecstate/module_dm/cal_dm_psi.cpp @@ -32,7 +32,14 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); // dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi wg_wfc(wfc, 1); + + psi::Psi wg_wfc(1, + wfc.get_nbands(), + wfc.get_nbasis(), + wfc.get_nbasis(), + true); + wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size()); + int ib_global = 0; for (int ib_local = 0; ib_local < nbands_local; ++ib_local) @@ -89,7 +96,12 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // dm.fix_k(ik); //dm[ik].create(ParaV->ncol, ParaV->nrow); // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); - psi::Psi> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr); + psi::Psi> wg_wfc(1, + wfc.get_nbands(), + wfc.get_nbasis(), + wfc.get_nbasis(), + true); + const std::complex* pwfc = wfc.get_pointer(); std::complex* pwg_wfc = wg_wfc.get_pointer(); #ifdef _OPENMP diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index b9fb62e853..8c87cc352b 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -1083,7 +1083,7 @@ void ESolver_KS_LCAO::after_scf(UnitCell& ucell, const int istep) //! initialize the gradients of Etotal with respect to occupation numbers and wfc, //! and set all elements to 0. ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true); - psi::Psi dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis()); + psi::Psi dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis(), this->kv.ngk, true); dE_dWfc.zero_out(); double Etotal_RDMFT = this->rdmft_solver.run(dE_dOccNum, dE_dWfc); diff --git a/source/module_esolver/esolver_of.cpp b/source/module_esolver/esolver_of.cpp index fc8d60c6d0..eb86e8f6ef 100644 --- a/source/module_esolver/esolver_of.cpp +++ b/source/module_esolver/esolver_of.cpp @@ -222,7 +222,11 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell) // Refresh the arrays delete this->psi_; - this->psi_ = new psi::Psi(1, PARAM.inp.nspin, this->pw_rho->nrxx); + this->psi_ = new psi::Psi(1, + PARAM.inp.nspin, + this->pw_rho->nrxx, + this->pw_rho->nrxx, + true); for (int is = 0; is < PARAM.inp.nspin; ++is) { this->pphi_[is] = this->psi_->get_pointer(is); diff --git a/source/module_esolver/esolver_of_tool.cpp b/source/module_esolver/esolver_of_tool.cpp index 750598ab2e..e430347215 100644 --- a/source/module_esolver/esolver_of_tool.cpp +++ b/source/module_esolver/esolver_of_tool.cpp @@ -71,7 +71,11 @@ void ESolver_OF::init_elecstate(UnitCell& ucell) void ESolver_OF::allocate_array() { // Initialize the "wavefunction", which is sqrt(rho) - this->psi_ = new psi::Psi(1, PARAM.inp.nspin, this->pw_rho->nrxx); + this->psi_ = new psi::Psi(1, + PARAM.inp.nspin, + this->pw_rho->nrxx, + this->pw_rho->nrxx, + true); ModuleBase::Memory::record("OFDFT::Psi", sizeof(double) * PARAM.inp.nspin * this->pw_rho->nrxx); this->pphi_ = new double*[PARAM.inp.nspin]; for (int is = 0; is < PARAM.inp.nspin; ++is) diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a99e813e01..008d5e30e3 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -4,28 +4,31 @@ using namespace hamilt; - -template -Operator::Operator(){} - -template -Operator::~Operator() +template +Operator::Operator() { - if(this->hpsi != nullptr) { delete this->hpsi; } + +template +Operator::~Operator() +{ + if (this->hpsi != nullptr) + { + delete this->hpsi; + } Operator* last = this->next_op; Operator* last_sub = this->next_sub_op; - while(last != nullptr || last_sub != nullptr) + while (last != nullptr || last_sub != nullptr) { - if(last_sub != nullptr) - {//delete sub_chain first + if (last_sub != nullptr) + { // delete sub_chain first Operator* node_delete = last_sub; last_sub = last_sub->next_sub_op; node_delete->next_sub_op = nullptr; delete node_delete; } else - {//delete main chain if sub_chain is deleted + { // delete main chain if sub_chain is deleted Operator* node_delete = last; last_sub = last->next_sub_op; node_delete->next_sub_op = nullptr; @@ -36,7 +39,7 @@ Operator::~Operator() } } -template +template typename Operator::hpsi_info Operator::hPsi(hpsi_info& input) const { using syncmem_op = base_device::memory::synchronize_memory_op; @@ -46,37 +49,51 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp T* tmhpsi = this->get_hpsi(input); const T* tmpsi_in = std::get<0>(psi_info); - //if range in hpsi_info is illegal, the first return of to_range() would be nullptr + // if range in hpsi_info is illegal, the first return of to_range() would be nullptr if (tmpsi_in == nullptr) { ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!"); } - //if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return + // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return T* hpsi_pointer = std::get<2>(input); if (this->in_place) { // ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size()); syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size()); delete this->hpsi; - this->hpsi = new psi::Psi(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol); + this->hpsi = new psi::Psi(hpsi_pointer, + 1, + nbands / psi_input->npol, + psi_input->get_nbasis(), + psi_input->get_nbasis(), + true); } auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void { - // a "psi" with the bands of needed range - psi::Psi psi_wrapper(const_cast(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true); - - + psi::Psi psi_wrapper(const_cast(tmpsi_in), + 1, + nbands, + psi_input->get_nbasis(), + psi_input->get_nbasis(), + true); + switch (op->get_act_type()) { case 2: op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node); + op->act(nbands, + psi_input->get_nbasis(), + psi_input->npol, + tmpsi_in, + this->hpsi->get_pointer(), + psi_input->get_current_nbas(), + is_first_node); break; } - }; + }; ModuleBase::timer::tick("Operator", "hPsi"); call_act(this, true); // first node @@ -91,39 +108,43 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); } - -template -void Operator::init(const int ik_in) +template +void Operator::init(const int ik_in) { this->ik = ik_in; - if(this->next_op != nullptr) { + if (this->next_op != nullptr) + { this->next_op->init(ik_in); } } -template -void Operator::add(Operator* next) +template +void Operator::add(Operator* next) { - if(next==nullptr) { return; -} + if (next == nullptr) + { + return; + } next->is_first_node = false; - if(next->next_op != nullptr) { this->add(next->next_op); -} + if (next->next_op != nullptr) + { + this->add(next->next_op); + } Operator* last = this; - //loop to end of the chain - while(last->next_op != nullptr) + // loop to end of the chain + while (last->next_op != nullptr) { - if(next->cal_type==last->cal_type) + if (next->cal_type == last->cal_type) { break; } last = last->next_op; } - if(next->cal_type == last->cal_type) + if (next->cal_type == last->cal_type) { - //insert next to sub chain of current node + // insert next to sub chain of current node Operator* sub_last = last; - while(sub_last->next_sub_op != nullptr) + while (sub_last->next_sub_op != nullptr) { sub_last = sub_last->next_sub_op; } @@ -136,34 +157,45 @@ void Operator::add(Operator* next) } } -template +template T* Operator::get_hpsi(const hpsi_info& info) const { const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1); - //in_place call of hPsi, hpsi inputs as new psi, - //create a new hpsi and delete old hpsi later + // in_place call of hPsi, hpsi inputs as new psi, + // create a new hpsi and delete old hpsi later T* hpsi_pointer = std::get<2>(info); const T* psi_pointer = std::get<0>(info)->get_pointer(); - if(this->hpsi != nullptr) + if (this->hpsi != nullptr) { delete this->hpsi; this->hpsi = nullptr; } - if(!hpsi_pointer) + if (!hpsi_pointer) { ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr"); } - else if(hpsi_pointer == psi_pointer) + else if (hpsi_pointer == psi_pointer) { this->in_place = true; - this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + // this->hpsi = new psi::Psi(std::get<0>(info)[0], 1, nbands_range); + this->hpsi = new psi::Psi(1, + nbands_range, + std::get<0>(info)->get_nbasis(), + std::get<0>(info)->get_nbasis(), + true); } else { this->in_place = false; - this->hpsi = new psi::Psi(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range); + + this->hpsi = new psi::Psi(hpsi_pointer, + 1, + nbands_range, + std::get<0>(info)->get_nbasis(), + std::get<0>(info)->get_nbasis(), + true); } - + hpsi_pointer = this->hpsi->get_pointer(); size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis(); // ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size); @@ -172,7 +204,8 @@ T* Operator::get_hpsi(const hpsi_info& info) const return hpsi_pointer; } -namespace hamilt { +namespace hamilt +{ template class Operator; template class Operator, base_device::DEVICE_CPU>; template class Operator; @@ -183,4 +216,4 @@ template class Operator, base_device::DEVICE_GPU>; template class Operator; template class Operator, base_device::DEVICE_GPU>; #endif -} +} // namespace hamilt diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 6cf29122fe..80ed065ccc 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -1,19 +1,19 @@ #ifndef OPERATOR_H #define OPERATOR_H -#include - #include "module_base/global_function.h" #include "module_base/tool_quit.h" #include "module_psi/psi.h" +#include + namespace hamilt { enum class calculation_type { no, - pw_ekinetic, + pw_ekinetic, pw_nonlocal, pw_veff, pw_meta, @@ -28,49 +28,54 @@ enum class calculation_type lcao_tddft_velocity, }; -// Basic class for operator module, +// Basic class for operator module, // it is designed for "O|psi>" and "" // Operator "O" might have several different types, which should be calculated one by one. // In basic class , function add() is designed for combine all operators together with a chain. template class Operator { - public: + public: Operator(); virtual ~Operator(); - //this is the core function for Operator - // do H|psi> from input |psi> , + // this is the core function for Operator + // do H|psi> from input |psi> , /// as default, different operators donate hPsi independently - /// run this->act function for the first operator and run all act() for other nodes in chain table + /// run this->act function for the first operator and run all act() for other nodes in chain table /// if this procedure is not suitable for your operator, just override this function. - /// output of hpsi would be first member of the returned tuple + /// output of hpsi would be first member of the returned tuple typedef std::tuple*, const psi::Range, T*> hpsi_info; - virtual hpsi_info hPsi(hpsi_info& input)const; + + virtual hpsi_info hPsi(hpsi_info& input) const; virtual void init(const int ik_in); virtual void add(Operator* next); - virtual int get_ik() const { return this->ik; } + virtual int get_ik() const + { + return this->ik; + } - ///do operation : |hpsi_choosed> = V|psi_choosed> - ///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi - /// interface type 1: pointer-only (default) - /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! + /// do operation : |hpsi_choosed> = V|psi_choosed> + /// V is the target operator act on choosed psi, the consequence should be added to choosed hpsi + /// interface type 1: pointer-only (default) + /// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!! virtual void act(const int nbands, - const int nbasis, - const int npol, - const T* tmpsi_in, - T* tmhpsi, - const int ngk_ik = 0, - const bool is_first_node = false)const {}; + const int nbasis, + const int npol, + const T* tmpsi_in, + T* tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const {}; /// developer-friendly interfaces for act() function /// interface type 2: input and change the Psi-type HPsi // virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out) const {}; virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out, const int nbands) const {}; + /// interface type 3: return a Psi-type HPsi // virtual psi::Psi act(const psi::Psi& psi_in) const { return psi_in; }; @@ -78,36 +83,41 @@ class Operator /// type 1 (default): pointer-only /// act(const T* psi_in, T* psi_out) - /// type 2: use the `Psi`class + /// type 2: use the `Psi`class /// act(const Psi& psi_in, Psi& psi_out) - int get_act_type() const { return this->act_type; } -protected: + int get_act_type() const + { + return this->act_type; + } + + protected: int ik = 0; - int act_type = 1; ///< determine which act() interface would be called in hPsi() + int act_type = 1; ///< determine which act() interface would be called in hPsi() mutable bool in_place = false; - //calculation type, only different type can be in main chain table + // calculation type, only different type can be in main chain table enum calculation_type cal_type; Operator* next_sub_op = nullptr; bool is_first_node = true; - //if this Operator is first node in chain table, hpsi would not be empty + // if this Operator is first node in chain table, hpsi would not be empty mutable psi::Psi* hpsi = nullptr; /*This function would analyze hpsi_info and choose how to arrange hpsi storage In hpsi_info, if the third parameter hpsi_pointer is set, which indicates memory of hpsi is arranged by developer; - if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. + if hpsi_pointer is not set(nullptr), which indicates memory of hpsi is arranged by Operator, this case is rare. two cases would occurred: - 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary memory - 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case + 1. hpsi_pointer != nullptr && psi_pointer == hpsi_pointer , psi would be replaced by hpsi, hpsi need a temporary + memory + 2. hpsi_pointer != nullptr && psi_pointer != hpsi_pointer , this is the commonly case */ - T* get_hpsi(const hpsi_info& info)const; + T* get_hpsi(const hpsi_info& info) const; - Device *ctx = {}; + Device* ctx = {}; using set_memory_op = base_device::memory::set_memory_op; }; -}//end namespace hamilt +} // end namespace hamilt #endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 412534df6a..f49a9b6702 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -23,7 +23,8 @@ Velocity::Velocity this->ucell = ucell_in; this->nonlocal = nonlocal_in; this->tpiba = ucell_in -> tpiba; - if(this->nonlocal) this->ppcell->initgradq_vnl(*this->ucell); + if(this->nonlocal) { this->ppcell->initgradq_vnl(*this->ucell); +} } void Velocity::init(const int ik_in) @@ -47,7 +48,9 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - const int npw = psi_in->get_ngk(this->ik); + + const int npw = psi_in->get_current_nbas(); + const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp index 8692a586ee..c0e5d61d3f 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp @@ -172,9 +172,9 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi>& kspsi_all, const int allbands = bandinfo[5]; const int dim_jmatrix = perbands_ks * allbands_sto + perbands_sto * allbands; - psi::Psi> right_hchi(1, perbands_sto, npwx, p_kv->ngk.data()); - psi::Psi> f_rightchi(1, perbands_sto, npwx, p_kv->ngk.data()); - psi::Psi> f_right_hchi(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> right_hchi(1, perbands_sto, npwx, npw, true); + psi::Psi> f_rightchi(1, perbands_sto, npwx, npw, true); + psi::Psi> f_right_hchi(1, perbands_sto, npwx, npw, true); this->p_hamilt_sto->hPsi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto); this->p_hamilt_sto->hPsi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto); @@ -206,8 +206,8 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi>& kspsi_all, } #endif - psi::Psi> f_batch_vchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); - psi::Psi> f_batch_vhchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); + psi::Psi> f_batch_vchi(1, bsize_psi * ndim, npwx, npw, true); + psi::Psi> f_batch_vhchi(1, bsize_psi * ndim, npwx, npw, true); std::vector> tmpj(ndim * allbands_sto * perbands_sto); // 1. (<\psi|J|\chi>)^T @@ -663,19 +663,19 @@ void Sto_EleCond::sKG(const int& smear_type, //----------------------------------------------------------- //------------------- allocate ------------------------- size_t ks_memory_cost = perbands_ks * npwx * sizeof(std::complex); - psi::Psi> kspsi(1, perbands_ks, npwx, p_kv->ngk.data()); - psi::Psi> vkspsi(1, perbands_ks * ndim, npwx, p_kv->ngk.data()); + psi::Psi> kspsi(1, perbands_ks, npwx, npw, true); + psi::Psi> vkspsi(1, perbands_ks * ndim, npwx, npw, true); std::vector> expmtmf_fact(perbands_ks), expmtf_fact(perbands_ks); - psi::Psi> f_kspsi(1, perbands_ks, npwx, p_kv->ngk.data()); + psi::Psi> f_kspsi(1, perbands_ks, npwx, npw, true); ModuleBase::Memory::record("SDFT::kspsi", ks_memory_cost); - psi::Psi> f_vkspsi(1, perbands_ks * ndim, npwx, p_kv->ngk.data()); + psi::Psi> f_vkspsi(1, perbands_ks * ndim, npwx, npw, true); ModuleBase::Memory::record("SDFT::vkspsi", ks_memory_cost); psi::Psi>* kspsi_all = &f_kspsi; size_t sto_memory_cost = perbands_sto * npwx * sizeof(std::complex); - psi::Psi> sfchi(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> sfchi(1, perbands_sto, npwx, npw, true); ModuleBase::Memory::record("SDFT::sfchi", sto_memory_cost); - psi::Psi> smfchi(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> smfchi(1, perbands_sto, npwx, npw, true); ModuleBase::Memory::record("SDFT::smfchi", sto_memory_cost); #ifdef __MPI psi::Psi> chi_all, hchi_all, psi_all; @@ -702,8 +702,8 @@ void Sto_EleCond::sKG(const int& smear_type, const int nbatch_psi = npart_sto; const int bsize_psi = ceil(double(perbands_sto) / nbatch_psi); - psi::Psi> batch_vchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); - psi::Psi> batch_vhchi(1, bsize_psi * ndim, npwx, p_kv->ngk.data()); + psi::Psi> batch_vchi(1, bsize_psi * ndim, npwx, npw, true); + psi::Psi> batch_vhchi(1, bsize_psi * ndim, npwx, npw, true); ModuleBase::Memory::record("SDFT::batchjpsi", 3 * bsize_psi * ndim * npwx * sizeof(std::complex)); //------------------- sqrt(f)|psi> sqrt(1-f)|psi> --------------- @@ -781,7 +781,7 @@ void Sto_EleCond::sKG(const int& smear_type, std::vector> j1r(ndim * dim_jmatrix), j2r(ndim * dim_jmatrix); ModuleBase::Memory::record("SDFT::j1r", sizeof(std::complex) * ndim * dim_jmatrix); ModuleBase::Memory::record("SDFT::j2r", sizeof(std::complex) * ndim * dim_jmatrix); - psi::Psi> tmphchil(1, perbands_sto, npwx, p_kv->ngk.data()); + psi::Psi> tmphchil(1, perbands_sto, npwx, npw, true); ModuleBase::Memory::record("SDFT::tmphchil/r", sto_memory_cost * 2); //------------------------ t loop -------------------------- diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index bcfbd2da61..ec4aa26c1c 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -60,7 +60,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, if (PARAM.inp.nbands > 0) { const int nchipk = stowf.nchip[ik]; - const int npw = psi.get_current_nbas(); + const int npw = psi.get_current_ngk(); const int npwx = psi.get_nbasis(); stowf.chi0->fix_k(ik); stowf.chiortho->fix_k(ik); diff --git a/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp b/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp index 2e429dfa3f..f343c93c81 100644 --- a/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/test/test_sto_tool.cpp @@ -68,8 +68,8 @@ TEST_F(TestStoTool, parallel_distribution) TEST_F(TestStoTool, convert_psi) { - psi::Psi> psi_in(1, 1, 10); - psi::Psi> psi_out(1, 1, 10); + psi::Psi> psi_in(1, 1, 10, 10, true); + psi::Psi> psi_out(1, 1, 10, 10, true); for (int i = 0; i < 10; ++i) { psi_in.get_pointer()[i] = std::complex(i, i); @@ -83,8 +83,8 @@ TEST_F(TestStoTool, convert_psi) TEST_F(TestStoTool, gatherchi) { - psi::Psi> chi(1, 1, 10); - psi::Psi> chi_all(1, 1, 10); + psi::Psi> chi(1, 1, 10, 10, true); + psi::Psi> chi_all(1, 1, 10, 10, true); int npwx = 10; int nrecv_sto[4] = {1, 2, 3, 4}; int displs_sto[4] = {0, 1, 3, 6}; diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 5ec443ab4e..bdb60ffaff 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -49,7 +49,7 @@ void DiagoIterAssist::diagH_subspace(const hamilt::Hamilt* setmem_complex_op()(ctx, scc, 0, nstart * nstart); setmem_complex_op()(ctx, vcc, 0, nstart * nstart); - const int dmin = psi.get_current_nbas(); + const int dmin = psi.get_current_ngk(); const int dmax = psi.get_nbasis(); T* temp = nullptr; @@ -167,7 +167,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* const int nstart = psi_nr; const int n_band = evc.get_nbands(); const int dmax = evc.get_nbasis(); - const int dmin = evc.get_current_nbas(); + const int dmin = evc.get_current_ngk(); // skip the diagonalization if the operators are not allocated if (pHamilt->ops == nullptr) @@ -199,7 +199,8 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, 1, psi_nc, dmin, true); + T* ppsi = psi_temp.get_pointer(); // hpsi and spsi share the temp space T* temp = nullptr; @@ -212,7 +213,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* { // psi_temp is one band psi, psi is all bands psi, the range always is 1 for the only band in psi_temp syncmem_complex_op()(ctx, ctx, ppsi, psi + i * psi_nc, psi_nc); - psi::Range band_by_band_range(1, 0, 0, 0); + psi::Range band_by_band_range(true, 0, 0, 0); hpsi_info hpsi_in(&psi_temp, band_by_band_range, hpsi); // H|Psi> to get hpsi for target band @@ -246,7 +247,8 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* } else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { - psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, nstart, psi_nc, dmin, true); + T* ppsi = psi_temp.get_pointer(); syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size()); // hpsi and spsi share the temp space @@ -256,7 +258,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* T* hpsi = temp; // do hPsi for all bands - psi::Range all_bands_range(1, 0, 0, nstart - 1); + psi::Range all_bands_range(true, 0, 0, nstart - 1); hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); pHamilt->ops->hPsi(hpsi_in); @@ -264,7 +266,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* T* spsi = temp; // do sPsi for all bands - pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_current_nbas(), psi_temp.get_nbands()); + pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_nbasis(), psi_temp.get_nbands()); gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart); delmem_complex_op()(ctx, temp); @@ -423,7 +425,7 @@ void DiagoIterAssist::cal_hs_subspace(const hamilt::Hamilt setmem_complex_op()(ctx, hcc, 0, nstart * nstart); setmem_complex_op()(ctx, scc, 0, nstart * nstart); - const int dmin = psi.get_current_nbas(); + const int dmin = psi.get_current_ngk(); const int dmax = psi.get_nbasis(); T* temp = nullptr; @@ -549,7 +551,7 @@ void DiagoIterAssist::diag_subspace_psi(const T* hcc, DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc); { // code block to calculate tar_mat - const int dmin = evc.get_current_nbas(); + const int dmin = evc.get_current_ngk(); const int dmax = evc.get_nbasis(); T* temp = nullptr; resmem_complex_op()(ctx, temp, nstart * dmax, "DiagSub::temp"); @@ -586,8 +588,9 @@ bool DiagoIterAssist::test_exit_cond(const int& ntry, const int& notc //================================================================ bool scf = true; - if (PARAM.inp.calculation == "nscf") + if (PARAM.inp.calculation == "nscf") { scf = false; +} // If ntry <=5, try to do it better, if ntry > 5, exit. const bool f1 = (ntry <= 5); diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 0c1ad2e8b8..dbfca81061 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -310,7 +310,11 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, #endif /// solve eigenvector and eigenvalue for H(k) - this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands(), this->wfc_basis->nks); + this->hamiltSolvePsiK(pHamilt, + psi, + precondition, + eigenvalues.data() + ik * psi.get_nbands(), + this->wfc_basis->nks); if (skip_charge) { @@ -370,18 +374,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool}; #endif - auto ngk_pointer = psi.get_ngk_pointer(); - - std::vector ngk_vector(nk_nums, 0); - for (int i = 0; i < nk_nums; i++) - { - ngk_vector[i] = ngk_pointer[i]; - } + const int cur_nbasis = psi.get_current_nbas(); if (this->method == "cg") { // wrap the subspace_func into a lambda function - auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) { + auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) { // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); @@ -391,12 +389,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, psi_in.shape().dim_size(0), psi_in.shape().dim_size(1), - ngk_vector); + cur_nbasis); auto psi_out_wrapper = psi::Psi(psi_out.data(), 1, psi_out.shape().dim_size(0), psi_out.shape().dim_size(1), - ngk_vector); + cur_nbasis); auto eigen = ct::Tensor(ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, ct::TensorShape({psi_in.shape().dim_size(0)})); @@ -415,7 +413,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, using ct_Device = typename ct::PsiToContainer::type; // wrap the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { + auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] @@ -426,7 +424,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, 1, ndim == 1 ? 1 : psi_in.shape().dim_size(0), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_vector); + cur_nbasis); psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); @@ -475,7 +473,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ct::DeviceTypeToEnum::value, ct::TensorShape({static_cast(pre_condition.size())})) .to_device() - .slice({0}, {psi.get_current_nbas()}); + .slice({0}, {psi.get_current_ngk()}); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor); // TODO: Double check tensormap's potential problem @@ -486,11 +484,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); const int nbasis = psi.get_nbasis(); // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -507,11 +505,11 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, else if (this->method == "dav_subspace") { // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); @@ -525,7 +523,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, Diago_DavSubspace dav_subspace(pre_condition, psi.get_nbands(), - psi.get_k_first() ? psi.get_current_nbas() + psi.get_k_first() ? psi.get_current_ngk() : psi.get_nk() * psi.get_nbasis(), PARAM.inp.pw_diag_ndim, this->diag_thr, @@ -551,18 +549,18 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int david_maxiter = this->diag_iter_max; // dimensions of matrix to be solved - const int dim = psi.get_current_nbas(); /// dimension of matrix - const int nband = psi.get_nbands(); /// number of eigenpairs sought - const int ld_psi = psi.get_nbasis(); /// leading dimension of psi + const int dim = psi.get_current_ngk(); /// dimension of matrix + const int nband = psi.get_nbands(); /// number of eigenpairs sought + const int ld_psi = psi.get_nbasis(); /// leading dimension of psi // Davidson matrix-blockvector functions /// wrap hpsi into lambda function, Matrix \times blockvector // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { ModuleBase::timer::tick("David", "hpsi_func"); // Convert pointer of psi_in to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, ngk_vector); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); psi::Range bands_range(true, 0, 0, nvec - 1); diff --git a/source/module_hsolver/test/diago_cg_float_test.cpp b/source/module_hsolver/test/diago_cg_float_test.cpp index 47fac4ef01..29fadb84bc 100644 --- a/source/module_hsolver/test/diago_cg_float_test.cpp +++ b/source/module_hsolver/test/diago_cg_float_test.cpp @@ -182,7 +182,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_COMPLEX, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_FLOAT, @@ -192,7 +192,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_FLOAT, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()}); + ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_cg_real_test.cpp b/source/module_hsolver/test/diago_cg_real_test.cpp index 97872c316d..7ee33a7e99 100644 --- a/source/module_hsolver/test/diago_cg_real_test.cpp +++ b/source/module_hsolver/test/diago_cg_real_test.cpp @@ -185,7 +185,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_DOUBLE, @@ -195,7 +195,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()}); + ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_cg_test.cpp b/source/module_hsolver/test/diago_cg_test.cpp index 08912bc428..5783c74c12 100644 --- a/source/module_hsolver/test/diago_cg_test.cpp +++ b/source/module_hsolver/test/diago_cg_test.cpp @@ -176,7 +176,7 @@ class DiagoCGPrepare psi_local.get_pointer(), ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_nbas()}); + ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); auto eigen_tensor = ct::TensorMap( en, ct::DataType::DT_DOUBLE, @@ -186,7 +186,7 @@ class DiagoCGPrepare precondition_local, ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()}); + ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); diff --git a/source/module_hsolver/test/diago_david_float_test.cpp b/source/module_hsolver/test/diago_david_float_test.cpp index c3feeea246..0f05717511 100644 --- a/source/module_hsolver/test/diago_david_float_test.cpp +++ b/source/module_hsolver/test/diago_david_float_test.cpp @@ -90,7 +90,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_current_nbas() ; + const int dim = phi.get_current_ngk() ; const int nband = phi.get_nbands(); const int ld_psi =phi.get_nbasis(); hsolver::DiagoDavid> dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_hsolver/test/diago_david_real_test.cpp b/source/module_hsolver/test/diago_david_real_test.cpp index a1c4dee958..634b2ab83b 100644 --- a/source/module_hsolver/test/diago_david_real_test.cpp +++ b/source/module_hsolver/test/diago_david_real_test.cpp @@ -89,7 +89,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_current_nbas(); + const int dim = phi.get_current_ngk(); const int nband = phi.get_nbands(); const int ld_psi = phi.get_nbasis(); hsolver::DiagoDavid dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_hsolver/test/diago_david_test.cpp b/source/module_hsolver/test/diago_david_test.cpp index 71005a78b9..01c0e62a42 100644 --- a/source/module_hsolver/test/diago_david_test.cpp +++ b/source/module_hsolver/test/diago_david_test.cpp @@ -92,7 +92,7 @@ class DiagoDavPrepare const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif - const int dim = phi.get_current_nbas(); + const int dim = phi.get_current_ngk(); const int nband = phi.get_nbands(); const int ld_psi = phi.get_nbasis(); hsolver::DiagoDavid> dav(precondition, nband, dim, order, false, comm_info); diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 6e069fd017..3cea8a3940 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -478,7 +478,14 @@ void IState_Charge::idmatrix(const int& ib, // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw); this->psi_gamma->fix_k(is); - psi::Psi wg_wfc(*this->psi_gamma, 1); + + // psi::Psi wg_wfc(*this->psi_gamma, 1, this->psi_gamma->get_nbands()); + psi::Psi wg_wfc(1, + this->psi_gamma->get_nbands(), + this->psi_gamma->get_nbasis(), + this->psi_gamma->get_nbasis(), + true); + wg_wfc.set_all_psi(this->psi_gamma->get_pointer(), wg_wfc.size()); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { @@ -540,7 +547,12 @@ void IState_Charge::idmatrix(const int& ib, } this->psi_k->fix_k(ik); - psi::Psi> wg_wfc(*this->psi_k, 1); + + psi::Psi> wg_wfc(1, + this->psi_k->get_nbands(), + this->psi_k->get_nbasis(), + this->psi_k->get_nbasis(), + true); for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir) { diff --git a/source/module_io/test/write_wfc_nao_test.cpp b/source/module_io/test/write_wfc_nao_test.cpp index 1b39f34a0e..c0effc8711 100644 --- a/source/module_io/test/write_wfc_nao_test.cpp +++ b/source/module_io/test/write_wfc_nao_test.cpp @@ -167,7 +167,7 @@ class WriteWfcLcaoTest : public testing::Test TEST_F(WriteWfcLcaoTest, WriteWfcLcao) { // create a psi object - psi::Psi my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, true); + psi::Psi my_psi(psi_local_double.data(), nk, nbands_local, nbasis_local, nbasis_local, true); PARAM.sys.global_out_dir = "./"; ModuleIO::write_wfc_nao(2, my_psi, ekb, wg, kvec_c, pv, -1); diff --git a/source/module_io/unk_overlap_pw.cpp b/source/module_io/unk_overlap_pw.cpp index d0d1d7c706..d4f0d2a85b 100644 --- a/source/module_io/unk_overlap_pw.cpp +++ b/source/module_io/unk_overlap_pw.cpp @@ -1,16 +1,16 @@ #include "unk_overlap_pw.h" -#include "module_parameter/parameter.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_parameter/parameter.h" unkOverlap_pw::unkOverlap_pw() { - //GlobalV::ofs_running << "this is unkOverlap_pw()" << std::endl; + // GlobalV::ofs_running << "this is unkOverlap_pw()" << std::endl; } unkOverlap_pw::~unkOverlap_pw() { - //GlobalV::ofs_running << "this is ~unkOverlap_pw()" << std::endl; + // GlobalV::ofs_running << "this is ~unkOverlap_pw()" << std::endl; } std::complex unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw, @@ -20,50 +20,44 @@ std::complex unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw, const int iband_R, const psi::Psi>* evc) { - - std::complex result(0.0,0.0); - const int number_pw = wfcpw->npw; - std::complex *unk_L = new std::complex[number_pw]; - std::complex *unk_R = new std::complex[number_pw]; - ModuleBase::GlobalFunc::ZEROS(unk_L,number_pw); - ModuleBase::GlobalFunc::ZEROS(unk_R,number_pw); - - - for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) - { - unk_L[wfcpw->getigl2ig(ik_L,igl)] = evc[0](ik_L, iband_L, igl); - } - - for (int igl = 0; igl < evc->get_ngk(ik_R); igl++) - { - unk_R[wfcpw->getigl2ig(ik_R,igl)] = evc[0](ik_R, iband_R, igl); - } - - - for (int iG = 0; iG < number_pw; iG++) - { - - result = result + conj(unk_L[iG]) * unk_R[iG]; - - } + std::complex result(0.0, 0.0); + const int number_pw = wfcpw->npw; + std::complex* unk_L = new std::complex[number_pw]; + std::complex* unk_R = new std::complex[number_pw]; + ModuleBase::GlobalFunc::ZEROS(unk_L, number_pw); + ModuleBase::GlobalFunc::ZEROS(unk_R, number_pw); + + for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) + { + unk_L[wfcpw->getigl2ig(ik_L, igl)] = evc[0](ik_L, iband_L, igl); + } + + for (int igl = 0; igl < evc->get_ngk(ik_R); igl++) + { + unk_R[wfcpw->getigl2ig(ik_R, igl)] = evc[0](ik_R, iband_R, igl); + } + + for (int iG = 0; iG < number_pw; iG++) + { + + result = result + conj(unk_L[iG]) * unk_R[iG]; + } #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - delete[] unk_L; - delete[] unk_R; - return result; - - + delete[] unk_L; + delete[] unk_R; + return result; } std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, @@ -75,24 +69,24 @@ std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, const psi::Psi>* evc, const ModuleBase::Vector3 G) { - // (1) set value - std::complex result(0.0,0.0); + // (1) set value + std::complex result(0.0, 0.0); std::complex* psi_r = new std::complex[wfcpw->nmaxgr]; std::complex* phase = new std::complex[rhopw->nmaxgr]; // get the phase value in realspace for (int ig = 0; ig < rhopw->nmaxgr; ig++) { - ModuleBase::Vector3 delta_G = rhopw->gdirect[ig] - G; - if (delta_G.norm2() < 1e-10) // rhopw->gdirect[ig] == G - { - phase[ig] = std::complex(1.0,0.0); - break; - } - } - - // (2) fft and get value - rhopw->recip2real(phase, phase); + ModuleBase::Vector3 delta_G = rhopw->gdirect[ig] - G; + if (delta_G.norm2() < 1e-10) // rhopw->gdirect[ig] == G + { + phase[ig] = std::complex(1.0, 0.0); + break; + } + } + + // (2) fft and get value + rhopw->recip2real(phase, phase); wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_r, ik_L); for (int ir = 0; ir < rhopw->nmaxgr; ir++) @@ -110,17 +104,17 @@ std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - - delete[] psi_r; - delete[] phase; + + delete[] psi_r; + delete[] phase; return result; } @@ -133,18 +127,18 @@ std::complex unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf const int npwx, const psi::Psi>* evc) { - - std::complex result(0.0,0.0); + + std::complex result(0.0, 0.0); const int number_pw = wfcpw->npw; std::complex* unk_L = new std::complex[number_pw * PARAM.globalv.npol]; std::complex* unk_R = new std::complex[number_pw * PARAM.globalv.npol]; - ModuleBase::GlobalFunc::ZEROS(unk_L,number_pw*PARAM.globalv.npol); - ModuleBase::GlobalFunc::ZEROS(unk_R,number_pw*PARAM.globalv.npol); - - for(int i = 0; i < PARAM.globalv.npol; i++) - { - for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) - { + ModuleBase::GlobalFunc::ZEROS(unk_L, number_pw * PARAM.globalv.npol); + ModuleBase::GlobalFunc::ZEROS(unk_R, number_pw * PARAM.globalv.npol); + + for (int i = 0; i < PARAM.globalv.npol; i++) + { + for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) + { unk_L[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_L, iband_L, igl + i * npwx); } @@ -154,32 +148,29 @@ std::complex unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf } } - for (int iG = 0; iG < number_pw*PARAM.globalv.npol; iG++) - { + for (int iG = 0; iG < number_pw * PARAM.globalv.npol; iG++) + { - result = result + conj(unk_L[iG]) * unk_R[iG]; + result = result + conj(unk_L[iG]) * unk_R[iG]; + } - } - #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - delete[] unk_L; - delete[] unk_R; - return result; - - + delete[] unk_L; + delete[] unk_R; + return result; } -//here G is in direct coordinate +// here G is in direct coordinate std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rhopw, const ModulePW::PW_Basis_K* wfcpw, const int ik_L, @@ -189,32 +180,32 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho const psi::Psi>* evc, const ModuleBase::Vector3 G) { - // (1) set value - std::complex result(0.0,0.0); - std::complex *phase =new std::complex[rhopw->nmaxgr]; + // (1) set value + std::complex result(0.0, 0.0); + std::complex* phase = new std::complex[rhopw->nmaxgr]; std::complex* psi_up = new std::complex[wfcpw->nmaxgr]; std::complex* psi_down = new std::complex[wfcpw->nmaxgr]; const int npwx = wfcpw->npwk_max; // get the phase value in realspace for (int ig = 0; ig < rhopw->npw; ig++) - { - if (rhopw->gdirect[ig] == G) - { - phase[ig] = std::complex(1.0,0.0); - break; - } - } - - // (2) fft and get value - rhopw->recip2real(phase, phase); + { + if (rhopw->gdirect[ig] == G) + { + phase[ig] = std::complex(1.0, 0.0); + break; + } + } + + // (2) fft and get value + rhopw->recip2real(phase, phase); wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L); wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L); for (int ir = 0; ir < wfcpw->nrxx; ir++) { psi_up[ir] = psi_up[ir] * phase[ir]; - psi_down[ir] = psi_down[ir] * phase[ir]; + psi_down[ir] = psi_down[ir] * phase[ir]; } // (3) calculate the overlap in ik_L and ik_R @@ -223,27 +214,31 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho for (int i = 0; i < PARAM.globalv.npol; i++) { - for(int ig = 0; ig < evc->get_ngk(ik_R); ig++) - { - if( i == 0 ) { result = result + conj( psi_up[ig] ) * evc[0](ik_R, iband_R, ig); -} - if( i == 1 ) { result = result + conj( psi_down[ig] ) * evc[0](ik_R, iband_R, ig + npwx); -} - } - } - + for (int ig = 0; ig < evc->get_ngk(ik_R); ig++) + { + if (i == 0) + { + result = result + conj(psi_up[ig]) * evc[0](ik_R, iband_R, ig); + } + if (i == 1) + { + result = result + conj(psi_down[ig]) * evc[0](ik_R, iband_R, ig + npwx); + } + } + } + #ifdef __MPI // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1. - double in_date_real = result.real(); - double in_date_imag = result.imag(); - double out_date_real = 0.0; - double out_date_imag = 0.0; - MPI_Allreduce(&in_date_real , &out_date_real , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - MPI_Allreduce(&in_date_imag , &out_date_imag , 1, MPI_DOUBLE , MPI_SUM , POOL_WORLD); - result = std::complex(out_date_real,out_date_imag); + double in_date_real = result.real(); + double in_date_imag = result.imag(); + double out_date_real = 0.0; + double out_date_imag = 0.0; + MPI_Allreduce(&in_date_real, &out_date_real, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(&in_date_imag, &out_date_imag, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + result = std::complex(out_date_real, out_date_imag); #endif - - delete[] psi_up; - delete[] psi_down; + + delete[] psi_up; + delete[] psi_down; return result; } diff --git a/source/module_io/write_dos_lcao.cpp b/source/module_io/write_dos_lcao.cpp index e475c77459..015c5bc1c1 100644 --- a/source/module_io/write_dos_lcao.cpp +++ b/source/module_io/write_dos_lcao.cpp @@ -461,11 +461,17 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell, } psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + + psi::Psi> Dwfc(1, + psi->get_nbands(), + psi->get_nbasis(), + psi->get_nbasis(), + true); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { - p_dwfc[index] = conj(p_dwfc[index]); + p_dwfc[index] = conj(psi->get_pointer()[index]); } for (int i = 0; i < PARAM.inp.nbands; ++i) diff --git a/source/module_io/write_proj_band_lcao.cpp b/source/module_io/write_proj_band_lcao.cpp index ccd7a0d4b0..47d4907b5b 100644 --- a/source/module_io/write_proj_band_lcao.cpp +++ b/source/module_io/write_proj_band_lcao.cpp @@ -25,8 +25,9 @@ void ModuleIO::write_proj_band_lcao( const double* sk = dynamic_cast*>(p_ham)->getSk(); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -103,14 +104,16 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; out << "" << PARAM.inp.nspin << "" << std::endl; - if (PARAM.inp.nspin == 4) + if (PARAM.inp.nspin == 4) { out << "" << std::setw(2) << PARAM.globalv.nlocal / 2 << "" << std::endl; - else + } else { out << "" << std::setw(2) << PARAM.globalv.nlocal << "" << std::endl; +} out << "" << std::endl; - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; out << "" << std::endl; @@ -139,9 +142,9 @@ void ModuleIO::write_proj_band_lcao( out << "" << std::endl; for (int ib = 0; ib < PARAM.inp.nbands; ib++) { - if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) { out << std::setw(13) << weight(is, ib * PARAM.globalv.nlocal + w); - else if (PARAM.inp.nspin == 4) + } else if (PARAM.inp.nspin == 4) { int w0 = w - s0; out << std::setw(13) @@ -178,8 +181,9 @@ void ModuleIO::write_proj_band_lcao( ModuleBase::timer::tick("ModuleIO", "write_proj_band_lcao"); int nspin0 = 1; - if (PARAM.inp.nspin == 2) + if (PARAM.inp.nspin == 2) { nspin0 = 2; +} int nks = 0; if (nspin0 == 1) { @@ -221,11 +225,16 @@ void ModuleIO::write_proj_band_lcao( // calculate Mulk psi->fix_k(ik); - psi::Psi> Dwfc(psi[0], 1); + psi::Psi> Dwfc(1, + psi->get_nbands(), + psi->get_nbasis(), + psi->get_nbasis(), + true); + std::complex* p_dwfc = Dwfc.get_pointer(); for (int index = 0; index < Dwfc.size(); ++index) { - p_dwfc[index] = conj(p_dwfc[index]); + p_dwfc[index] = conj(psi->get_pointer()[index]); } for (int i = 0; i < PARAM.inp.nbands; ++i) @@ -301,8 +310,9 @@ void ModuleIO::write_proj_band_lcao( for (int ik = 0; ik < nks; ik++) { - for (int ib = 0; ib < PARAM.inp.nbands; ib++) + for (int ib = 0; ib < PARAM.inp.nbands; ib++) { out << " " << (pelec->ekb(ik + is * nks, ib)) * ModuleBase::Ry_to_eV; +} out << std::endl; } out << "" << std::endl; diff --git a/source/module_io/write_vxc_lip.hpp b/source/module_io/write_vxc_lip.hpp index 205fdbb057..d57c8f2ccd 100644 --- a/source/module_io/write_vxc_lip.hpp +++ b/source/module_io/write_vxc_lip.hpp @@ -59,24 +59,27 @@ namespace ModuleIO assert(nbands >= 0); #endif std::vector e(nbands, 0.0); - for (int i = 0; i < nbands; ++i) + for (int i = 0; i < nbands; ++i) { e[i] = get_real(mat_mo[i * nbands + i]); +} return e; } template FPTYPE all_band_energy(const int ik, const int nbands, const std::vector>& mat_mo, const ModuleBase::matrix& wg) { FPTYPE e = 0.0; - for (int i = 0; i < nbands; ++i) + for (int i = 0; i < nbands; ++i) { e += get_real(mat_mo[i * nbands + i]) * (FPTYPE)wg(ik, i); +} return e; } template FPTYPE all_band_energy(const int ik, const std::vector& orbital_energy, const ModuleBase::matrix& wg) { FPTYPE e = 0.0; - for (int i = 0; i < orbital_energy.size(); ++i) + for (int i = 0; i < orbital_energy.size(); ++i) { e += orbital_energy[i] * (FPTYPE)wg(ik, i); +} return e; } @@ -122,7 +125,7 @@ namespace ModuleIO // const ModuleBase::matrix vr_localxc = potxc->get_veff_smooth(); // 2. allocate xc operator - psi::Psi hpsi_localxc(psi_pw.get_nk(), psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.get_ngk_pointer()); + psi::Psi hpsi_localxc(psi_pw.get_nk(), psi_pw.get_nbands(), psi_pw.get_nbasis(), kv.ngk, true); hpsi_localxc.zero_out(); // std::cout << "hpsi.nk=" << hpsi_localxc.get_nk() << std::endl; // std::cout << "hpsi.nbands=" << hpsi_localxc.get_nbands() << std::endl; @@ -170,9 +173,11 @@ namespace ModuleIO #if((defined __LCAO)&&(defined __EXX) && !(defined __CUDA)&& !(defined __ROCM)) if (GlobalC::exx_info.info_global.cal_exx) { - for (int n = 0; n < naos; ++n) - for (int m = 0; m < naos; ++m) + for (int n = 0; n < naos; ++n) { + for (int m = 0; m < naos; ++m) { vexx_k_ao[n * naos + m] += (T)GlobalC::exx_info.info_global.hybrid_alpha * exx_lip.get_exx_matrix()[ik][m][n]; +} +} std::vector vexx_k_mo = cVc(vexx_k_ao.data(), &(exx_lip.get_hvec()(ik, 0, 0)), naos, nbands); Parallel_Reduce::reduce_pool(vexx_k_mo.data(), nbands * nbands); e_orb_exx.emplace_back(orbital_energy(ik, nbands, vexx_k_mo)); diff --git a/source/module_lr/AX/test/AX_test.cpp b/source/module_lr/AX/test/AX_test.cpp index b137f78ea0..697c71bc7e 100644 --- a/source/module_lr/AX/test/AX_test.cpp +++ b/source/module_lr/AX/test/AX_test.cpp @@ -70,7 +70,8 @@ TEST_F(AXTest, DoubleSerial) int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos); + std::vector temp(s.nks, s.naos); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); std::vector V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data(), size_v); } @@ -91,7 +92,8 @@ TEST_F(AXTest, ComplexSerial) int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos); + std::vector temp(s.nks, s.naos); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); std::vector V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data>(), size_v); } @@ -113,7 +115,9 @@ TEST_F(AXTest, DoubleParallel) std::vector V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() })); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector ngk_temp(s.nks, pc.get_row_size()); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp.data(), true); Parallel_2D px; LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); @@ -139,7 +143,9 @@ TEST_F(AXTest, DoubleParallel) } // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos); + + std::vector ngk_temp_1(s.nks, s.naos); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1.data(), true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data(), V_full.at(isk).data(), false, s.naos, s.naos); @@ -165,7 +171,9 @@ TEST_F(AXTest, ComplexParallel) std::vector V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pV.get_col_size(), pV.get_row_size() })); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector ngk_temp_1(s.nks, pc.get_row_size()); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1.data(), true); Parallel_2D px; LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); @@ -187,7 +195,10 @@ TEST_F(AXTest, ComplexParallel) } // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos); + + + std::vector ngk_temp_2(s.nks, s.naos); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data>(), V_full.at(isk).data>(), false, s.naos, s.naos); diff --git a/source/module_lr/dm_trans/test/dm_trans_test.cpp b/source/module_lr/dm_trans/test/dm_trans_test.cpp index b78bbb50bc..06e412155f 100644 --- a/source/module_lr/dm_trans/test/dm_trans_test.cpp +++ b/source/module_lr/dm_trans/test/dm_trans_test.cpp @@ -66,7 +66,9 @@ TEST_F(DMTransTest, DoubleSerial) for (int istate = 0;istate < nstate;++istate) { int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos); + + std::vector temp(s.nks, s.naos); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); set_rand(c.get_pointer(), size_c); X.fix_b(istate); const std::vector& dm_for = LR::cal_dm_trans_forloop_serial(X.get_pointer(), c, s.nocc, s.nvirt); @@ -85,7 +87,9 @@ TEST_F(DMTransTest, ComplexSerial) for (int istate = 0;istate < nstate;++istate) { int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos); + + std::vector temp(s.nks, s.naos); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); set_rand(c.get_pointer(), size_c); X.fix_b(istate); const std::vector& dm_for = LR::cal_dm_trans_forloop_serial(X.get_pointer(), c, s.nocc, s.nvirt); @@ -105,10 +109,14 @@ TEST_F(DMTransTest, DoubleParallel) // X: nvirt*nocc in para2d, nocc*nvirt in psi (row-para and constructed: nvirt) Parallel_2D px; LR_Util::setup_2d_division(px, s.nb, s.nvirt, s.nocc); - psi::Psi X(s.nks, nstate, px.get_local_size(), nullptr, false); + + std::vector temp_1(s.nks, px.get_local_size()); + psi::Psi X(s.nks, nstate, px.get_local_size(), temp_1.data(), false); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px.blacs_ctxt); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector temp_2(s.nks, pc.get_row_size()); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2.data(), true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px.blacs_ctxt); @@ -147,7 +155,8 @@ TEST_F(DMTransTest, DoubleParallel) LR_Util::gather_2d_to_full(pmat, dm_pblas_loc[isk].data(), dm_gather[isk].data(), false, s.naos, s.naos); // compare to global matrix - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos); + std::vector temp(s.nks, s.naos); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); @@ -173,7 +182,9 @@ TEST_F(DMTransTest, ComplexParallel) psi::Psi> X(s.nks, nstate, px.get_local_size(), nullptr, false); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px.blacs_ctxt); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size()); + + std::vector temp(s.nks, pc.get_row_size()); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp.data(), true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px.blacs_ctxt); @@ -206,7 +217,8 @@ TEST_F(DMTransTest, ComplexParallel) LR_Util::gather_2d_to_full(pmat, dm_pblas_loc[isk].data>(), dm_gather[isk].data>(), false, s.naos, s.naos); // compare to global matrix - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos); + std::vector ngk_temp_2(s.nks, s.naos); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); diff --git a/source/module_lr/esolver_lrtd_lcao.cpp b/source/module_lr/esolver_lrtd_lcao.cpp index 97842897b9..5fdbca8f94 100644 --- a/source/module_lr/esolver_lrtd_lcao.cpp +++ b/source/module_lr/esolver_lrtd_lcao.cpp @@ -195,7 +195,11 @@ LR::ESolver_LR::ESolver_LR(ModuleESolver::ESolver_KS_LCAO&& ks_sol if (this->nbands == PARAM.inp.nbands) { move_gs(); } else // copy the part of ground state info according to paraC_ { - this->psi_ks = new psi::Psi(this->kv.get_nks(), this->paraC_.get_col_size(), this->paraC_.get_row_size()); + this->psi_ks = new psi::Psi(this->kv.get_nks(), + this->paraC_.get_col_size(), + this->paraC_.get_row_size(), + this->kv.ngk, + true); this->eig_ks.create(this->kv.get_nks(), this->nbands); const int start_band = this->nocc_max - *std::max_element(nocc.begin(), nocc.end()); for (int ik = 0;ik < this->kv.get_nks();++ik) @@ -309,8 +313,10 @@ LR::ESolver_LR::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu // now ModuleIO::read_wfc_nao needs `Parallel_Orbitals` and can only read all the bands // it need improvement to read only the bands needed this->psi_ks = new psi::Psi(this->kv.get_nks(), - this->paraMat_.ncol_bands, - this->paraMat_.get_row_size()); + this->paraMat_.ncol_bands, + this->paraMat_.get_row_size(), + this->kv.ngk, + true); this->read_ks_wfc(); if (nspin == 2) { diff --git a/source/module_lr/hamilt_casida.cpp b/source/module_lr/hamilt_casida.cpp index a29429d5be..5d7958295b 100644 --- a/source/module_lr/hamilt_casida.cpp +++ b/source/module_lr/hamilt_casida.cpp @@ -12,19 +12,23 @@ namespace LR const int ldim = nk * px.get_local_size(); int npairs = no * nv; std::vector Amat_full(this->nk * npairs * this->nk * npairs, 0.0); - for (int ik = 0;ik < this->nk;++ik) - for (int j = 0;j < no;++j) + for (int ik = 0;ik < this->nk;++ik) { + for (int j = 0;j < no;++j) { for (int b = 0;b < nv;++b) {//calculate A^{ai} for each bj int bj = j * nv + b; //global int kbj = ik * npairs + bj; //global - psi::Psi X_bj(1, 1, this->nk * px.get_local_size()); // k1-first, like in iterative solver + psi::Psi X_bj(1, 1, this->nk * px.get_local_size(), this->nk * px.get_local_size(), true); // k1-first, like in iterative solver X_bj.zero_out(); // X_bj(0, 0, lj * px.get_row_size() + lb) = this->one(); int lj = px.global2local_col(j); int lb = px.global2local_row(b); if (px.in_this_processor(b, j)) { X_bj(0, 0, ik * px.get_local_size() + lj * px.get_row_size() + lb) = this->one(); } - psi::Psi A_aibj(1, 1, this->nk * px.get_local_size()); // k1-first + psi::Psi A_aibj(1, + 1, + this->nk * px.get_local_size(), + this->nk * px.get_local_size(), + true); // k1-first A_aibj.zero_out(); this->cal_dm_trans(0, X_bj.get_pointer()); @@ -37,12 +41,15 @@ namespace LR // reduce ai for a fixed bj A_aibj.fix_kb(0, 0); #ifdef __MPI - for (int ik_ai = 0;ik_ai < this->nk;++ik_ai) + for (int ik_ai = 0;ik_ai < this->nk;++ik_ai) { LR_Util::gather_2d_to_full(px, &A_aibj.get_pointer()[ik_ai * px.get_local_size()], Amat_full.data() + kbj * this->nk * npairs /*col, bj*/ + ik_ai * npairs/*row, ai*/, false, nv, no); +} #endif } +} +} // output Amat std::cout << "Full A matrix: (elements < 1e-10 is set to 0)" << std::endl; LR_Util::print_value(Amat_full.data(), nk * npairs, nk * npairs); diff --git a/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp b/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp index 71b4fdcbd7..7ae95a8918 100644 --- a/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp +++ b/source/module_lr/ri_benchmark/test/ri_benchmark_test.cpp @@ -23,7 +23,7 @@ UnitCell::~UnitCell() { TEST(RI_Benchmark, SlicePsi) { const int nk = 1, nbands = 2, nbasis = 3; - psi::Psi psi(nk, nbands, nbasis); + psi::Psi psi(nk, nbands, nbasis, nbasis, true); for (int i = 0; i < nk * nbands * nbasis; i++) { psi.get_pointer()[i] = i; } @@ -50,7 +50,7 @@ TEST(RI_Benchmark, CalCsMO) for (int i = 0;i < nabf * nao * nao;++i) { Cs_ao[0][{0, { 0, 0, 0 }}].ptr()[i] = static_cast(i); } const UnitCell ucell; - psi::Psi psi_ks(1, 2, 2); + psi::Psi psi_ks(1, 2, 2, 2, true); for (int i = 0;i < 4;++i) { psi_ks.get_pointer()[i] = static_cast(i); } RI_Benchmark::TLRI Cs_a_mo = RI_Benchmark::cal_Cs_mo(ucell, Cs_ao, psi_ks, nocc, nvirt, false); std::vector Cs_a_mo_ref = { 11,31 }; diff --git a/source/module_lr/utils/lr_util.hpp b/source/module_lr/utils/lr_util.hpp index 8fdc1b9b96..5bbedf645f 100644 --- a/source/module_lr/utils/lr_util.hpp +++ b/source/module_lr/utils/lr_util.hpp @@ -99,7 +99,11 @@ namespace LR_Util template psi::Psi get_psi_spin(const psi::Psi& psi_in, const int& is, const int& nk) { - return psi::Psi(&psi_in(is * nk, 0, 0), psi_in, nk, psi_in.get_nbands()); + return psi::Psi(&psi_in(is * nk, 0, 0), + nk, + psi_in.get_nbands(), + psi_in.get_nbasis(), + true); } /// psi(nk=1, nbands=nb, nk * nbasis) -> psi(nb, nk, nbasis) without memory copy @@ -111,7 +115,12 @@ namespace LR_Util int ib_now = psi_kfirst.get_current_b(); psi_kfirst.fix_b(0); // for get_pointer() to get the head pointer - psi::Psi psi_bfirst(psi_kfirst.get_pointer(), nk_in, psi_kfirst.get_nbands(), nbasis_in, false); + psi::Psi psi_bfirst(psi_kfirst.get_pointer(), + nk_in, + psi_kfirst.get_nbands(), + nbasis_in, + nbasis_in, + false); psi_kfirst.fix_b(ib_now); return psi_bfirst; } @@ -124,7 +133,12 @@ namespace LR_Util int ik_now = psi_bfirst.get_current_k(); psi_bfirst.fix_kb(0, 0); // for get_pointer() to get the head pointer - psi::Psi psi_kfirst(psi_bfirst.get_pointer(), 1, psi_bfirst.get_nbands(), psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), true); + psi::Psi psi_kfirst(psi_bfirst.get_pointer(), + 1, + psi_bfirst.get_nbands(), + psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), + psi_bfirst.get_nk() * psi_bfirst.get_nbasis(), + true); psi_bfirst.fix_kb(ik_now, ib_now); return psi_kfirst; } diff --git a/source/module_lr/utils/test/lr_util_algorithms_test.cpp b/source/module_lr/utils/test/lr_util_algorithms_test.cpp index f33105d32b..3318e526cb 100644 --- a/source/module_lr/utils/test/lr_util_algorithms_test.cpp +++ b/source/module_lr/utils/test/lr_util_algorithms_test.cpp @@ -9,7 +9,7 @@ TEST(LR_Util, PsiWrapper) int nbands = 5; int nbasis = 6; - psi::Psi k1(1, nbands, nk * nbasis); + psi::Psi k1(1, nbands, nk * nbasis, nk * nbasis, true); for (int i = 0;i < nbands * nk * nbasis;++i)k1.get_pointer()[i] = i; k1.fix_b(2); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index fb8abc78cd..7942b412c9 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -28,11 +28,11 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_ range_2 = range_2_in; } +// Constructor 0: basic template Psi::Psi() { this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); } template @@ -44,16 +44,32 @@ Psi::~Psi() } } +// Constructor 1-1: template Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) { + assert(nk_in > 0); + assert(nbd_in >= 0); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU + assert(nbs_in > 0); + this->k_first = k_first_in; - this->ngk = ngk_in; + this->npol = PARAM.globalv.npol; + this->allocate_inside = true; + + this->ngk = ngk_in; // modify later + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. + resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); + + this->nk = nk_in; + this->nbands = nbd_in; + this->nbasis = nbs_in; + this->current_b = 0; this->current_k = 0; - this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); - this->resize(nk_in, nbd_in, nbs_in); + this->current_nbasis = nbs_in; + this->psi_current = this->psi; + this->psi_bias = 0; + // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); base_device::information::record_device_memory(this->ctx, @@ -62,102 +78,119 @@ Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i sizeof(T) * nk_in * nbd_in * nbs_in); } -// Constructor 8-1: +// Constructor 1-2: template -Psi::Psi(T* psi_pointer, - const int nk_in, +Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, - const std::vector& ngk_vector_in, + const std::vector& ngk_in, const bool k_first_in) { + assert(nk_in > 0); + assert(nbd_in > 0); + assert(nbs_in > 0); + this->k_first = k_first_in; - this->ngk = ngk_vector_in.data(); - this->current_b = 0; - this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); + this->allocate_inside = true; + + this->ngk = ngk_in.data(); // modify later + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. + resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); + this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; + + this->current_b = 0; + this->current_k = 0; this->current_nbasis = nbs_in; - this->psi_current = this->psi = psi_pointer; - this->allocate_inside = false; + this->psi_current = this->psi; + this->psi_bias = 0; + // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); + base_device::information::record_device_memory(this->ctx, + GlobalV::ofs_device, + "Psi->resize()", + sizeof(T) * nk_in * nbd_in * nbs_in); } -// Constructor 8-2: +// Constructor 3-1: 2D Psi version template -Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const bool k_first_in) +Psi::Psi(T* psi_pointer, + const int nk_in, + const int nbd_in, + const int nbs_in, + const int current_nbasis_in, + const bool k_first_in) { + // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. + // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func + this->k_first = k_first_in; - this->ngk = nullptr; - this->current_b = 0; - this->current_k = 0; this->npol = PARAM.globalv.npol; - this->device = base_device::get_device_type(this->ctx); + this->allocate_inside = false; + + this->ngk = nullptr; + this->psi = psi_pointer; + this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; - this->current_nbasis = nbs_in; - this->psi_current = this->psi = psi_pointer; - this->allocate_inside = false; + + this->current_k = 0; + this->current_b = 0; + this->current_nbasis = current_nbasis_in; + this->psi_current = psi_pointer; + this->psi_bias = 0; + // Currently only GPU's implementation is supported for device recording! base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); } +// Constructor 3-2: 2D Psi version template -Psi::Psi(const Psi& psi_in, const int nk_in, int nband_in) +Psi::Psi(const int nk_in, + const int nbd_in, + const int nbs_in, + const int current_nbasis_in, + const bool k_first_in) { - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } - this->k_first = psi_in.get_k_first(); - this->device = psi_in.device; - this->resize(nk_in, nband_in, psi_in.get_nbasis()); - this->ngk = psi_in.ngk; - this->npol = psi_in.npol; - if (nband_in <= psi_in.get_nbands()) - { - // copy from Psi from psi_in(current_k, 0, 0), - // if size of k is 1, current_k in new Psi is psi_in.current_k - if (nk_in == 1) - { - // current_k for this Psi only keep the spin index same as the copied Psi - this->current_k = psi_in.get_current_k(); - } - synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size()); - } -} + // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. + assert(nk_in == 1); + + this->k_first = k_first_in; + this->npol = PARAM.globalv.npol; + this->allocate_inside = true; + + this->ngk = nullptr; + assert(nk_in > 0 && nbd_in >= 0 && nbs_in > 0); + resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); -template -Psi::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) -{ - this->k_first = psi_in.get_k_first(); - this->device = base_device::get_device_type(this->ctx); - assert(this->device == psi_in.device); - assert(nk_in <= psi_in.get_nk()); - if (nband_in == 0) - { - nband_in = psi_in.get_nbands(); - } - this->ngk = psi_in.ngk; - this->npol = psi_in.npol; this->nk = nk_in; - this->nbands = nband_in; - this->nbasis = psi_in.nbasis; - this->psi_current = psi_pointer; - this->allocate_inside = false; - this->psi = psi_pointer; + this->nbands = nbd_in; + this->nbasis = nbs_in; + + this->current_k = 0; + this->current_b = 0; + this->current_nbasis = current_nbasis_in; + this->psi_current = this->psi; + this->psi_bias = 0; + + // Currently only GPU's implementation is supported for device recording! + base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); + base_device::information::record_device_memory(this->ctx, + GlobalV::ofs_device, + "Psi->resize()", + sizeof(T) * nk_in * nbd_in * nbs_in); } +// Constructor 2-1: template Psi::Psi(const Psi& psi_in) { - this->ngk = psi_in.get_ngk_pointer(); + this->ngk = psi_in.ngk; this->npol = psi_in.npol; this->nk = psi_in.get_nk(); this->nbands = psi_in.get_nbands(); @@ -166,7 +199,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); base_device::memory::synchronize_memory_op()(this->ctx, psi_in.get_device(), @@ -178,6 +211,8 @@ Psi::Psi(const Psi& psi_in) this->psi_current = this->psi + psi_in.get_psi_bias(); } + +// Constructor 2-2: template template Psi::Psi(const Psi& psi_in) @@ -191,7 +226,7 @@ Psi::Psi(const Psi& psi_in) this->current_b = psi_in.get_current_b(); this->k_first = psi_in.get_k_first(); // this function will copy psi_in.psi to this->psi no matter the device types of each other. - this->device = base_device::get_device_type(this->ctx); + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); // Specifically, if the Device_in type is CPU and the Device type is GPU. @@ -230,12 +265,23 @@ Psi::Psi(const Psi& psi_in) this->psi_current = this->psi + psi_in.get_psi_bias(); } +template +void Psi::set_all_psi(const T* another_pointer, const std::size_t size_in) +{ + assert(size_in == this->size()); + synchronize_memory_op()(this->ctx, this->ctx, this->psi, another_pointer, this->size()); +} + template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { assert(nks_in > 0 && nbands_in >= 0 && nbasis_in > 0); + // This function will delete the psi array first(if psi exist), then malloc a new memory for it. resize_memory_op()(this->ctx, this->psi, nks_in * static_cast(nbands_in) * nbasis_in, "no_record"); + + // this->zero_out(); + this->nk = nks_in; this->nbands = nbands_in; this->nbasis = nbasis_in; @@ -258,12 +304,6 @@ T* Psi::get_pointer(const int& ikb) const return this->psi_current + ikb * this->nbasis; } -template -const int* Psi::get_ngk_pointer() const -{ - return this->ngk; -} - template const bool& Psi::get_k_first() const { @@ -276,12 +316,31 @@ const Device* Psi::get_device() const return this->ctx; } +template +const int* Psi::get_ngk_pointer() const +{ + return this->ngk; +} + template const int& Psi::get_psi_bias() const { return this->psi_bias; } +template +const int& Psi::get_current_ngk() const +{ + if (this->npol == 1) + { + return this->current_nbasis; + } + else + { + return this->nbasis; + } +} + template const int& Psi::get_nk() const { @@ -315,7 +374,7 @@ void Psi::fix_k(const int ik) const { assert(ik >= 0); this->current_k = ik; - if (this->ngk != nullptr && this->npol != 2) + if (this->ngk != nullptr) { this->current_nbasis = this->ngk[ik]; } @@ -429,10 +488,7 @@ int Psi::get_current_nbas() const template const int& Psi::get_ngk(const int ik_in) const { - if (!this->ngk) - { - return this->nbasis; - } + assert(this->ngk != nullptr); return this->ngk[ik_in]; } diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 6b374c8a70..d8a994377a 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -36,53 +36,53 @@ template class Psi { public: - // Constructor 1: basic + // Constructor 0: basic Psi(); - // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later - Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true); + // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later + Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true); - // Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in - Psi(const Psi& psi_in, const int nk_in, int nband_in = 0); + // Constructor 1-2: + Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); - // Constructor 5: a wrapper of a data pointer, used for Operator::hPsi() - // in this case, fix_k can not be used - Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0); - - // Constructor 6: initialize a new psi from the given psi_in + // Constructor 2-1: initialize a new psi from the given psi_in Psi(const Psi& psi_in); - // Constructor 7: initialize a new psi from the given psi_in with a different class template + // Constructor 2-2: initialize a new psi from the given psi_in with a different class template // in this case, psi_in may have a different device type. template Psi(const Psi& psi_in); - // Constructor 8-1: a pointer version of constructor 3 - // only used in hsolver-pw function pointer. + // Constructor 3-1: 2D Psi version + // used in hsolver-pw function pointer and somewhere. Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, - const std::vector& ngk_vector_in, + const int current_nbasis_in, const bool k_first_in = true); - // Constructor 8-2: a pointer version of constructor 3 - // only used in operator.cpp call_act func - Psi(T* psi_pointer, - const int nk_in, - const int nbd_in, - const int nbs_in, - const bool k_first_in); + // Constructor 3-2: 2D Psi version + Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in); - // Destructor for deleting the psi array manually ~Psi(); + // set psi value func 1 + void set_all_psi(const T* another_pointer, const std::size_t size_in); + + // set psi value func 2 + void zero_out(); + + // size_t size() const {return this->psi.size();} + size_t size() const; + // allocate psi for three dimensions void resize(const int nks_in, const int nbands_in, const int nbasis_in); // get the pointer for the 1st index T* get_pointer() const; + // get the pointer for the 2nd index (iband for k_first = true, ik for k_first = false) T* get_pointer(const int& ikb) const; @@ -90,8 +90,6 @@ class Psi const int& get_nk() const; const int& get_nbands() const; const int& get_nbasis() const; - // size_t size() const {return this->psi.size();} - size_t size() const; /// if k_first=true: choose k-point index , then Psi(iband, ibasis) can reach Psi(ik, iband, ibasis) /// if k_first=false: choose k-point index, then Psi(ibasis) can reach Psi(iband, ik, ibasis) @@ -122,27 +120,29 @@ class Psi int get_current_nbas() const; const int& get_ngk(const int ik_in) const; - // return ngk array of psi + const int* get_ngk_pointer() const; + // return k_first const bool& get_k_first() const; + // return device type of psi const Device* get_device() const; + // return psi_bias const int& get_psi_bias() const; - // mark - void zero_out(); + const int& get_current_ngk() const; // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; + int npol = 1; private: T* psi = nullptr; // avoid using C++ STL - base_device::AbacusDevice_t device = {}; // track the device type (CPU, GPU and SYCL are supported currented) - Device* ctx = {}; // an context identifier for obtaining the device variable + Device* ctx = {}; // an context identifier for obtaining the device variable // dimensions int nk = 1; // number of k points diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index df22b5f885..0b42df63c7 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -14,9 +14,6 @@ class TestPsi : public ::testing::Test const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - - psi::Psi>* psi_object4 = new psi::Psi>(*psi_object31, ink, 0); - psi::Psi>* psi_object5 = new psi::Psi>(psi_object31->get_pointer(), *psi_object31, ink, 0); }; TEST_F(TestPsi, get_val) @@ -63,26 +60,6 @@ TEST_F(TestPsi, get_val) EXPECT_EQ(psi_object14->get_psi_bias(), 0); } -// TEST_F(TestPsi, get_ngk) -// { -// psi::Psi>* psi_object21 = new psi::Psi>(&ngk[0]); -// psi::Psi* psi_object22 = new psi::Psi(&ngk[0]); -// psi::Psi>* psi_object23 = new psi::Psi>(&ngk[0]); -// psi::Psi* psi_object24 = new psi::Psi(&ngk[0]); - -// EXPECT_EQ(psi_object21->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object21->get_ngk_pointer()[0], ngk[0]); - -// EXPECT_EQ(psi_object22->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object22->get_ngk_pointer()[0], ngk[0]); - -// EXPECT_EQ(psi_object23->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object23->get_ngk_pointer()[0], ngk[0]); - -// EXPECT_EQ(psi_object24->get_ngk(2), ngk[2]); -// EXPECT_EQ(psi_object24->get_ngk_pointer()[0], ngk[0]); -// } - TEST_F(TestPsi, get_pointer_op_zero_complex_double) { for (int i = 0; i < ink; i++) @@ -119,7 +96,9 @@ TEST_F(TestPsi, get_pointer_op_zero_complex_double) // cover all lines in fix_k func psi_object31->fix_k(2); EXPECT_EQ(psi_object31->get_psi_bias(), 0); - psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis); + + std::vector temp(ink, inbasis); + psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp.data(), true); psi_temp->fix_k(0); EXPECT_EQ(psi_object31->get_current_nbas(), inbasis); delete psi_temp; @@ -331,30 +310,6 @@ TEST_F(TestPsi, band_first) EXPECT_EQ(std::get<0>(psi_band_32->to_range(illegal_range1)), nullptr); EXPECT_EQ(std::get<1>(psi_band_32->to_range(illegal_range2)), 0); - // pointer constructor - // band-first to k-first - // psi::Psi psi_band_32_k(psi_band_32->get_pointer(), psi_band_32->get_nk(), psi_band_32->get_nbands(), psi_band_32->get_nbasis(), psi_band_32->get_ngk_pointer(), true); - // k-first to band-first - // psi::Psi psi_band_32_b(psi_band_32_k.get_pointer(), psi_band_32_k.get_nk(), psi_band_32_k.get_nbands(), psi_band_32_k.get_nbasis(), psi_band_32_k.get_ngk_pointer(), false); - // EXPECT_EQ(psi_band_32_k.get_nk(), ink); - // EXPECT_EQ(psi_band_32_k.get_nbands(), inbands); - // EXPECT_EQ(psi_band_32_k.get_nbasis(), inbasis); - // EXPECT_EQ(psi_band_32_b.get_nk(), ink); - // EXPECT_EQ(psi_band_32_b.get_nbands(), inbands); - // EXPECT_EQ(psi_band_32_b.get_nbasis(), inbasis); - // for (int ik = 0;ik < ink;++ik) - // { - // for (int ib = 0;ib < inbands;++ib) - // { - // psi_band_32->fix_kb(ik, ib); - // psi_band_32_k.fix_kb(ik, ib); - // psi_band_32_b.fix_kb(ik, ib); - // EXPECT_EQ(psi_band_32->get_psi_bias(), (ib * ink + ik) * inbasis); - // EXPECT_EQ(psi_band_32_k.get_psi_bias(), (ik * inbands + ib) * inbasis); - // EXPECT_EQ(psi_band_32_b.get_psi_bias(), (ib * ink + ik) * inbasis); - // } - // } - delete psi_band_c64; delete psi_band_64; delete psi_band_c32; diff --git a/source/module_ri/exx_lip.hpp b/source/module_ri/exx_lip.hpp index 06456b6f80..bd94316027 100644 --- a/source/module_ri/exx_lip.hpp +++ b/source/module_ri/exx_lip.hpp @@ -114,7 +114,7 @@ Exx_Lip::Exx_Lip(const Exx_Info::Exx_Info_Lip& info_in, #endif this->k_pack->wf_wg.create(this->k_pack->kv_ptr->get_nks(),PARAM.inp.nbands); - this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal); + this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk.data(), true); // this->k_pack->hvec_array = new ModuleBase::ComplexMatrix[this->k_pack->kv_ptr->get_nks()]; // for( int ik=0; ikk_pack->kv_ptr->get_nks(); ++ik) // {