Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: refactor the constructors of Psi class #5761

Merged
merged 52 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
84e63b6
remove Psi(const Psi& psi_in, const int nk_in, int nband_in);
haozhihan Dec 25, 2024
b15cd5c
fix bug
haozhihan Dec 25, 2024
9900bb7
fix bug
haozhihan Dec 25, 2024
a02b5d8
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Dec 25, 2024
8e3a58f
remove device value in psi
haozhihan Dec 25, 2024
c716bb7
update Psi(const Psi& psi_in, const int nk_in, int nband_in)
haozhihan Dec 25, 2024
1c2f523
update get_ngk usage
haozhihan Dec 25, 2024
1fb8851
fix bug about ngk
haozhihan Dec 25, 2024
1a9aea9
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Dec 25, 2024
9a3a9f0
fix bug
haozhihan Dec 26, 2024
093c3f2
format operator
haozhihan Dec 26, 2024
af1b7bc
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Dec 26, 2024
16687c3
fix bug
haozhihan Dec 26, 2024
35d26d6
fix bug
haozhihan Dec 26, 2024
3096085
fix bug
haozhihan Dec 26, 2024
a3817e4
fix bug
haozhihan Dec 26, 2024
67fda40
add get_cur_effective_basis func
haozhihan Dec 27, 2024
0b0604c
fix bug
haozhihan Dec 27, 2024
d5634b3
update get_cur_effective_basis
haozhihan Dec 27, 2024
0339ba3
check bugs
haozhihan Dec 27, 2024
190e74a
update Constructor 8-1
haozhihan Dec 27, 2024
cf0cad7
fix bug
haozhihan Dec 28, 2024
30b1aa4
fix bug
haozhihan Dec 28, 2024
fcc167f
fix bug
haozhihan Dec 28, 2024
8cf1363
fix bug maybe
haozhihan Dec 29, 2024
7e09971
fix bug
haozhihan Dec 29, 2024
6ff1b3a
check correct
haozhihan Dec 29, 2024
7e44e9f
check 1
haozhihan Dec 29, 2024
f4f958c
fix unit test
haozhihan Dec 29, 2024
444f21e
fix unit bug
haozhihan Dec 29, 2024
588a335
update get_ngk func
haozhihan Dec 29, 2024
76893ee
remove get-ngk in velocity-pw
haozhihan Dec 30, 2024
462857f
fix bug
haozhihan Dec 30, 2024
7f66c7d
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Dec 30, 2024
ed7387e
fix 186_PW_SKG_ALL bug
haozhihan Dec 30, 2024
0906e22
format source/module_io/unk_overlap_pw.cpp
haozhihan Dec 30, 2024
c2cb0df
update Constructor in psi
haozhihan Dec 30, 2024
5a86f45
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Dec 30, 2024
8935299
debug unit test
haozhihan Dec 30, 2024
bfdfc92
fix ri test bug
haozhihan Dec 30, 2024
c55e20d
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Dec 30, 2024
49987f8
fix psi-ut bug
haozhihan Dec 31, 2024
2def09e
remove Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const i…
haozhihan Dec 31, 2024
d23dfe5
remove useless code
haozhihan Dec 31, 2024
24adafe
update Psi(const Psi& psi_in, const int nk_in, const int nband_in);
haozhihan Dec 31, 2024
a920924
remove Psi(const Psi& psi_in, const int nk_in, const int nband_in);
haozhihan Dec 31, 2024
b9d0160
refactor psi code
haozhihan Dec 31, 2024
5373bdc
fix sdft bug
haozhihan Dec 31, 2024
2807490
Merge branch 'develop' into psi-ngk
haozhihan Jan 2, 2025
cf09808
Merge branch 'develop' into psi-ngk
haozhihan Jan 3, 2025
0145476
change to get_current_ngk
haozhihan Jan 5, 2025
943ce71
Merge branch 'develop' into psi-ngk
haozhihan Jan 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions source/module_elecstate/cal_dm.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
//dm.fix_k(ik);
dm[ik].create(ParaV->ncol, ParaV->nrow);
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
psi::Psi<double> wg_wfc(wfc, 1);
psi::Psi<double> wg_wfc(1,
wfc.get_nbands(),
wfc.get_nbasis(),
wfc.get_nbasis(),
true);
wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size());

int ib_global = 0;
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
Expand All @@ -41,7 +46,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) continue;
if (ib_global >= wg.nc) { continue;
}
const double wg_local = wg(ik, ib_global);
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand Down Expand Up @@ -99,7 +105,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) continue;
if (ib_global >= wg.nc) { continue;
}
const double wg_local = wg(ik, ib_global);
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand Down
4 changes: 2 additions & 2 deletions source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)

this->init_rho_data();
int ik = psi.get_current_k();
int npw = psi.get_current_nbas();
int npw = psi.get_cur_effective_basis();
haozhihan marked this conversation as resolved.
Show resolved Hide resolved
int current_spin = 0;
if (PARAM.inp.nspin == 2)
{
Expand Down Expand Up @@ -287,7 +287,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
psi.fix_k(ik);
const T* psi_now = psi.get_pointer();
const int currect_spin = this->klist->isk[ik];
const int npw = psi.get_current_nbas();
const int npw = psi.get_cur_effective_basis();

// get |beta>
if (this->ppcell->nkb > 0)
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/elecstate_pw_cal_tau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
int npw = psi.get_current_nbas();
int npw = psi.get_cur_effective_basis();
int current_spin = 0;
if (PARAM.inp.nspin == 2)
{
Expand Down
16 changes: 14 additions & 2 deletions source/module_elecstate/module_dm/cal_dm_psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
// dm.fix_k(ik);
// dm[ik].create(ParaV->ncol, ParaV->nrow);
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
psi::Psi<double> wg_wfc(wfc, 1);

psi::Psi<double> wg_wfc(1,
wfc.get_nbands(),
wfc.get_nbasis(),
wfc.get_nbasis(),
true);
wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size());


int ib_global = 0;
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
Expand Down Expand Up @@ -89,7 +96,12 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
// dm.fix_k(ik);
//dm[ik].create(ParaV->ncol, ParaV->nrow);
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr);
psi::Psi<std::complex<double>> wg_wfc(1,
wfc.get_nbands(),
wfc.get_nbasis(),
wfc.get_nbasis(),
true);

const std::complex<double>* pwfc = wfc.get_pointer();
std::complex<double>* pwg_wfc = wg_wfc.get_pointer();
#ifdef _OPENMP
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
//! initialize the gradients of Etotal with respect to occupation numbers and wfc,
//! and set all elements to 0.
ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true);
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis());
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis(), this->kv.ngk, true);
dE_dWfc.zero_out();

double Etotal_RDMFT = this->rdmft_solver.run(dE_dOccNum, dE_dWfc);
Expand Down
6 changes: 5 additions & 1 deletion source/module_esolver/esolver_of.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell)

// Refresh the arrays
delete this->psi_;
this->psi_ = new psi::Psi<double>(1, PARAM.inp.nspin, this->pw_rho->nrxx);
this->psi_ = new psi::Psi<double>(1,
PARAM.inp.nspin,
this->pw_rho->nrxx,
this->pw_rho->nrxx,
true);
for (int is = 0; is < PARAM.inp.nspin; ++is)
{
this->pphi_[is] = this->psi_->get_pointer(is);
Expand Down
6 changes: 5 additions & 1 deletion source/module_esolver/esolver_of_tool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ void ESolver_OF::init_elecstate(UnitCell& ucell)
void ESolver_OF::allocate_array()
{
// Initialize the "wavefunction", which is sqrt(rho)
this->psi_ = new psi::Psi<double>(1, PARAM.inp.nspin, this->pw_rho->nrxx);
this->psi_ = new psi::Psi<double>(1,
PARAM.inp.nspin,
this->pw_rho->nrxx,
this->pw_rho->nrxx,
true);
ModuleBase::Memory::record("OFDFT::Psi", sizeof(double) * PARAM.inp.nspin * this->pw_rho->nrxx);
this->pphi_ = new double*[PARAM.inp.nspin];
for (int is = 0; is < PARAM.inp.nspin; ++is)
Expand Down
129 changes: 81 additions & 48 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,31 @@

using namespace hamilt;


template<typename T, typename Device>
Operator<T, Device>::Operator(){}

template<typename T, typename Device>
Operator<T, Device>::~Operator()
template <typename T, typename Device>
Operator<T, Device>::Operator()
{
if(this->hpsi != nullptr) { delete this->hpsi;
}

template <typename T, typename Device>
Operator<T, Device>::~Operator()
{
if (this->hpsi != nullptr)
{
delete this->hpsi;
}
Operator* last = this->next_op;
Operator* last_sub = this->next_sub_op;
while(last != nullptr || last_sub != nullptr)
while (last != nullptr || last_sub != nullptr)
{
if(last_sub != nullptr)
{//delete sub_chain first
if (last_sub != nullptr)
{ // delete sub_chain first
Operator* node_delete = last_sub;
last_sub = last_sub->next_sub_op;
node_delete->next_sub_op = nullptr;
delete node_delete;
}
else
{//delete main chain if sub_chain is deleted
{ // delete main chain if sub_chain is deleted
Operator* node_delete = last;
last_sub = last->next_sub_op;
node_delete->next_sub_op = nullptr;
Expand All @@ -36,7 +39,7 @@ Operator<T, Device>::~Operator()
}
}

template<typename T, typename Device>
template <typename T, typename Device>
typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& input) const
{
using syncmem_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
Expand All @@ -46,37 +49,51 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp

T* tmhpsi = this->get_hpsi(input);
const T* tmpsi_in = std::get<0>(psi_info);
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
// if range in hpsi_info is illegal, the first return of to_range() would be nullptr
if (tmpsi_in == nullptr)
{
ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!");
}
//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
// if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
T* hpsi_pointer = std::get<2>(input);
if (this->in_place)
{
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
delete this->hpsi;
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
1,
nbands / psi_input->npol,
psi_input->get_nbasis(),
psi_input->get_nbasis(),
true);
}

auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {

// a "psi" with the bands of needed range
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);


psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in),
1,
nbands,
psi_input->get_nbasis(),
psi_input->get_nbasis(),
true);

switch (op->get_act_type())
{
case 2:
op->act(psi_wrapper, *this->hpsi, nbands);
break;
default:
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
op->act(nbands,
psi_input->get_nbasis(),
psi_input->npol,
tmpsi_in,
this->hpsi->get_pointer(),
psi_input->get_current_nbas(),
is_first_node);
break;
}
};
};

ModuleBase::timer::tick("Operator", "hPsi");
call_act(this, true); // first node
Expand All @@ -91,39 +108,43 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
}


template<typename T, typename Device>
void Operator<T, Device>::init(const int ik_in)
template <typename T, typename Device>
void Operator<T, Device>::init(const int ik_in)
{
this->ik = ik_in;
if(this->next_op != nullptr) {
if (this->next_op != nullptr)
{
this->next_op->init(ik_in);
}
}

template<typename T, typename Device>
void Operator<T, Device>::add(Operator* next)
template <typename T, typename Device>
void Operator<T, Device>::add(Operator* next)
{
if(next==nullptr) { return;
}
if (next == nullptr)
{
return;
}
next->is_first_node = false;
if(next->next_op != nullptr) { this->add(next->next_op);
}
if (next->next_op != nullptr)
{
this->add(next->next_op);
}
Operator* last = this;
//loop to end of the chain
while(last->next_op != nullptr)
// loop to end of the chain
while (last->next_op != nullptr)
{
if(next->cal_type==last->cal_type)
if (next->cal_type == last->cal_type)
{
break;
}
last = last->next_op;
}
if(next->cal_type == last->cal_type)
if (next->cal_type == last->cal_type)
{
//insert next to sub chain of current node
// insert next to sub chain of current node
Operator* sub_last = last;
while(sub_last->next_sub_op != nullptr)
while (sub_last->next_sub_op != nullptr)
{
sub_last = sub_last->next_sub_op;
}
Expand All @@ -136,34 +157,45 @@ void Operator<T, Device>::add(Operator* next)
}
}

template<typename T, typename Device>
template <typename T, typename Device>
T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
{
const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1);
//in_place call of hPsi, hpsi inputs as new psi,
//create a new hpsi and delete old hpsi later
// in_place call of hPsi, hpsi inputs as new psi,
// create a new hpsi and delete old hpsi later
T* hpsi_pointer = std::get<2>(info);
const T* psi_pointer = std::get<0>(info)->get_pointer();
if(this->hpsi != nullptr)
if (this->hpsi != nullptr)
{
delete this->hpsi;
this->hpsi = nullptr;
}
if(!hpsi_pointer)
if (!hpsi_pointer)
{
ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr");
}
else if(hpsi_pointer == psi_pointer)
else if (hpsi_pointer == psi_pointer)
{
this->in_place = true;
this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
// this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
this->hpsi = new psi::Psi<T, Device>(1,
nbands_range,
std::get<0>(info)->get_nbasis(),
std::get<0>(info)->get_nbasis(),
true);
}
else
{
this->in_place = false;
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range);

this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
1,
nbands_range,
std::get<0>(info)->get_nbasis(),
std::get<0>(info)->get_nbasis(),
true);
}

hpsi_pointer = this->hpsi->get_pointer();
size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis();
// ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
Expand All @@ -172,7 +204,8 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
return hpsi_pointer;
}

namespace hamilt {
namespace hamilt
{
template class Operator<float, base_device::DEVICE_CPU>;
template class Operator<std::complex<float>, base_device::DEVICE_CPU>;
template class Operator<double, base_device::DEVICE_CPU>;
Expand All @@ -183,4 +216,4 @@ template class Operator<std::complex<float>, base_device::DEVICE_GPU>;
template class Operator<double, base_device::DEVICE_GPU>;
template class Operator<std::complex<double>, base_device::DEVICE_GPU>;
#endif
}
} // namespace hamilt
Loading
Loading