Skip to content

Commit

Permalink
add ut for move flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Angel-Jia committed Oct 28, 2024
1 parent 2c3dd9e commit 0222d6d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
4 changes: 2 additions & 2 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,8 @@ def make_unlabeled_stru(
numerical descriptor file
mass : list of float, optional
List of atomic masses
move : list of list of bool, optional
List of the move flag of each xyz direction of each atom
move : list of (list of list of bool), optional
List of the move flag of each xyz direction of each atom for each frame
velocity : list of list of float, optional
List of the velocity of each xyz direction of each atom
mag : list of (list of float or float), optional
Expand Down
15 changes: 9 additions & 6 deletions dpdata/vasp/poscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def move_flag_mapper(flag):
elif flag == "F":
return False
else:
raise RuntimeError(f"Invalid selective dynamics flag: {flag}")
raise RuntimeError(f"Invalid move flag: {flag}")

"""Treat as cartesian poscar."""
system = {}
Expand All @@ -35,8 +35,13 @@ def move_flag_mapper(flag):
else:
tmpv = np.matmul(np.array(tmpv), system["cells"][0])
coord.append(tmpv)
if selective_dynamics and len(tmp) == 6:
move_flags.append(list(map(move_flag_mapper, tmp[3:])))
if selective_dynamics:
if len(tmp) == 6:
move_flags.append(list(map(move_flag_mapper, tmp[3:])))
else:
raise RuntimeError(
f"Invalid move flags, should be 6 columns, got {tmp}"
)

system["coords"] = [np.array(coord)]
system["orig"] = np.zeros(3)
Expand Down Expand Up @@ -118,11 +123,9 @@ def from_system_data(system, f_idx=0, skip_zeros=True):
move_flags = move[idx]
if isinstance(move_flags, list) and len(move_flags) == 3:
line += " " + " ".join(["T" if flag else "F" for flag in move_flags])
elif isinstance(move_flags, (int, float, bool)):
line += " " + " ".join(["T" if move_flags else "F"] * 3)
else:
raise RuntimeError(
f"Invalid move flags: {move_flags}, should be a list or a bool"
f"Invalid move flags: {move_flags}, should be a list of 3 bools"
)

posi_list.append(line)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_vasp_poscar_to_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ def test_move_flags(self):
self.assertTrue(np.array_equal(self.system["move"], expected))


class TestPOSCARCart(unittest.TestCase):
def test_move_flags_error1(self):
with self.assertRaisesRegex(RuntimeError, "Invalid move flags.*?"):
dpdata.System().from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.err1"))

def test_move_flags_error2(self):
with self.assertRaisesRegex(RuntimeError, "Invalid move flag: a"):
dpdata.System().from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.err2"))

def test_move_flags_error3(self):
system = dpdata.System().from_vasp_poscar(os.path.join("poscars", "POSCAR.oh.c"))
system.data["move"] = np.array([[[True, True], [False, False]]])
with self.assertRaisesRegex(RuntimeError, "Invalid move flags:.*?should be a list of 3 bools"):
system.to_vasp_poscar("POSCAR.tmp.1")

class TestPOSCARDirect(unittest.TestCase, TestPOSCARoh):
def setUp(self):
self.system = dpdata.System()
Expand Down

0 comments on commit 0222d6d

Please sign in to comment.