Skip to content

Commit

Permalink
run pre-commit
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed May 17, 2024
1 parent 04e584b commit 1d1b95f
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
on:
- push
- pull_request

name: Type checker
jobs:
pyright:
Expand Down
2 changes: 2 additions & 0 deletions docs/nb/try_dpdata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import dpdata"
]
},
Expand Down
1 change: 1 addition & 0 deletions dpdata/amber/mask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Amber mask."""

from __future__ import annotations

try:
Expand Down
2 changes: 1 addition & 1 deletion dpdata/bond_order_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs):
mol = fmtobj.from_bond_order_system(file_name, **kwargs)
self.from_rdkit_mol(mol)
if hasattr(fmtobj.from_bond_order_system, "post_func"):
for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore
for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore
self.post_funcs.get_plugin(post_f)(self)
return self

Expand Down
1 change: 1 addition & 0 deletions dpdata/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Command line interface for dpdata."""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions dpdata/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Driver plugin system."""

from __future__ import annotations

from abc import ABC, abstractmethod
Expand Down
1 change: 1 addition & 0 deletions dpdata/format.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implement the format plugin system."""

from __future__ import annotations

import os
Expand Down
1 change: 1 addition & 0 deletions dpdata/gaussian/gjf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# https://github.com/deepmodeling/dpgen/blob/0767dce7cad29367edb2e4a55fd0d8724dbda642/dpgen/generator/lib/gaussian.py#L1-L190
# under LGPL 3.0 license
"""Generate Gaussian input file."""

from __future__ import annotations

import itertools
Expand Down
1 change: 1 addition & 0 deletions dpdata/plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base of plugin systems."""

from __future__ import annotations


Expand Down
116 changes: 74 additions & 42 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ class System:
def __init__(
self,
# some formats do not use string as input
file_name: Any=None,
fmt: str="auto",
type_map:list[str] | None=None,
begin:int=0,
step:int=1,
data: dict[str, Any] | None=None,
convergence_check: bool=True,
file_name: Any = None,
fmt: str = "auto",
type_map: list[str] | None = None,
begin: int = 0,
step: int = 1,
data: dict[str, Any] | None = None,
convergence_check: bool = True,
**kwargs,
):
"""Constructor.
Expand Down Expand Up @@ -231,7 +231,7 @@ def check_data(self):

post_funcs = Plugin()

def from_fmt(self, file_name: Any, fmt: str="auto", **kwargs: Any):
def from_fmt(self, file_name: Any, fmt: str = "auto", **kwargs: Any):
fmt = fmt.lower()
if fmt == "auto":
fmt = os.path.basename(file_name).split(".")[-1].lower()
Expand All @@ -247,7 +247,7 @@ def from_fmt_obj(self, fmtobj: Format, file_name: Any, **kwargs: Any):
self.data = {**self.data, **data}
self.check_data()
if hasattr(fmtobj.from_system, "post_func"):
for post_f in fmtobj.from_system.post_func: # type: ignore
for post_f in fmtobj.from_system.post_func: # type: ignore
self.post_funcs.get_plugin(post_f)(self)
return self

Expand Down Expand Up @@ -289,20 +289,19 @@ def __str__(self):
return ret

@overload
def __getitem__(self, key: int | slice | list | np.ndarray) -> System:
...
def __getitem__(self, key: int | slice | list | np.ndarray) -> System: ...
@overload
def __getitem__(self, key: Literal["atom_names", "real_atom_names"]) -> list[str]:
...
def __getitem__(
self, key: Literal["atom_names", "real_atom_names"]
) -> list[str]: ...
@overload
def __getitem__(self, key: Literal["atom_numbs"]) -> list[int]:
...
def __getitem__(self, key: Literal["atom_numbs"]) -> list[int]: ...
@overload
def __getitem__(self, key: Literal["nopbc"]) -> bool:
...
def __getitem__(self, key: Literal["nopbc"]) -> bool: ...
@overload
def __getitem__(self, key: Literal["orig", "coords", "energies", "forces", "virials"]) -> np.ndarray:
...
def __getitem__(
self, key: Literal["orig", "coords", "energies", "forces", "virials"]
) -> np.ndarray: ...
@overload
def __getitem__(self, key: str) -> Any:
# other cases, for example customized data
Expand Down Expand Up @@ -333,13 +332,15 @@ def __add__(self, others):
raise RuntimeError("Unspported data structure")
return self.__class__.from_dict({"data": self_copy.data})

def dump(self, filename: str, indent: int=4):
def dump(self, filename: str, indent: int = 4):
"""Dump .json or .yaml file."""
from monty.serialization import dumpfn

dumpfn(self.as_dict(), filename, indent=indent)

def map_atom_types(self, type_map: dict[str, int] | list[str] | None=None) -> np.ndarray:
def map_atom_types(
self, type_map: dict[str, int] | list[str] | None = None
) -> np.ndarray:
"""Map the atom types of the system.
Parameters
Expand Down Expand Up @@ -456,7 +457,9 @@ def sub_system(self, f_idx: numbers.Integral) -> System:
continue
if tt.shape is not None and Axis.NFRAMES in tt.shape:
axis_nframes = tt.shape.index(Axis.NFRAMES)
new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape]
new_shape: list[slice | np.ndarray] = [
slice(None) for _ in self.data[tt.name].shape
]
new_shape[axis_nframes] = f_idx

Check failure on line 463 in dpdata/system.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__setitem__" match the provided arguments (reportCallIssue)
tmp.data[tt.name] = self.data[tt.name][tuple(new_shape)]
else:
Expand Down Expand Up @@ -520,7 +523,7 @@ def append(self, system: System) -> bool:
self.data["nopbc"] = False
return True

def convert_to_mixed_type(self, type_map:list[str] | None=None):
def convert_to_mixed_type(self, type_map: list[str] | None = None):
"""Convert the data dict to mixed type format structure, in order to append systems
with different formula but the same number of atoms. Change the 'atom_names' to
one placeholder type 'MIXED_TOKEN' and add 'real_atom_types' to store the real type
Expand All @@ -546,7 +549,7 @@ def convert_to_mixed_type(self, type_map:list[str] | None=None):
self.data["atom_numbs"] = [natoms]
self.data["atom_names"] = ["MIXED_TOKEN"]

def sort_atom_names(self, type_map:list[str] | None=None):
def sort_atom_names(self, type_map: list[str] | None = None):
"""Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
to atom_names. If type_map is not given, atom_names will be sorted by
alphabetical order. If type_map is given, atom_names will be type_map.
Expand All @@ -558,7 +561,7 @@ def sort_atom_names(self, type_map:list[str] | None=None):
"""
self.data = sort_atom_names(self.data, type_map=type_map)

def check_type_map(self, type_map:list[str] | None):
def check_type_map(self, type_map: list[str] | None):
"""Assign atom_names to type_map if type_map is given and different from
atom_names.
Expand Down Expand Up @@ -600,7 +603,9 @@ def sort_atom_types(self) -> np.ndarray:
continue
if tt.shape is not None and Axis.NATOMS in tt.shape:
axis_natoms = tt.shape.index(Axis.NATOMS)
new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape]
new_shape: list[slice | np.ndarray] = [
slice(None) for _ in self.data[tt.name].shape
]
new_shape[axis_natoms] = idx
self.data[tt.name] = self.data[tt.name][tuple(new_shape)]
return idx
Expand Down Expand Up @@ -686,7 +691,7 @@ def apply_pbc(self):
self.data["coords"] = np.matmul(ncoord, self.data["cells"])

@post_funcs.register("remove_pbc")
def remove_pbc(self, protect_layer: int=9):
def remove_pbc(self, protect_layer: int = 9):
"""This method does NOT delete the definition of the cells, it
(1) revises the cell to a cubic cell and ensures that the cell
boundary to any atom in the system is no less than `protect_layer`
Expand All @@ -701,7 +706,7 @@ def remove_pbc(self, protect_layer: int=9):
assert protect_layer >= 0, "the protect_layer should be no less than 0"
remove_pbc(self.data, protect_layer)

def affine_map(self, trans, f_idx: numbers.Integral=0):
def affine_map(self, trans, f_idx: numbers.Integral = 0):
assert np.linalg.det(trans) != 0
self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans)
self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans)
Expand All @@ -719,7 +724,7 @@ def rot_lower_triangular(self):
for ii in range(self.get_nframes()):
self.rot_frame_lower_triangular(ii)

def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0):
def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0):
qq, rr = np.linalg.qr(self.data["cells"][f_idx].T)
if np.linalg.det(qq) < 0:
qq = -qq
Expand Down Expand Up @@ -837,7 +842,11 @@ def replace(self, initial_atom_type: str, end_atom_type: str, replace_num: int):
self.sort_atom_types()

def perturb(
self, pert_num: int, cell_pert_fraction: float, atom_pert_distance: float, atom_pert_style: str="normal"
self,
pert_num: int,
cell_pert_fraction: float,
atom_pert_distance: float,
atom_pert_style: str = "normal",
):
"""Perturb each frame in the system randomly.
The cell will be deformed randomly, and atoms will be displaced by a random distance in random direction.
Expand Down Expand Up @@ -914,7 +923,9 @@ def shuffle(self):
self.data = self.sub_system(idx).data
return idx

def predict(self, *args: Any, driver: str | Driver = "dp", **kwargs: Any) -> LabeledSystem:
def predict(
self, *args: Any, driver: str | Driver = "dp", **kwargs: Any
) -> LabeledSystem:
"""Predict energies and forces by a driver.
Parameters
Expand Down Expand Up @@ -966,7 +977,7 @@ def minimize(
data = minimizer.minimize(self.data.copy())
return LabeledSystem(data=data)

def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None):
def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None):
"""Pick atom index.
Parameters
Expand All @@ -990,7 +1001,9 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None):
continue
if tt.shape is not None and Axis.NATOMS in tt.shape:
axis_natoms = tt.shape.index(Axis.NATOMS)
new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape]
new_shape: list[slice | np.ndarray] = [
slice(None) for _ in self.data[tt.name].shape
]
new_shape[axis_natoms] = idx
new_sys.data[tt.name] = self.data[tt.name][tuple(new_shape)]
# recalculate atom_numbs according to atom_types
Expand Down Expand Up @@ -1028,7 +1041,13 @@ def remove_atom_names(self, atom_names: str | Iterable[str]):
new_sys.data["atom_numbs"] = new_sys.data["atom_numbs"][: len(new_atom_names)]
return new_sys

def pick_by_amber_mask(self, param: str | parmed.Structure, maskstr: str, pass_coords: bool=False, nopbc: bool | None=None):
def pick_by_amber_mask(
self,
param: str | parmed.Structure,
maskstr: str,
pass_coords: bool = False,
nopbc: bool | None = None,
):
"""Pick atoms by amber mask.
Parameters
Expand Down Expand Up @@ -1093,7 +1112,10 @@ def get_cell_perturb_matrix(cell_pert_fraction: float):
return cell_pert_matrix


def get_atom_perturb_vector(atom_pert_distance: float, atom_pert_style: Literal["normal", "uniform", "const"]="normal"):
def get_atom_perturb_vector(
atom_pert_distance: float,
atom_pert_style: Literal["normal", "uniform", "const"] = "normal",
):
random_vector = None
if atom_pert_distance < 0:
raise RuntimeError("atom_pert_distance can not be negative")
Expand Down Expand Up @@ -1230,7 +1252,7 @@ def affine_map_fv(self, trans, f_idx: numbers.Integral):
trans.T, np.matmul(self.data["virials"][f_idx], trans)
)

def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0):
def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0):
trans = System.rot_frame_lower_triangular(self, f_idx=f_idx)
self.affine_map_fv(trans, f_idx=f_idx)
return trans
Expand Down Expand Up @@ -1322,7 +1344,9 @@ def __init__(self, *systems, type_map=None):
self.atom_names: list[str] = []
self.append(*systems)

def from_fmt_obj(self, fmtobj: Format, directory, labeled:bool=True, **kwargs: Any):
def from_fmt_obj(
self, fmtobj: Format, directory, labeled: bool = True, **kwargs: Any
):
if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat):
for dd in fmtobj.from_multi_systems(directory, **kwargs):
if labeled:
Expand Down Expand Up @@ -1415,7 +1439,13 @@ def from_file(cls, file_name, fmt: str, **kwargs: Any):
return multi_systems

@classmethod
def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: list[str] | None=None):
def from_dir(
cls,
dir_name: str,
file_name: str,
fmt: str = "auto",
type_map: list[str] | None = None,
):
multi_systems = cls()
target_file_list = sorted(
glob.glob(f"./{dir_name}/**/{file_name}", recursive=True)
Expand All @@ -1426,7 +1456,7 @@ def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: list
)
return multi_systems

def load_systems_from_file(self, file_name=None, fmt: str | None=None, **kwargs):
def load_systems_from_file(self, file_name=None, fmt: str | None = None, **kwargs):
assert fmt is not None
fmt = fmt.lower()
return self.from_fmt_obj(load_format(fmt), file_name, **kwargs)
Expand Down Expand Up @@ -1485,7 +1515,9 @@ def check_atom_names(self, system: System):
system.add_atom_names(new_in_self)
system.sort_atom_names(type_map=self.atom_names)

def predict(self, *args: Any, driver: str | Driver="dp", **kwargs: Any) -> MultiSystems:
def predict(
self, *args: Any, driver: str | Driver = "dp", **kwargs: Any
) -> MultiSystems:
"""Predict energies and forces by a driver.
Parameters
Expand Down Expand Up @@ -1544,7 +1576,7 @@ def minimize(
new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs))
return new_multisystems

def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None):
def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None):
"""Pick atom index.
Parameters
Expand Down Expand Up @@ -1599,7 +1631,7 @@ def correction(self, hl_sys: MultiSystems) -> MultiSystems:
for nn in self.systems.keys():
ll_ss = self[nn]
hl_ss = hl_sys[nn]
assert isinstance(ll_ss, LabeledSystem)
assert isinstance(ll_ss, LabeledSystem)
corrected_sys.append(ll_ss.correction(hl_ss))
return corrected_sys

Expand Down
Loading

0 comments on commit 1d1b95f

Please sign in to comment.