Skip to content

Commit

Permalink
Merge branch 'develop' into refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ErjieWu authored Jan 5, 2025
2 parents d3d5b7c + 4ddec65 commit 43a3b8b
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 435 deletions.
9 changes: 4 additions & 5 deletions source/module_cell/atom_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Atom
std::vector<bool> iw2_new;
int nw = 0; // number of local orbitals (l,n,m) of this type

void set_index(void);
void set_index();

int type = 0; // Index of atom type
int na = 0; // Number of atoms in this type.
Expand All @@ -34,8 +34,7 @@ class Atom

std::string label = "\0"; // atomic symbol
std::vector<ModuleBase::Vector3<double>> tau; // Cartesian coordinates of each atom in this type.
std::vector<ModuleBase::Vector3<double>>
dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
std::vector<ModuleBase::Vector3<double>> dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
std::vector<ModuleBase::Vector3<double>> taud; // Direct coordinates of each atom in this type.
std::vector<ModuleBase::Vector3<double>> vel; // velocities of each atom in this type.
std::vector<ModuleBase::Vector3<double>> force; // force acting on each atom in this type.
Expand All @@ -54,8 +53,8 @@ class Atom
void print_Atom(std::ofstream& ofs);
void update_force(ModuleBase::matrix& fcs);
#ifdef __MPI
void bcast_atom(void);
void bcast_atom2(void);
void bcast_atom();
void bcast_atom2();
#endif
};

Expand Down
5 changes: 1 addition & 4 deletions source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
std::ofstream& ofs_warning) {
return true;
}
void UnitCell::update_pos_taud(double* posd_in) {}
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {}
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {}
void UnitCell::bcast_atoms_tau() {}

bool UnitCell::judge_big_cell() const { return true; }
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
void UnitCell::update_force(ModuleBase::matrix& fcs) {}
Expand Down
2 changes: 1 addition & 1 deletion source/module_cell/test/unitcell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ TEST_F(UcellTest, UpdateVel)
{
vel_in[iat].set(iat * 0.1, iat * 0.1, iat * 0.1);
}
ucell->update_vel(vel_in);
unitcell::update_vel(vel_in,ucell->ntype,ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
EXPECT_DOUBLE_EQ(vel_in[iat].x, 0.1 * iat);
Expand Down
36 changes: 34 additions & 2 deletions source/module_cell/test/unitcell_test_para.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ TEST_F(UcellTest, UpdatePosTau)
}
delete[] pos_in;
}
TEST_F(UcellTest, UpdatePosTaud)
TEST_F(UcellTest, UpdatePosTaud_pointer)
{
double* pos_in = new double[ucell->nat * 3];
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
Expand All @@ -167,7 +167,8 @@ TEST_F(UcellTest, UpdatePosTaud)
ucell->iat2iait(iat, &ia, &it);
tmp[iat] = ucell->atoms[it].taud[ia];
}
ucell->update_pos_taud(pos_in);
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
Expand All @@ -180,6 +181,37 @@ TEST_F(UcellTest, UpdatePosTaud)
delete[] pos_in;
}

//test update_pos_taud with ModuleBase::Vector3<double> version
TEST_F(UcellTest, UpdatePosTaud_Vector3)
{
ModuleBase::Vector3<double>* pos_in = new ModuleBase::Vector3<double>[ucell->nat];
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
ucell->set_iat2itia();
for (int iat = 0; iat < ucell->nat; ++iat)
{
for (int ik = 0; ik < 3; ++ik)
{
pos_in[iat][ik] = 0.01;
}
int it=0;
int ia=0;
ucell->iat2iait(iat, &ia, &it);
tmp[iat] = ucell->atoms[it].taud[ia];
}
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
ucell->nat,ucell->atoms);
for (int iat = 0; iat < ucell->nat; ++iat)
{
int it, ia;
ucell->iat2iait(iat, &ia, &it);
for (int ik = 0; ik < 3; ++ik)
{
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia][ik], tmp[iat][ik] + 0.01);
}
}
delete[] tmp;
delete[] pos_in;
}
TEST_F(UcellTest, ReadPseudo)
{
PARAM.input.pseudo_dir = pp_dir;
Expand Down
59 changes: 0 additions & 59 deletions source/module_cell/unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,65 +314,6 @@ std::vector<ModuleBase::Vector3<int>> UnitCell::get_constrain() const
return constrain;
}



void UnitCell::update_pos_taud(double* posd_in) {
int iat = 0;
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
for (int ik = 0; ik < 3; ++ik) {
atom->taud[ia][ik] += posd_in[3 * iat + ik];
atom->dis[ia][ik] = posd_in[3 * iat + ik];
}
iat++;
}
}
assert(iat == this->nat);
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

// posd_in is atomic displacements here liuyu 2023-03-22
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {
int iat = 0;
for (int it = 0; it < this->ntype; it++) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ia++) {
for (int ik = 0; ik < 3; ++ik) {
atom->taud[ia][ik] += posd_in[iat][ik];
atom->dis[ia][ik] = posd_in[iat][ik];
}
iat++;
}
}
assert(iat == this->nat);
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
this->bcast_atoms_tau();
}

void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {
int iat = 0;
for (int it = 0; it < this->ntype; ++it) {
Atom* atom = &this->atoms[it];
for (int ia = 0; ia < atom->na; ++ia) {
this->atoms[it].vel[ia] = vel_in[iat];
++iat;
}
}
assert(iat == this->nat);
}


void UnitCell::bcast_atoms_tau() {
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
for (int i = 0; i < ntype; i++) {
atoms[i].bcast_atom(); // bcast tau array
}
#endif
}

//==============================================================
// Calculate various lattice related quantities for given latvec
//==============================================================
Expand Down
4 changes: 0 additions & 4 deletions source/module_cell/unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ class UnitCell {
void print_cell(std::ofstream& ofs) const;
void print_cell_xyz(const std::string& fn) const;

void update_pos_taud(const ModuleBase::Vector3<double>* posd_in);
void update_pos_taud(double* posd_in);
void update_vel(const ModuleBase::Vector3<double>* vel_in);
void bcast_atoms_tau();
bool judge_big_cell() const;

void update_stress(ModuleBase::matrix& scs); // updates stress
Expand Down
69 changes: 69 additions & 0 deletions source/module_cell/update_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,75 @@ void update_pos_tau(const Lattice& lat,
bcast_atoms_tau(atoms, ntype);
}

void update_pos_taud(const Lattice& lat,
const double* posd_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; it++)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ia++)
{
for (int ik = 0; ik < 3; ++ik)
{
atom->taud[ia][ik] += posd_in[3 * iat + ik];
atom->dis[ia][ik] = posd_in[3 * iat + ik];
}
iat++;
}
}
assert(iat == nat);
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
bcast_atoms_tau(atoms, ntype);
}

// posd_in is atomic displacements here liuyu 2023-03-22
void update_pos_taud(const Lattice& lat,
const ModuleBase::Vector3<double>* posd_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; it++)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ia++)
{
for (int ik = 0; ik < 3; ++ik)
{
atom->taud[ia][ik] += posd_in[iat][ik];
atom->dis[ia][ik] = posd_in[iat][ik];
}
iat++;
}
}
assert(iat == nat);
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
bcast_atoms_tau(atoms, ntype);
}

void update_vel(const ModuleBase::Vector3<double>* vel_in,
const int ntype,
const int nat,
Atom* atoms)
{
int iat = 0;
for (int it = 0; it < ntype; ++it)
{
Atom* atom = &atoms[it];
for (int ia = 0; ia < atom->na; ++ia)
{
atoms[it].vel[ia] = vel_in[iat];
++iat;
}
}
assert(iat == nat);
}

void periodic_boundary_adjustment(Atom* atoms,
const ModuleBase::Matrix3& latvec,
const int ntype)
Expand Down
42 changes: 42 additions & 0 deletions source/module_cell/update_cell.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,48 @@ namespace unitcell
const int ntype,
const int nat,
Atom* atoms);

/**
* @brief update the position and tau of the atoms
*
* @param lat: the lattice of the atoms [in]
* @param pos_in: the position of the atoms in direct coordinate system [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_pos_taud(const Lattice& lat,
const double* posd_in,
const int ntype,
const int nat,
Atom* atoms);
/**
* @brief update the velocity of the atoms
*
* @param lat: the lattice of the atoms [in]
* @param pos_in: the position of the atoms in direct coordinate system
* in ModuleBase::Vector3 version [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_pos_taud(const Lattice& lat,
const ModuleBase::Vector3<double>* posd_in,
const int ntype,
const int nat,
Atom* atoms);
/**
* @brief update the velocity of the atoms
*
* @param vel_in: the velocity of the atoms [in]
* @param ntype: the number of types of the atoms [in]
* @param nat: the number of atoms [in]
* @param atoms: the atoms to be updated [out]
*/
void update_vel(const ModuleBase::Vector3<double>* vel_in,
const int ntype,
const int nat,
Atom* atoms);
}
//
#endif // UPDATE_CELL_H
5 changes: 2 additions & 3 deletions source/module_md/md_base.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#include "md_base.h"

#include "md_func.h"
#ifdef __MPI
#include "mpi.h"
#endif
#include "module_io/print_info.h"

#include "module_cell/update_cell.h"
MD_base::MD_base(const Parameter& param_in, UnitCell& unit_in) : mdp(param_in.mdp), ucell(unit_in)
{
my_rank = param_in.globalv.myrank;
Expand Down Expand Up @@ -112,7 +111,7 @@ void MD_base::update_pos()
MPI_Bcast(pos, ucell.nat * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD);
#endif

ucell.update_pos_taud(pos);
unitcell::update_pos_taud(ucell.lat,pos,ucell.ntype,ucell.nat,ucell.atoms);

return;
}
Expand Down
4 changes: 2 additions & 2 deletions source/module_md/run_md.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "msst.h"
#include "nhchain.h"
#include "verlet.h"

#include "module_cell/update_cell.h"
namespace Run_MD
{

Expand Down Expand Up @@ -97,7 +97,7 @@ void md_line(UnitCell& unit_in, ModuleESolver::ESolver* p_esolver, const Paramet

if ((mdrun->step_ + mdrun->step_rst_) % param_in.mdp.md_restartfreq == 0)
{
unit_in.update_vel(mdrun->vel);
unitcell::update_vel(mdrun->vel,unit_in.ntype,unit_in.nat,unit_in.atoms);
std::stringstream file;
file << PARAM.globalv.global_stru_dir << "STRU_MD_" << mdrun->step_ + mdrun->step_rst_;
// changelog 20240509
Expand Down
2 changes: 1 addition & 1 deletion source/module_relax/relax_new/relax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ void Relax::move_cell_ions(UnitCell& ucell, const bool is_new_dir)
ucell.symm.symmetrize_vec3_nat(move_ion);
}

ucell.update_pos_taud(move_ion);
unitcell::update_pos_taud(ucell.lat,move_ion,ucell.ntype,ucell.nat,ucell.atoms);

// Print the structure file.
ucell.print_tau();
Expand Down
Loading

0 comments on commit 43a3b8b

Please sign in to comment.