Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for converting move flags between Abacus and VASP #744

Merged
merged 14 commits into from
Oct 29, 2024
Merged
4 changes: 3 additions & 1 deletion dpdata/abacus/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_frame(fname):
with open_file(geometry_path_in) as fp:
geometry_inlines = fp.read().split("\n")
celldm, cell = get_cell(geometry_inlines)
atom_names, natoms, types, coords = get_coords(
atom_names, natoms, types, coords, move = get_coords(
celldm, cell, geometry_inlines, inlines
)
# This coords is not to be used.
Expand Down Expand Up @@ -221,5 +221,7 @@ def get_frame(fname):
data["spins"] = magmom
if len(magforce) > 0:
data["mag_forces"] = magforce
if len(move) > 0:
data["move"] = move

return data
4 changes: 3 additions & 1 deletion dpdata/abacus/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_frame(fname):
with open_file(geometry_path_in) as fp:
geometry_inlines = fp.read().split("\n")
celldm, cell = get_cell(geometry_inlines)
atom_names, natoms, types, coord_tmp = get_coords(
atom_names, natoms, types, coord_tmp, move = get_coords(
celldm, cell, geometry_inlines, inlines
)

Expand Down Expand Up @@ -218,5 +218,7 @@ def get_frame(fname):
data["spins"] = magmom
if len(magforce) > 0:
data["mag_forces"] = magforce
if len(move) > 0:
data["move"] = move

return data
21 changes: 15 additions & 6 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def get_coords(celldm, cell, geometry_inlines, inlines=None):
coords.append(xyz)
atom_types.append(it)

move.append(imove)
if imove is not None:
move.append(imove)
velocity.append(ivelocity)
mag.append(imagmom)
angle1.append(iangle1)
Expand All @@ -310,7 +311,8 @@ def get_coords(celldm, cell, geometry_inlines, inlines=None):
line_idx += 1
coords = np.array(coords) # need transformation!!!
atom_types = np.array(atom_types)
return atom_names, atom_numbs, atom_types, coords
move = np.array(move, dtype=bool)
return atom_names, atom_numbs, atom_types, coords, move


def get_energy(outlines):
Expand Down Expand Up @@ -477,7 +479,7 @@ def get_frame(fname):
outlines = fp.read().split("\n")

celldm, cell = get_cell(geometry_inlines)
atom_names, natoms, types, coords = get_coords(
atom_names, natoms, types, coords, move = get_coords(
celldm, cell, geometry_inlines, inlines
)
magmom, magforce = get_mag_force(outlines)
Expand Down Expand Up @@ -510,6 +512,8 @@ def get_frame(fname):
data["spins"] = magmom
if len(magforce) > 0:
data["mag_forces"] = magforce
if len(move) > 0:
data["move"] = move[np.newaxis, :, :]
# print("atom_names = ", data['atom_names'])
# print("natoms = ", data['atom_numbs'])
# print("types = ", data['atom_types'])
Expand Down Expand Up @@ -561,7 +565,7 @@ def get_frame_from_stru(fname):
nele = get_nele_from_stru(geometry_inlines)
inlines = [f"ntype {nele}"]
celldm, cell = get_cell(geometry_inlines)
atom_names, natoms, types, coords = get_coords(
atom_names, natoms, types, coords, move = get_coords(
celldm, cell, geometry_inlines, inlines
)
data = {}
Expand All @@ -571,6 +575,8 @@ def get_frame_from_stru(fname):
data["cells"] = cell[np.newaxis, :, :]
data["coords"] = coords[np.newaxis, :, :]
data["orig"] = np.zeros(3)
if len(move) > 0:
data["move"] = move[np.newaxis, :, :]

return data

Expand Down Expand Up @@ -609,8 +615,8 @@ def make_unlabeled_stru(
numerical descriptor file
mass : list of float, optional
List of atomic masses
move : list of list of bool, optional
List of the move flag of each xyz direction of each atom
move : list of (list of list of bool), optional
List of the move flag of each xyz direction of each atom for each frame
velocity : list of list of float, optional
List of the velocity of each xyz direction of each atom
mag : list of (list of float or float), optional
Expand Down Expand Up @@ -684,6 +690,9 @@ def process_file_input(file_input, atom_names, input_name):
if mag is None and data.get("spins") is not None and len(data["spins"]) > 0:
mag = data["spins"][frame_idx]

if move is None and data.get("move", None) is not None and len(data["move"]) > 0:
move = data["move"][frame_idx]

Comment on lines +693 to +695
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider optimizing the move flag check.

The current implementation checks both the existence and length of move data. This could be simplified.

Consider this more concise implementation:

-if move is None and data.get("move", None) is not None and len(data["move"]) > 0:
+if move is None and data.get("move", []):
     move = data["move"][frame_idx]

This achieves the same result but is more readable and handles empty lists naturally.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if move is None and data.get("move", None) is not None and len(data["move"]) > 0:
move = data["move"][frame_idx]
if move is None and data.get("move", []):
move = data["move"][frame_idx]

atom_numbs = sum(data["atom_numbs"])
for key in [move, velocity, mag, angle1, angle2, sc, lambda_]:
if key is not None:
Expand Down
15 changes: 15 additions & 0 deletions dpdata/plugins/abacus.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def register_mag_data(data):
dpdata.LabeledSystem.register_data_type(dt)


def register_move_data(data):
if "move" in data:
dt = DataType(
"move",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="move",
)
dpdata.System.register_data_type(dt)
Angel-Jia marked this conversation as resolved.
Show resolved Hide resolved


@Format.register("abacus/scf")
@Format.register("abacus/pw/scf")
@Format.register("abacus/lcao/scf")
Expand All @@ -75,6 +87,7 @@ class AbacusSCFFormat(Format):
def from_labeled_system(self, file_name, **kwargs):
data = dpdata.abacus.scf.get_frame(file_name)
register_mag_data(data)
register_move_data(data)
return data


Expand All @@ -86,6 +99,7 @@ class AbacusMDFormat(Format):
def from_labeled_system(self, file_name, **kwargs):
data = dpdata.abacus.md.get_frame(file_name)
register_mag_data(data)
register_move_data(data)
return data


Expand All @@ -97,4 +111,5 @@ class AbacusRelaxFormat(Format):
def from_labeled_system(self, file_name, **kwargs):
data = dpdata.abacus.relax.get_frame(file_name)
register_mag_data(data)
register_move_data(data)
return data
16 changes: 16 additions & 0 deletions dpdata/plugins/vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@
import dpdata.vasp.outcar
import dpdata.vasp.poscar
import dpdata.vasp.xml
from dpdata.data_type import Axis, DataType
from dpdata.format import Format
from dpdata.utils import open_file, uniq_atom_names

if TYPE_CHECKING:
from dpdata.utils import FileType


def register_move_data(data):
if "move" in data:
dt = DataType(
"move",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="move",
)
dpdata.System.register_data_type(dt)


@Format.register("poscar")
@Format.register("contcar")
@Format.register("vasp/poscar")
Expand All @@ -25,6 +38,7 @@ def from_system(self, file_name: FileType, **kwargs):
lines = [line.rstrip("\n") for line in fp]
data = dpdata.vasp.poscar.to_system_data(lines)
data = uniq_atom_names(data)
register_move_data(data)
return data

def to_system(self, data, file_name: FileType, frame_idx=0, **kwargs):
Expand Down Expand Up @@ -99,6 +113,7 @@ def from_labeled_system(
vol = np.linalg.det(np.reshape(data["cells"][ii], [3, 3]))
data["virials"][ii] *= v_pref * vol
data = uniq_atom_names(data)
register_move_data(data)
return data


Expand Down Expand Up @@ -135,4 +150,5 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
vol = np.linalg.det(np.reshape(data["cells"][ii], [3, 3]))
data["virials"][ii] *= v_pref * vol
data = uniq_atom_names(data)
register_move_data(data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Move flags are not supported in XML format

Based on the examination of the XML parsing implementation in dpdata/vasp/xml.py and the sample XML file, I can confirm that the XML format does not currently support move flags (selective dynamics). The analyze_vasprun function only extracts positions, cell parameters, energies, forces, and stresses from the XML file, with no handling of selective dynamics flags.

  • The registration of move data in from_labeled_system should be removed as it's not applicable for XML format
  • If selective dynamics support is needed for XML format, it would require implementing the extraction of these flags from the XML structure
🔗 Analysis chain

Verify move flag support in XML format.

While the registration is placed consistently, we should verify that the XML format supports move flags and that they are being properly extracted in dpdata.vasp.xml.analyze.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if XML parsing handles move flags
# Test: Search for move flag handling in xml.py
rg -A 5 "move|selective" "dpdata/vasp/xml.py"

Length of output: 244


Script:

#!/bin/bash
# Let's check for move flag handling in both xml.py and the test files
rg -A 5 "move|selective" "dpdata/vasp/"

# Also check test files specifically for XML move flag tests
rg -A 5 "xml.*move|move.*xml" "tests/"

# Check the actual implementation of from_labeled_system in vasp.py
ast-grep --pattern 'def from_labeled_system'

Length of output: 4755


Script:

#!/bin/bash
# Let's check the XML parsing implementation
ast-grep --pattern 'class VASPXMLFormat'

# Check if there are any XML test files that might contain move flags
fd -e xml . tests/

# Look for any move-related functionality in the XML format class
rg -A 5 "from_labeled_system|analyze_vasprun" "dpdata/vasp/xml.py"

Length of output: 403


Script:

#!/bin/bash
# Let's examine one of the test XML files to see if it contains selective dynamics
head -n 50 "tests/poscars/vasprun.h2o.md.xml"

# Check if there are any test cases related to selective dynamics in XML
rg -A 5 "selective.*xml|move.*xml" "tests/"

# Look at the implementation of analyze_vasprun in xml.py
cat "dpdata/vasp/xml.py"

Length of output: 6104

return data
55 changes: 50 additions & 5 deletions dpdata/vasp/poscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
import numpy as np


def _to_system_data_lower(lines, cartesian=True):
def _to_system_data_lower(lines, cartesian=True, selective_dynamics=False):
def move_flag_mapper(flag):
if flag == "T":
return True
elif flag == "F":
return False
else:
raise RuntimeError(f"Invalid move flag: {flag}")
Comment on lines +8 to +14
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Make move flag parsing case-insensitive

The function should handle both uppercase and lowercase flags as both are commonly used in VASP files.

Apply this diff:

 def move_flag_mapper(flag):
-    if flag == "T":
+    if flag.upper() == "T":
         return True
-    elif flag == "F":
+    elif flag.upper() == "F":
         return False
     else:
         raise RuntimeError(f"Invalid move flag: {flag}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def move_flag_mapper(flag):
if flag == "T":
return True
elif flag == "F":
return False
else:
raise RuntimeError(f"Invalid move flag: {flag}")
def move_flag_mapper(flag):
if flag.upper() == "T":
return True
elif flag.upper() == "F":
return False
else:
raise RuntimeError(f"Invalid move flag: {flag}")


"""Treat as cartesian poscar."""
system = {}
system["atom_names"] = [str(ii) for ii in lines[5].split()]
system["atom_numbs"] = [int(ii) for ii in lines[6].split()]
scale = float(lines[1])
cell = []
move_flags = []
for ii in range(2, 5):
boxv = [float(jj) for jj in lines[ii].split()]
boxv = np.array(boxv) * scale
Expand All @@ -19,12 +28,21 @@ def _to_system_data_lower(lines, cartesian=True):
natoms = sum(system["atom_numbs"])
coord = []
for ii in range(8, 8 + natoms):
tmpv = [float(jj) for jj in lines[ii].split()[:3]]
tmp = lines[ii].split()
tmpv = [float(jj) for jj in tmp[:3]]
if cartesian:
tmpv = np.array(tmpv) * scale
else:
tmpv = np.matmul(np.array(tmpv), system["cells"][0])
coord.append(tmpv)
if selective_dynamics:
if len(tmp) == 6:
move_flags.append(list(map(move_flag_mapper, tmp[3:])))
else:
raise RuntimeError(
f"Invalid move flags, should be 6 columns, got {tmp}"
)

system["coords"] = [np.array(coord)]
system["orig"] = np.zeros(3)
atom_types = []
Expand All @@ -34,20 +52,26 @@ def _to_system_data_lower(lines, cartesian=True):
system["atom_types"] = np.array(atom_types, dtype=int)
system["cells"] = np.array(system["cells"])
system["coords"] = np.array(system["coords"])
if move_flags:
move_flags = np.array(move_flags, dtype=bool)
move_flags = move_flags.reshape((1, natoms, 3))
system["move"] = np.array(move_flags, dtype=bool)
return system


def to_system_data(lines):
# remove the line that has 'selective dynamics'
selective_dynamics = False
if lines[7][0] == "S" or lines[7][0] == "s":
selective_dynamics = True
lines.pop(7)
is_cartesian = lines[7][0] in ["C", "c", "K", "k"]
if not is_cartesian:
if lines[7][0] not in ["d", "D"]:
raise RuntimeError(
"seem not to be a valid POSCAR of vasp 5.x, may be a POSCAR of vasp 4.x?"
)
return _to_system_data_lower(lines, is_cartesian)
return _to_system_data_lower(lines, is_cartesian, selective_dynamics)


def from_system_data(system, f_idx=0, skip_zeros=True):
Expand All @@ -72,6 +96,10 @@ def from_system_data(system, f_idx=0, skip_zeros=True):
continue
ret += "%d " % ii
ret += "\n"
move = system.get("move", None)
if move is not None and len(move) > 0:
ret += "Selective Dynamics\n"

# should use Cartesian for VESTA software
ret += "Cartesian\n"
atype = system["atom_types"]
Expand All @@ -81,9 +109,26 @@ def from_system_data(system, f_idx=0, skip_zeros=True):
sort_idx = np.lexsort((np.arange(len(atype)), atype))
atype = atype[sort_idx]
posis = posis[sort_idx]
if move is not None and len(move) > 0:
move = move[f_idx][sort_idx]

if isinstance(move, np.ndarray):
move = move.tolist()

Comment on lines +115 to +117
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Add test coverage for invalid move flags error handling

The review comment is correct. Based on the test files inspection:

  • There are only basic happy path tests for move flags in test_move_flags() and test_dump_move_flags()
  • No tests cover the error handling case when invalid move flags are provided
  • The error path raise RuntimeError("Invalid move flags...") is untested

Add test cases to cover:

  • Invalid move flags type (e.g., string, dict)
  • List with incorrect length (not 3)
  • None value
🔗 Analysis chain

Improve test coverage and simplify move flags formatting

  1. The error handling for invalid move flags (lines 124-126) is not covered by tests.
  2. The type checking and formatting logic could be simplified.

Let's verify the test coverage for invalid move flags:

Consider simplifying the move flags formatting:

-            if isinstance(move_flags, list) and len(move_flags) == 3:
-                line += " " + " ".join(["T" if flag else "F" for flag in move_flags])
-            elif isinstance(move_flags, (int, float, bool)):
-                line += " " + " ".join(["T" if move_flags else "F"] * 3)
-            else:
-                raise RuntimeError(
-                    f"Invalid move flags: {move_flags}, should be a list or a bool"
-                )
+            flags = move_flags if isinstance(move_flags, list) else [bool(move_flags)] * 3
+            if not isinstance(flags, list) or len(flags) != 3:
+                raise RuntimeError(f"Invalid move flags: {move_flags}, should be a list or a bool")
+            line += " " + " ".join("T" if flag else "F" for flag in flags)

Also applies to: 117-126

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check test coverage for invalid move flags error handling
rg -l "Invalid move flags.*should be a list or a bool" tests/

Length of output: 63


Script:

#!/bin/bash
# Let's check for test files related to POSCAR and move flags
fd "test.*poscar.*\.py" tests/

# Then look for any test cases involving move flags
rg -l "move.*flags?" tests/

Length of output: 196


Script:

#!/bin/bash
# Let's also check the actual implementation and error handling
rg -A 3 "Invalid move flags" dpdata/

Length of output: 260


Script:

#!/bin/bash
# Check test files content for move flags testing
ast-grep --pattern 'def test_$_($$$) {
  $$$
  move_$$$
  $$$
}'

Length of output: 66


Script:

#!/bin/bash
# Let's check the content of the test files for move flags testing
rg -B2 -A5 "move.*flags?" tests/test_vasp_poscar_dump.py tests/test_vasp_poscar_to_system.py

Length of output: 1286


Script:

#!/bin/bash
# Let's also check for any test cases that might be testing error conditions
rg "def test.*" tests/test_vasp_poscar_dump.py tests/test_vasp_poscar_to_system.py

Length of output: 409

posi_list = []
for ii in posis:
posi_list.append(f"{ii[0]:15.10f} {ii[1]:15.10f} {ii[2]:15.10f}")
for idx in range(len(posis)):
ii_posi = posis[idx]
line = f"{ii_posi[0]:15.10f} {ii_posi[1]:15.10f} {ii_posi[2]:15.10f}"
if move is not None and len(move) > 0:
move_flags = move[idx]
if not isinstance(move_flags, list) or len(move_flags) != 3:
raise RuntimeError(
f"Invalid move flags: {move_flags}, should be a list of 3 bools"
)
line += " " + " ".join("T" if flag else "F" for flag in move_flags)

posi_list.append(line)

posi_list.append("")
ret += "\n".join(posi_list)
return ret
8 changes: 4 additions & 4 deletions tests/abacus.scf/STRU.ch4
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ Cartesian #Cartesian(Unit is LATTICE_CONSTANT)
C #Name of element
0.0 #Magnetic for this element.
1 #Number of atoms
0.981274803 0.861285385 0.838442496 0 0 0
0.981274803 0.861285385 0.838442496 1 1 1
H
0.0
4
1.023557202 0.758025625 0.66351336 0 0 0
0.78075702 0.889445935 0.837363468 0 0 0
1.064091613 1.043438905 0.840995502 0 0 0
1.039321214 0.756530859 1.009609207 0 0 0
0.78075702 0.889445935 0.837363468 1 0 1
1.064091613 1.043438905 0.840995502 1 0 1
1.039321214 0.756530859 1.009609207 0 1 1
8 changes: 4 additions & 4 deletions tests/abacus.scf/stru_test
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ C
H
0.0
4
5.416431453540 4.011298860305 3.511161492417 1 1 1
4.131588222365 4.706745191323 4.431136645083 1 1 1
5.630930319126 5.521640894956 4.450356541303 1 1 1
5.499851012568 4.003388899277 5.342621842622 1 1 1
5.416431453540 4.011298860305 3.511161492417 0 0 0
4.131588222365 4.706745191323 4.431136645083 1 0 1
5.630930319126 5.521640894956 4.450356541303 1 0 1
5.499851012568 4.003388899277 5.342621842622 0 1 1
11 changes: 11 additions & 0 deletions tests/poscars/POSCAR.oh.err1
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Cubic BN
3.57
0.00 0.50 0.50
0.45 0.00 0.50
0.55 0.51 0.00
O H
1 1
Selective dynamics
Cartesian
0.00 0.00 0.00 T T F
0.25 0.25 0.25 F F
11 changes: 11 additions & 0 deletions tests/poscars/POSCAR.oh.err2
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Cubic BN
3.57
0.00 0.50 0.50
0.45 0.00 0.50
0.55 0.51 0.00
O H
1 1
Selective dynamics
Cartesian
0.00 0.00 0.00 T T F
0.25 0.25 0.25 a T F
Loading