Skip to content

Commit

Permalink
Refactor: Remove global dependence of descriptor, orbital_precalc, v_…
Browse files Browse the repository at this point in the history
…delta_precalc in DeePKS. (#5812)

* Remove functions related to v_delta in LCAO_Deepks; Remove some redundent variables.

* Remove some temporary variables for force/stress calculation in DeePKS and separate force&stress calculations. Remove global dependence of descriptor.

* Use accessor to accelerate the manipulation of torch::Tensor variables in DeePKS.

* Remove LCAO_deepks_mpi.cpp.

* Update Unittest for DeePKS.

* Clang-format change.

* Update cal_gdmx and cal_gdmepsl.

* Fix check_gvx() bug when using mpirun.

* Move functions for calculating descriptor from LCAO_deepks to DeePKS_domain.

* Add UT for cal_gdmepsl and modify the ref files to suit the new data structure.
  • Loading branch information
ErjieWu authored Jan 5, 2025
1 parent 8905ddf commit e25db6e
Show file tree
Hide file tree
Showing 50 changed files with 3,767 additions and 2,500 deletions.
10 changes: 5 additions & 5 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -192,21 +192,21 @@ OBJS_CELL=atom_pseudo.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
deepks_descriptor.o\
deepks_orbital.o\
deepks_orbpre.o\
deepks_vdpre.o\
deepks_hmat.o\
LCAO_deepks_io.o\
LCAO_deepks_mpi.o\
LCAO_deepks_pdm.o\
LCAO_deepks_phialpha.o\
LCAO_deepks_torch.o\
LCAO_deepks_vdelta.o\
deepks_hmat.o\
LCAO_deepks_interface.o\
deepks_orbpre.o\
cal_gdmx.o\
cal_gdmepsl.o\
cal_gedm.o\
cal_gvx.o\
cal_descriptor.o\
v_delta_precalc.o\


OBJS_ELECSTAT=elecstate.o\
Expand Down
20 changes: 10 additions & 10 deletions source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ESolver_KS : public ESolver_FP
virtual void after_scf(UnitCell& ucell, const int istep) override;

//! <Temporary> It should be replaced by a function in Hamilt Class
virtual void update_pot(UnitCell& ucell, const int istep, const int iter) {};
virtual void update_pot(UnitCell& ucell, const int istep, const int iter){};

//! Hamiltonian
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
Expand All @@ -72,7 +72,7 @@ class ESolver_KS : public ESolver_FP
//! Electronic wavefunctions
psi::Psi<T>* psi = nullptr;

//! plane wave or LCAO
//! plane wave or LCAO
std::string basisname;

//! number of electrons
Expand All @@ -83,18 +83,18 @@ class ESolver_KS : public ESolver_FP

//! the start time of scf iteration
#ifdef __MPI
double iter_time;
double iter_time;
#else
std::chrono::system_clock::time_point iter_time;
#endif

double diag_ethr; //! the threshold for diagonalization
double scf_thr; //! scf density threshold
double scf_ene_thr; //! scf energy threshold
double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver)
double hsolver_error; //! the error of HSolver
int maxniter; //! maximum iter steps for scf
int niter; //! iter steps actually used in scf
double diag_ethr; //! the threshold for diagonalization
double scf_thr; //! scf density threshold
double scf_ene_thr; //! scf energy threshold
double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver)
double hsolver_error; //! the error of HSolver
int maxniter; //! maximum iter steps for scf
int niter; //! iter steps actually used in scf
};
} // namespace ModuleESolver
#endif
96 changes: 56 additions & 40 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,14 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,

if (!PARAM.inp.deepks_equiv) // training with force label not supported by equivariant version now
{
torch::Tensor gdmx;
if (PARAM.globalv.gamma_only_local)
{
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
GlobalC::ld.cal_gdmx(dm_gamma,
ucell,
orb,
gd,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
isstress);

GlobalC::ld
.cal_gdmx(dm_gamma, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
}
else
{
Expand All @@ -533,25 +529,25 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
->get_DM()
->get_DMK_vector();

GlobalC::ld
.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, isstress);
GlobalC::ld.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmx(ucell.nat);
GlobalC::ld.check_gdmx(ucell.nat, gdmx);
}
std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
GlobalC::ld.cal_gvx(ucell.nat, gevdm);
torch::Tensor gvx;
GlobalC::ld.cal_gvx(ucell.nat, gevdm, gdmx, gvx);

if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gvx(ucell.nat);
GlobalC::ld.check_gvx(ucell.nat, gvx);
}

LCAO_deepks_io::save_npy_gvx(ucell.nat,
GlobalC::ld.des_per_atom,
GlobalC::ld.gvx_tensor,
gvx,
PARAM.globalv.global_out_dir,
GlobalV::MY_RANK);
}
Expand Down Expand Up @@ -715,6 +711,12 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
{
scs(i, j) += stress_exx(i, j);
}
#endif
#ifdef __DEEPKS
if (PARAM.inp.deepks_scf)
{
scs(i, j) += svnl_dalpha(i, j);
}
#endif
}
}
Expand All @@ -726,47 +728,61 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
#ifdef __DEEPKS
if (PARAM.inp.deepks_out_labels) // not parallelized yet
{
const std::string file_s = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
const std::string file_stot = PARAM.globalv.global_out_dir + "deepks_stot.npy";
LCAO_deepks_io::save_npy_s(scs,
file_s,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
}
if (PARAM.inp.deepks_scf)
{
if (ModuleSymmetry::Symmetry::symm_flag == 1)
{
symm->symmetrize_mat3(svnl_dalpha, ucell.lat);
} // end symmetry
for (int i = 0; i < 3; i++)
{
for (int j = 0; j < 3; j++)
{
scs(i, j) += svnl_dalpha(i, j);
}
}
}
if (PARAM.inp.deepks_out_labels) // not parallelized yet
{
const std::string file_s = PARAM.globalv.global_out_dir + "deepks_stot.npy";
LCAO_deepks_io::save_npy_s(scs,
file_s,
file_stot,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_tot, w/ model

// wenfei add 2021/11/2
if (PARAM.inp.deepks_scf)
{
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
LCAO_deepks_io::save_npy_s(scs - svnl_dalpha,
file_sbase,
ucell.omega,
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;

if (!PARAM.inp.deepks_equiv) // training with stress label not supported by equivariant version now
{
torch::Tensor gdmepsl;
if (PARAM.globalv.gamma_only_local)
{
const std::vector<std::vector<double>>& dm_gamma
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();

GlobalC::ld.cal_gdmepsl(dm_gamma,
ucell,
orb,
gd,
kv.get_nks(),
kv.kvec_d,
GlobalC::ld.phialpha,
gdmepsl);
}
else
{
const std::vector<std::vector<std::complex<double>>>& dm_k
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec)
->get_DM()
->get_DMK_vector();

GlobalC::ld
.cal_gdmepsl(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmepsl);
}
if (PARAM.inp.deepks_out_unittest)
{
GlobalC::ld.check_gdmepsl(gdmepsl);
}

std::vector<torch::Tensor> gevdm;
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm);
torch::Tensor gvepsl;
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm, gdmepsl, gvepsl);

LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
GlobalC::ld.des_per_atom,
GlobalC::ld.gvepsl_tensor,
gvepsl,
PARAM.globalv.global_out_dir,
GlobalV::MY_RANK); // unitless, grad_vepsl
}
Expand Down
17 changes: 14 additions & 3 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,19 @@ void Force_LCAO<double>::ftable(const bool isforce,

#ifdef __DEEPKS
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
std::vector<torch::Tensor> descriptor;
if (PARAM.inp.deepks_scf)
{
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);

GlobalC::ld.cal_descriptor(ucell.nat);
GlobalC::ld.cal_gedm(ucell.nat);
DeePKS_domain::cal_descriptor(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.inl_l,
GlobalC::ld.pdm,
descriptor,
GlobalC::ld.des_per_atom);
GlobalC::ld.cal_gedm(ucell.nat, descriptor);

const int nks = 1;
DeePKS_domain::cal_f_delta<double>(dm_gamma,
Expand Down Expand Up @@ -305,7 +311,12 @@ void Force_LCAO<double>::ftable(const bool isforce,

GlobalC::ld.check_projected_dm();

GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);
DeePKS_domain::check_descriptor(GlobalC::ld.inlmax,
GlobalC::ld.des_per_atom,
GlobalC::ld.inl_l,
ucell,
PARAM.globalv.global_out_dir,
descriptor);

GlobalC::ld.check_gedm();

Expand Down
11 changes: 8 additions & 3 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,14 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);

GlobalC::ld.cal_descriptor(ucell.nat);

GlobalC::ld.cal_gedm(ucell.nat);
std::vector<torch::Tensor> descriptor;
DeePKS_domain::cal_descriptor(ucell.nat,
GlobalC::ld.inlmax,
GlobalC::ld.inl_l,
GlobalC::ld.pdm,
descriptor,
GlobalC::ld.des_per_atom);
GlobalC::ld.cal_gedm(ucell.nat, descriptor);

DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,
ucell,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::initialize_HR(const Grid_Driv
this->H_V_delta = new HContainer<TR>(paraV);
if (std::is_same<TK, double>::value)
{
//this->H_V_delta = new HContainer<TR>(paraV);
// this->H_V_delta = new HContainer<TR>(paraV);
this->H_V_delta->fix_gamma();
}

Expand Down Expand Up @@ -138,8 +138,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::initialize_HR(const Grid_Driv
// if (std::is_same<TK, double>::value)
// {
this->H_V_delta->allocate(nullptr, true);
// expand hR with H_V_delta
// update : for computational rigor, gamma-only and multi-k cases both have full size of Hamiltonian of DeePKS now
// expand hR with H_V_delta
// update : for computational rigor, gamma-only and multi-k cases both have full size of Hamiltonian of DeePKS now
this->hR->add(*this->H_V_delta);
this->hR->allocate(nullptr, false);
// }
Expand All @@ -161,8 +161,15 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
ModuleBase::timer::tick("DeePKS", "contributeHR");

GlobalC::ld.cal_projected_DM<TK>(this->DM, *this->ucell, *ptr_orb_, *(this->gd));
GlobalC::ld.cal_descriptor(this->ucell->nat);
GlobalC::ld.cal_gedm(this->ucell->nat);

std::vector<torch::Tensor> descriptor;
DeePKS_domain::cal_descriptor(this->ucell->nat,
GlobalC::ld.inlmax,
GlobalC::ld.inl_l,
GlobalC::ld.pdm,
descriptor,
GlobalC::ld.des_per_atom);
GlobalC::ld.cal_gedm(this->ucell->nat, descriptor);

// // recalculate the H_V_delta
// if (this->H_V_delta == nullptr)
Expand Down
10 changes: 5 additions & 5 deletions source/module_hamilt_lcao/module_deepks/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
if(ENABLE_DEEPKS)
list(APPEND objects
LCAO_deepks.cpp
deepks_descriptor.cpp
deepks_force.cpp
deepks_orbital.cpp
deepks_orbpre.cpp
deepks_vdpre.cpp
deepks_hmat.cpp
LCAO_deepks_io.cpp
LCAO_deepks_mpi.cpp
LCAO_deepks_pdm.cpp
LCAO_deepks_phialpha.cpp
LCAO_deepks_torch.cpp
LCAO_deepks_vdelta.cpp
deepks_hmat.cpp
LCAO_deepks_interface.cpp
deepks_orbpre.cpp
cal_gdmx.cpp
cal_gdmepsl.cpp
cal_gedm.cpp
cal_gvx.cpp
cal_descriptor.cpp
v_delta_precalc.cpp
)

add_library(
Expand Down
Loading

0 comments on commit e25db6e

Please sign in to comment.