Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
 into hotfix
  • Loading branch information
Qianruipku committed Jan 5, 2025
2 parents 85b1951 + e25db6e commit 6f45435
Show file tree
Hide file tree
Showing 102 changed files with 4,634 additions and 3,419 deletions.
10 changes: 5 additions & 5 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -192,21 +192,21 @@ OBJS_CELL=atom_pseudo.o\

OBJS_DEEPKS=LCAO_deepks.o\
deepks_force.o\
deepks_descriptor.o\
deepks_orbital.o\
deepks_orbpre.o\
deepks_vdpre.o\
deepks_hmat.o\
LCAO_deepks_io.o\
LCAO_deepks_mpi.o\
LCAO_deepks_pdm.o\
LCAO_deepks_phialpha.o\
LCAO_deepks_torch.o\
LCAO_deepks_vdelta.o\
deepks_hmat.o\
LCAO_deepks_interface.o\
deepks_orbpre.o\
cal_gdmx.o\
cal_gdmepsl.o\
cal_gedm.o\
cal_gvx.o\
cal_descriptor.o\
v_delta_precalc.o\


OBJS_ELECSTAT=elecstate.o\
Expand Down
9 changes: 4 additions & 5 deletions source/module_cell/atom_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Atom
std::vector<bool> iw2_new;
int nw = 0; // number of local orbitals (l,n,m) of this type

void set_index(void);
void set_index();

int type = 0; // Index of atom type
int na = 0; // Number of atoms in this type.
Expand All @@ -34,8 +34,7 @@ class Atom

std::string label = "\0"; // atomic symbol
std::vector<ModuleBase::Vector3<double>> tau; // Cartesian coordinates of each atom in this type.
std::vector<ModuleBase::Vector3<double>>
dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
std::vector<ModuleBase::Vector3<double>> dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
std::vector<ModuleBase::Vector3<double>> taud; // Direct coordinates of each atom in this type.
std::vector<ModuleBase::Vector3<double>> vel; // velocities of each atom in this type.
std::vector<ModuleBase::Vector3<double>> force; // force acting on each atom in this type.
Expand All @@ -54,8 +53,8 @@ class Atom
void print_Atom(std::ofstream& ofs);
void update_force(ModuleBase::matrix& fcs);
#ifdef __MPI
void bcast_atom(void);
void bcast_atom2(void);
void bcast_atom();
void bcast_atom2();
#endif
};

Expand Down
5 changes: 1 addition & 4 deletions source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
std::ofstream& ofs_warning) {
return true;
}
void UnitCell::update_pos_taud(double* posd_in) {}
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {}
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {}
void UnitCell::bcast_atoms_tau() {}

bool UnitCell::judge_big_cell() const { return true; }
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
void UnitCell::update_force(ModuleBase::matrix& fcs) {}
Expand Down
2 changes: 1 addition & 1 deletion source/module_cell/test/unitcell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ TEST_F(UcellTest, UpdateVel)
{
vel_in[iat].set(iat * 0.1, iat * 0.1, iat * 0.1);
}
ucell->update_vel(vel_in);
unitcell::update_vel(vel_in,ucell->ntype,ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
EXPECT_DOUBLE_EQ(vel_in[iat].x, 0.1 * iat);
Expand Down
36 changes: 34 additions & 2 deletions source/module_cell/test/unitcell_test_para.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ TEST_F(UcellTest, UpdatePosTau)
}
delete[] pos_in;
}
TEST_F(UcellTest, UpdatePosTaud)
TEST_F(UcellTest, UpdatePosTaud_pointer)
{
double* pos_in = new double[ucell->nat * 3];
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
Expand All @@ -167,7 +167,8 @@ TEST_F(UcellTest, UpdatePosTaud)
ucell->iat2iait(iat, &ia, &it);
tmp[iat] = ucell->atoms[it].taud[ia];
}
ucell->update_pos_taud(pos_in);
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
Expand All @@ -180,6 +181,37 @@ TEST_F(UcellTest, UpdatePosTaud)
delete[] pos_in;
}

//test update_pos_taud with ModuleBase::Vector3<double> version
TEST_F(UcellTest, UpdatePosTaud_Vector3)
{
ModuleBase::Vector3<double>* pos_in = new ModuleBase::Vector3<double>[ucell->nat];
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
ucell->set_iat2itia();
for (int iat = 0; iat < ucell->nat; ++iat)
{
for (int ik = 0; ik < 3; ++ik)
{
pos_in[iat][ik] = 0.01;
}
int it=0;
int ia=0;
ucell->iat2iait(iat, &ia, &it);
tmp[iat] = ucell->atoms[it].taud[ia];
}
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
ucell->iat2iait(iat, &ia, &it);
for (int ik = 0; ik < 3; ++ik)
{
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia][ik], tmp[iat][ik] + 0.01);
}
}
delete[] tmp;
delete[] pos_in;
}
TEST_F(UcellTest, ReadPseudo)
{
PARAM.input.pseudo_dir = pp_dir;
Expand Down
59 changes: 0 additions & 59 deletions source/module_cell/unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,65 +314,6 @@ std::vector<ModuleBase::Vector3<int>> UnitCell::get_constrain() const
return constrain;
}



void UnitCell::update_pos_taud(double* posd_in) {
int iat = 0;
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
for (int ik = 0; ik < 3; ++ik) {
atom->taud[ia][ik] += posd_in[3 * iat + ik];
atom->dis[ia][ik] = posd_in[3 * iat + ik];
}
iat++;
}
}
assert(iat == this->nat);
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

// posd_in is atomic displacements here liuyu 2023-03-22
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {
int iat = 0;
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
for (int ik = 0; ik < 3; ++ik) {
atom->taud[ia][ik] += posd_in[iat][ik];
atom->dis[ia][ik] = posd_in[iat][ik];
}
iat++;
}
}
assert(iat == this->nat);
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {
int iat = 0;
for (int it = 0; it < this->ntype; ++it) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ++ia) {
this->atoms[it].vel[ia] = vel_in[iat];
++iat;
}
}
assert(iat == this->nat);
}


void UnitCell::bcast_atoms_tau() {
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
for (int i = 0; i < ntype; i++) {
atoms[i].bcast_atom(); // bcast tau array
}
#endif
}

//==============================================================
// Calculate various lattice related quantities for given latvec
//==============================================================
Expand Down
4 changes: 0 additions & 4 deletions source/module_cell/unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ class UnitCell {
void print_cell(std::ofstream& ofs) const;
void print_cell_xyz(const std::string& fn) const;

void update_pos_taud(const ModuleBase::Vector3<double>* posd_in);
void update_pos_taud(double* posd_in);
void update_vel(const ModuleBase::Vector3<double>* vel_in);
void bcast_atoms_tau();
bool judge_big_cell() const;

void update_stress(ModuleBase::matrix& scs); // updates stress
Expand Down
69 changes: 69 additions & 0 deletions source/module_cell/update_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,75 @@ void update_pos_tau(const Lattice& lat,
bcast_atoms_tau(atoms, ntype);
}

void update_pos_taud(const Lattice& lat,
const double* posd_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; it++)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ia++)
{
for (int ik = 0; ik < 3; ++ik)
{
atom->taud[ia][ik] += posd_in[3 * iat + ik];
atom->dis[ia][ik] = posd_in[3 * iat + ik];
}
iat++;
}
}
assert(iat == nat);
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
bcast_atoms_tau(atoms, ntype);
}

// posd_in is atomic displacements here liuyu 2023-03-22
void update_pos_taud(const Lattice& lat,
const ModuleBase::Vector3<double>* posd_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; it++)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ia++)
{
for (int ik = 0; ik < 3; ++ik)
{
atom->taud[ia][ik] += posd_in[iat][ik];
atom->dis[ia][ik] = posd_in[iat][ik];
}
iat++;
}
}
assert(iat == nat);
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
bcast_atoms_tau(atoms, ntype);
}

void update_vel(const ModuleBase::Vector3<double>* vel_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; ++it)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ++ia)
{
atoms[it].vel[ia] = vel_in[iat];
++iat;
}
}
assert(iat == nat);
}

void periodic_boundary_adjustment(Atom* atoms,
const ModuleBase::Matrix3& latvec,
const int ntype)
Expand Down
42 changes: 42 additions & 0 deletions source/module_cell/update_cell.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,48 @@ namespace unitcell
const int ntype,
const int nat,
Atom* atoms);

/**
* @brief update the position and tau of the atoms
*
* @param lat: the lattice of the atoms [in]
* @param pos_in: the position of the atoms in direct coordinate system [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_pos_taud(const Lattice& lat,
const double* posd_in,
const int ntype,
const int nat,
Atom* atoms);
/**
* @brief update the velocity of the atoms
*
* @param lat: the lattice of the atoms [in]
* @param pos_in: the position of the atoms in direct coordinate system
* in ModuleBase::Vector3 version [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_pos_taud(const Lattice& lat,
const ModuleBase::Vector3<double>* posd_in,
const int ntype,
const int nat,
Atom* atoms);
/**
* @brief update the velocity of the atoms
*
* @param vel_in: the velocity of the atoms [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_vel(const ModuleBase::Vector3<double>* vel_in,
const int ntype,
const int nat,
Atom* atoms);
}
//
#endif // UPDATE_CELL_H
13 changes: 10 additions & 3 deletions source/module_elecstate/cal_dm.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
//dm.fix_k(ik);
dm[ik].create(ParaV->ncol, ParaV->nrow);
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
psi::Psi<double> wg_wfc(wfc, 1);
psi::Psi<double> wg_wfc(1,
wfc.get_nbands(),
wfc.get_nbasis(),
wfc.get_nbasis(),
true);
wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size());

int ib_global = 0;
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
Expand All @@ -41,7 +46,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) continue;
if (ib_global >= wg.nc) { continue;
}
const double wg_local = wg(ik, ib_global);
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand Down Expand Up @@ -99,7 +105,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) continue;
if (ib_global >= wg.nc) { continue;
}
const double wg_local = wg(ik, ib_global);
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand Down
4 changes: 2 additions & 2 deletions source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)

this->init_rho_data();
int ik = psi.get_current_k();
int npw = psi.get_current_nbas();
int npw = psi.get_current_ngk();
int current_spin = 0;
if (PARAM.inp.nspin == 2)
{
Expand Down Expand Up @@ -287,7 +287,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
psi.fix_k(ik);
const T* psi_now = psi.get_pointer();
const int currect_spin = this->klist->isk[ik];
const int npw = psi.get_current_nbas();
const int npw = psi.get_current_ngk();

// get |beta>
if (this->ppcell->nkb > 0)
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/elecstate_pw_cal_tau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
int npw = psi.get_current_nbas();
int npw = psi.get_current_ngk();
int current_spin = 0;
if (PARAM.inp.nspin == 2)
{
Expand Down
Loading

0 comments on commit 6f45435

Please sign in to comment.