From 1a322ce5f22f442de47d1f6847d6a0cd6adcf5d1 Mon Sep 17 00:00:00 2001 From: Thang Nguyen <46436648+thangckt@users.noreply.github.com> Date: Fri, 5 Apr 2024 12:41:14 +0900 Subject: [PATCH] Update ase.py --- dpdata/plugins/ase.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 6f784f7b..ece79971 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional, Type, Generator, List import numpy as np @@ -112,7 +112,7 @@ def from_multi_systems( step: Optional[int] = None, ase_fmt: Optional[str] = None, **kwargs, - ) -> object: # generator of "ase.Atoms" + ) -> Generator["ase.Atoms"]: """Convert a ASE supported file to ASE Atoms. It will finally be converted to MultiSystems. @@ -140,7 +140,7 @@ def from_multi_systems( frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step)) yield from frames - def to_system(self, data, **kwargs): + def to_system(self, data, **kwargs) -> List["ase.Atoms"]: """Convert System to ASE Atom obj.""" from ase import Atoms @@ -158,7 +158,7 @@ def to_system(self, data, **kwargs): return structures - def to_labeled_system(self, data, *args, **kwargs): + def to_labeled_system(self, data, *args, **kwargs) -> List["ase.Atoms"]: """Convert System to ASE Atoms object.""" from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator @@ -296,19 +296,34 @@ def from_labeled_system( return dict_frames - def to_system(self, data, **kwargs): - """Convert System to ASE Atoms object.""" + def to_system(self, + data, + file_name: str = "confs.traj", + **kwargs) -> None: + """Convert System to ASE Atoms object. + + Parameters + ---------- + file_name : str + path to file + """ list_atoms = ASEStructureFormat().to_system(data, **kwargs) - file_name = kwargs.get("file_name", "conf.traj") traj = Trajectory(file_name, 'a') _ = [traj.write(atom) for atom in list_atoms] traj.close() return - def to_labeled_system(self, data, *args, **kwargs): - """Convert System to ASE Atoms object.""" + def to_labeled_system(self, + data, + file_name: str = "labeled_confs.traj", + *args, **kwargs) -> None: + """Convert System to ASE Atoms object. + Parameters + ---------- + file_name : str + path to file + """ list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs) - file_name = kwargs.get("file_name", "labeled_conf.traj") traj = Trajectory(file_name, 'a') _ = [traj.write(atom) for atom in list_atoms] traj.close()