Skip to content

Commit

Permalink
register the move shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Angel-Jia committed Oct 28, 2024
1 parent be2bb38 commit 80ec521
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 11 deletions.
9 changes: 5 additions & 4 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def get_coords(celldm, cell, geometry_inlines, inlines=None):
coords.append(xyz)
atom_types.append(it)

move.append(imove)
if imove is not None:
move.append(imove)
velocity.append(ivelocity)
mag.append(imagmom)
angle1.append(iangle1)
Expand Down Expand Up @@ -512,7 +513,7 @@ def get_frame(fname):
if len(magforce) > 0:
data["mag_forces"] = magforce
if len(move) > 0:
data["move"] = move
data["move"] = move[np.newaxis, :, :]
# print("atom_names = ", data['atom_names'])
# print("natoms = ", data['atom_numbs'])
# print("types = ", data['atom_types'])
Expand Down Expand Up @@ -575,7 +576,7 @@ def get_frame_from_stru(fname):
data["coords"] = coords[np.newaxis, :, :]
data["orig"] = np.zeros(3)
if len(move) > 0:
data["move"] = move
data["move"] = move[np.newaxis, :, :]

return data

Expand Down Expand Up @@ -690,7 +691,7 @@ def process_file_input(file_input, atom_names, input_name):
mag = data["spins"][frame_idx]

if move is None and data.get("move", None) is not None and len(data["move"]) > 0:
move = data["move"]
move = data["move"][frame_idx]

atom_numbs = sum(data["atom_numbs"])
for key in [move, velocity, mag, angle1, angle2, sc, lambda_]:
Expand Down
9 changes: 9 additions & 0 deletions dpdata/plugins/abacus.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def register_mag_data(data):
deepmd_name="force_mag",
)
dpdata.LabeledSystem.register_data_type(dt)
if "move" in data:
dt = DataType(
"move",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="move",
)
dpdata.System.register_data_type(dt)


@Format.register("abacus/scf")
Expand Down
16 changes: 16 additions & 0 deletions dpdata/plugins/vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,24 @@
import dpdata.vasp.xml
from dpdata.format import Format
from dpdata.utils import open_file, uniq_atom_names
from dpdata.data_type import Axis, DataType

if TYPE_CHECKING:
from dpdata.utils import FileType


def register_move_data(data):
if "move" in data:
dt = DataType(
"move",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="move",
)
dpdata.System.register_data_type(dt)


@Format.register("poscar")
@Format.register("contcar")
@Format.register("vasp/poscar")
Expand All @@ -25,6 +38,7 @@ def from_system(self, file_name: FileType, **kwargs):
lines = [line.rstrip("\n") for line in fp]
data = dpdata.vasp.poscar.to_system_data(lines)
data = uniq_atom_names(data)
register_move_data(data)
return data

def to_system(self, data, file_name: FileType, frame_idx=0, **kwargs):
Expand Down Expand Up @@ -99,6 +113,7 @@ def from_labeled_system(
vol = np.linalg.det(np.reshape(data["cells"][ii], [3, 3]))
data["virials"][ii] *= v_pref * vol
data = uniq_atom_names(data)
register_move_data(data)
return data


Expand Down Expand Up @@ -135,4 +150,5 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
vol = np.linalg.det(np.reshape(data["cells"][ii], [3, 3]))
data["virials"][ii] *= v_pref * vol
data = uniq_atom_names(data)
register_move_data(data)
return data
15 changes: 9 additions & 6 deletions dpdata/vasp/poscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def move_flag_mapper(flag):
system["atom_types"] = np.array(atom_types, dtype=int)
system["cells"] = np.array(system["cells"])
system["coords"] = np.array(system["coords"])
system["move"] = np.array(move_flags, dtype=bool)
if move_flags:
move_flags = np.array(move_flags, dtype=bool)
move_flags = move_flags.reshape((1, natoms, 3))
system["move"] = np.array(move_flags, dtype=bool)
return system


Expand Down Expand Up @@ -88,8 +91,8 @@ def from_system_data(system, f_idx=0, skip_zeros=True):
continue
ret += "%d " % ii
ret += "\n"
move = system.get("move", np.array([]))
if len(move) > 0:
move = system.get("move", None)
if move is not None and len(move) > 0:
ret += "Selective Dynamics\n"

# should use Cartesian for VESTA software
Expand All @@ -101,8 +104,8 @@ def from_system_data(system, f_idx=0, skip_zeros=True):
sort_idx = np.lexsort((np.arange(len(atype)), atype))
atype = atype[sort_idx]
posis = posis[sort_idx]
if len(move) > 0:
move = move[sort_idx]
if move is not None and len(move) > 0:
move = move[f_idx][sort_idx]

if isinstance(move, np.ndarray):
move = move.tolist()
Expand All @@ -111,7 +114,7 @@ def from_system_data(system, f_idx=0, skip_zeros=True):
for idx in range(len(posis)):
ii_posi = posis[idx]
line = f"{ii_posi[0]:15.10f} {ii_posi[1]:15.10f} {ii_posi[2]:15.10f}"
if len(move) > 0:
if move is not None and len(move) > 0:
move_flags = move[idx]
if isinstance(move_flags, list) and len(move_flags) == 3:
line += " " + " ".join(["T" if flag else "F" for flag in move_flags])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vasp_poscar_to_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def setUp(self):
self.system.from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.c"))

def test_move_flags(self):
expected = np.array([[True, True, False], [False, False, False]])
expected = np.array([[[True, True, False], [False, False, False]]])
self.assertTrue(np.array_equal(self.system["move"], expected))


Expand Down

0 comments on commit 80ec521

Please sign in to comment.