Skip to content

Commit

Permalink
Merge pull request #5 from thangckt/devel
Browse files Browse the repository at this point in the history
Update ase.py
  • Loading branch information
thangckt authored Apr 5, 2024
2 parents 6e2df4d + 188f945 commit 5d27076
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, Generator, List

import numpy as np

Expand Down Expand Up @@ -114,7 +114,7 @@ def from_multi_systems(
step: Optional[int] = None,
ase_fmt: Optional[str] = None,
**kwargs,
) -> object: # generator of "ase.Atoms"
) -> Generator["ase.Atoms"]:
"""Convert a ASE supported file to ASE Atoms.
It will finally be converted to MultiSystems.
Expand Down Expand Up @@ -142,7 +142,7 @@ def from_multi_systems(
frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
yield from frames

def to_system(self, data, **kwargs):
def to_system(self, data, **kwargs) -> List["ase.Atoms"]:
"""Convert System to ASE Atom obj."""
from ase import Atoms

Expand All @@ -160,7 +160,7 @@ def to_system(self, data, **kwargs):

return structures

def to_labeled_system(self, data, *args, **kwargs):
def to_labeled_system(self, data, *args, **kwargs) -> List["ase.Atoms"]:
"""Convert System to ASE Atoms object."""
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
Expand Down Expand Up @@ -298,20 +298,35 @@ def from_labeled_system(

return dict_frames

def to_system(self, data, **kwargs):
"""Convert System to ASE Atoms object."""
def to_system(self,
data,
file_name: str = "confs.traj",
**kwargs) -> None:
"""Convert System to ASE Atoms object.
Parameters
----------
file_name : str
path to file
"""
list_atoms = ASEStructureFormat().to_system(data, **kwargs)
file_name = kwargs.get("file_name", "conf.traj")
traj = Trajectory(file_name, "a")
traj = Trajectory(file_name, 'a')
_ = [traj.write(atom) for atom in list_atoms]
traj.close()
return

def to_labeled_system(self, data, *args, **kwargs):
"""Convert System to ASE Atoms object."""
def to_labeled_system(self,
data,
file_name: str = "labeled_confs.traj",
*args, **kwargs) -> None:
"""Convert System to ASE Atoms object.
Parameters
----------
file_name : str
path to file
"""
list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs)
file_name = kwargs.get("file_name", "labeled_conf.traj")
traj = Trajectory(file_name, "a")
traj = Trajectory(file_name, 'a')
_ = [traj.write(atom) for atom in list_atoms]
traj.close()
return
Expand Down

0 comments on commit 5d27076

Please sign in to comment.