Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Remove global dependence of descriptor, orbital_precalc, v_delta_precalc in DeePKS. #5812

Merged
merged 13 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -192,21 +192,21 @@ OBJS_CELL=atom_pseudo.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
deepks_descriptor.o\
deepks_orbital.o\
deepks_orbpre.o\
deepks_vdpre.o\
deepks_hmat.o\
LCAO_deepks_io.o\
LCAO_deepks_mpi.o\
LCAO_deepks_pdm.o\
LCAO_deepks_phialpha.o\
LCAO_deepks_torch.o\
LCAO_deepks_vdelta.o\
deepks_hmat.o\
LCAO_deepks_interface.o\
deepks_orbpre.o\
cal_gdmx.o\
cal_gdmepsl.o\
cal_gedm.o\
cal_gvx.o\
cal_descriptor.o\
v_delta_precalc.o\


OBJS_ELECSTAT=elecstate.o\
Expand Down
20 changes: 10 additions & 10 deletions source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ESolver_KS : public ESolver_FP
virtual void after_scf(UnitCell& ucell, const int istep) override;

//! <Temporary> It should be replaced by a function in Hamilt Class
virtual void update_pot(UnitCell& ucell, const int istep, const int iter) {};
virtual void update_pot(UnitCell& ucell, const int istep, const int iter){};

//! Hamiltonian
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
Expand All @@ -72,7 +72,7 @@ class ESolver_KS : public ESolver_FP
//! Electronic wavefunctions
psi::Psi<T>* psi = nullptr;

//! plane wave or LCAO
//! plane wave or LCAO
std::string basisname;

//! number of electrons
Expand All @@ -83,18 +83,18 @@ class ESolver_KS : public ESolver_FP

//! the start time of scf iteration
#ifdef __MPI
double iter_time;
double iter_time;
#else
std::chrono::system_clock::time_point iter_time;
#endif

double diag_ethr; //! the threshold for diagonalization
double scf_thr; //! scf density threshold
double scf_ene_thr; //! scf energy threshold
double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver)
double hsolver_error; //! the error of HSolver
int maxniter; //! maximum iter steps for scf
int niter; //! iter steps actually used in scf
double diag_ethr; //! the threshold for diagonalization
double scf_thr; //! scf density threshold
double scf_ene_thr; //! scf energy threshold
double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver)
double hsolver_error; //! the error of HSolver
int maxniter; //! maximum iter steps for scf
int niter; //! iter steps actually used in scf
};
} // namespace ModuleESolver
#endif
96 changes: 56 additions & 40 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,14 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,

if (!PARAM.inp.deepks_equiv) // training with force label not supported by equivariant version now
{
torch::Tensor gdmx;
if (PARAM.globalv.gamma_only_local)
{
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
GlobalC::ld.cal_gdmx(dm_gamma,
ucell,
orb,
gd,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
isstress);

GlobalC::ld
.cal_gdmx(dm_gamma, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
}
else
{
Expand All @@ -533,25 +529,25 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
->get_DM()
->get_DMK_vector();

GlobalC::ld
.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, isstress);
GlobalC::ld.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmx(ucell.nat);
GlobalC::ld.check_gdmx(ucell.nat, gdmx);
}
std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
GlobalC::ld.cal_gvx(ucell.nat, gevdm);
torch::Tensor gvx;
GlobalC::ld.cal_gvx(ucell.nat, gevdm, gdmx, gvx);

if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gvx(ucell.nat);
GlobalC::ld.check_gvx(ucell.nat, gvx);
}

LCAO_deepks_io::save_npy_gvx(ucell.nat,
GlobalC::ld.des_per_atom,
GlobalC::ld.gvx_tensor,
gvx,
PARAM.globalv.global_out_dir,
GlobalV::MY_RANK);
}
Expand Down Expand Up @@ -715,6 +711,12 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
{
scs(i, j) += stress_exx(i, j);
}
#endif
#ifdef __DEEPKS
if (PARAM.inp.deepks_scf)
{
scs(i, j) += svnl_dalpha(i, j);
}
#endif
}
}
Expand All @@ -726,47 +728,61 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
#ifdef __DEEPKS
if (PARAM.inp.deepks_out_labels) // not parallelized yet
{
const std::string file_s = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
const std::string file_stot = PARAM.globalv.global_out_dir + "deepks_stot.npy";
LCAO_deepks_io::save_npy_s(scs,
file_s,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
}
if (PARAM.inp.deepks_scf)
{
if (ModuleSymmetry::Symmetry::symm_flag == 1)
{
symm->symmetrize_mat3(svnl_dalpha, ucell.lat);
} // end symmetry
for (int i = 0; i < 3; i++)
{
for (int j = 0; j < 3; j++)
{
scs(i, j) += svnl_dalpha(i, j);
}
}
}
if (PARAM.inp.deepks_out_labels) // not parallelized yet
{
const std::string file_s = PARAM.globalv.global_out_dir + "deepks_stot.npy";
LCAO_deepks_io::save_npy_s(scs,
file_s,
file_stot,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_tot, w/ model

// wenfei add 2021/11/2
if (PARAM.inp.deepks_scf)
{
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
LCAO_deepks_io::save_npy_s(scs - svnl_dalpha,
file_sbase,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;

if (!PARAM.inp.deepks_equiv) // training with stress label not supported by equivariant version now
{
torch::Tensor gdmepsl;
if (PARAM.globalv.gamma_only_local)
{
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

GlobalC::ld.cal_gdmepsl(dm_gamma,
ucell,
orb,
gd,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
gdmepsl);
}
else
{
const std::vector<std::vector<std::complex<double>>>& dm_k
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec)
->get_DM()
->get_DMK_vector();

GlobalC::ld
.cal_gdmepsl(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmepsl);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmepsl(gdmepsl);
}

std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm);
torch::Tensor gvepsl;
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm, gdmepsl, gvepsl);

LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
GlobalC::ld.des_per_atom,
GlobalC::ld.gvepsl_tensor,
gvepsl,
PARAM.globalv.global_out_dir,
GlobalV::MY_RANK); // unitless, grad_vepsl
}
Expand Down
17 changes: 14 additions & 3 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,19 @@ void Force_LCAO<double>::ftable(const bool isforce,

#ifdef __DEEPKS
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
std::vector<torch::Tensor> descriptor;
if (PARAM.inp.deepks_scf)
{
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);

GlobalC::ld.cal_descriptor(ucell.nat);
GlobalC::ld.cal_gedm(ucell.nat);
DeePKS_domain::cal_descriptor(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.inl_l,
GlobalC::ld.pdm,
descriptor,
GlobalC::ld.des_per_atom);
GlobalC::ld.cal_gedm(ucell.nat, descriptor);

const int nks = 1;
DeePKS_domain::cal_f_delta<double>(dm_gamma,
Expand Down Expand Up @@ -305,7 +311,12 @@ void Force_LCAO<double>::ftable(const bool isforce,

GlobalC::ld.check_projected_dm();

GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);
DeePKS_domain::check_descriptor(GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
ucell,
PARAM.globalv.global_out_dir,
descriptor);

GlobalC::ld.check_gedm();

Expand Down
11 changes: 8 additions & 3 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,14 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);

GlobalC::ld.cal_descriptor(ucell.nat);

GlobalC::ld.cal_gedm(ucell.nat);
std::vector<torch::Tensor> descriptor;
DeePKS_domain::cal_descriptor(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.inl_l,
GlobalC::ld.pdm,
descriptor,
GlobalC::ld.des_per_atom);
GlobalC::ld.cal_gedm(ucell.nat, descriptor);

DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,
ucell,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::initialize_HR(const Grid_Driv
this->H_V_delta = new HContainer<TR>(paraV);
if (std::is_same<TK, double>::value)
{
//this->H_V_delta = new HContainer<TR>(paraV);
// this->H_V_delta = new HContainer<TR>(paraV);
this->H_V_delta->fix_gamma();
}

Expand Down Expand Up @@ -138,8 +138,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::initialize_HR(const Grid_Driv
// if (std::is_same<TK, double>::value)
// {
this->H_V_delta->allocate(nullptr, true);
// expand hR with H_V_delta
// update : for computational rigor, gamma-only and multi-k cases both have full size of Hamiltonian of DeePKS now
// expand hR with H_V_delta
// update : for computational rigor, gamma-only and multi-k cases both have full size of Hamiltonian of DeePKS now
this->hR->add(*this->H_V_delta);
this->hR->allocate(nullptr, false);
// }
Expand All @@ -161,8 +161,15 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
ModuleBase::timer::tick("DeePKS", "contributeHR");

GlobalC::ld.cal_projected_DM<TK>(this->DM, *this->ucell, *ptr_orb_, *(this->gd));
GlobalC::ld.cal_descriptor(this->ucell->nat);
GlobalC::ld.cal_gedm(this->ucell->nat);

std::vector<torch::Tensor> descriptor;
DeePKS_domain::cal_descriptor(this->ucell->nat,
GlobalC::ld.inlmax,
GlobalC::ld.inl_l,
GlobalC::ld.pdm,
descriptor,
GlobalC::ld.des_per_atom);
GlobalC::ld.cal_gedm(this->ucell->nat, descriptor);

// // recalculate the H_V_delta
// if (this->H_V_delta == nullptr)
Expand Down
10 changes: 5 additions & 5 deletions source/module_hamilt_lcao/module_deepks/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
if(ENABLE_DEEPKS)
list(APPEND objects
LCAO_deepks.cpp
deepks_descriptor.cpp
deepks_force.cpp
deepks_orbital.cpp
deepks_orbpre.cpp
deepks_vdpre.cpp
deepks_hmat.cpp
LCAO_deepks_io.cpp
LCAO_deepks_mpi.cpp
LCAO_deepks_pdm.cpp
LCAO_deepks_phialpha.cpp
LCAO_deepks_torch.cpp
LCAO_deepks_vdelta.cpp
deepks_hmat.cpp
LCAO_deepks_interface.cpp
deepks_orbpre.cpp
cal_gdmx.cpp
cal_gdmepsl.cpp
cal_gedm.cpp
cal_gvx.cpp
cal_descriptor.cpp
v_delta_precalc.cpp
)

add_library(
Expand Down
Loading
Loading