Skip to content

Commit

Permalink
pref: lazy import modules (#658)
Browse files Browse the repository at this point in the history
Fix #526.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **Refactor**
- Removed unnecessary import statements and restructured import handling
for improved code organization and readability.
- Reorganized imports within functions to localize dependencies and
enhance code modularity.

- **New Features**
- Introduced conditional imports based on `TYPE_CHECKING` for better
resource management and efficiency.
- Added a new method `from_dict` to the `System` class for constructing
instances from a data dictionary.

- **Chores**
- Updated linting rules in `pyproject.toml` to include `TID253` for
banned module-level imports.
- Modified import statements in test files to comply with the new
linting rules for better code quality.

- **Style**
- Added `# noqa: TID253` comments to specific import statements to
adhere to new linting rules and ensure clean code styling.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored May 16, 2024
1 parent 02309f7 commit a7bf93d
Show file tree
Hide file tree
Showing 19 changed files with 128 additions and 107 deletions.
26 changes: 1 addition & 25 deletions dpdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,12 @@
# monty needs lzma
# See https://github.com/pandas-dev/pandas/pull/27882
try:
import lzma # noqa: F401
except ImportError:

class fakemodule:
pass

import sys

sys.modules["lzma"] = fakemodule

from . import lammps, md, vasp
from .bond_order_system import BondOrderSystem
from .system import LabeledSystem, MultiSystems, System

try:
from ._version import version as __version__
except ImportError:
from .__about__ import __version__

# BondOrder System has dependency on rdkit
try:
# prevent conflict with dpdata.rdkit
import rdkit as _ # noqa: F401

USE_RDKIT = True
except ModuleNotFoundError:
USE_RDKIT = False

if USE_RDKIT:
from .bond_order_system import BondOrderSystem

__all__ = [
"__version__",
"lammps",
Expand Down
3 changes: 2 additions & 1 deletion dpdata/amber/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re

import numpy as np
from scipy.io import netcdf_file

from dpdata.amber.mask import pick_by_amber_mask
from dpdata.unit import EnergyConversion
Expand Down Expand Up @@ -44,6 +43,8 @@ def read_amber_traj(
labeled : bool
Whether to return labeled data
"""
from scipy.io import netcdf_file

flag_atom_type = False
flag_atom_numb = False
amber_types = []
Expand Down
2 changes: 1 addition & 1 deletion dpdata/ase_calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional

from ase.calculators.calculator import (
from ase.calculators.calculator import ( # noqa: TID253
Calculator,
PropertyNotImplementedError,
all_changes,
Expand Down
3 changes: 2 additions & 1 deletion dpdata/bond_order_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from copy import deepcopy

import numpy as np
from rdkit.Chem import Conformer

import dpdata.rdkit.utils
from dpdata.rdkit.sanitize import Sanitizer
Expand Down Expand Up @@ -102,6 +101,8 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs):
return self

def to_fmt_obj(self, fmtobj, *args, **kwargs):
from rdkit.Chem import Conformer

self.rdkit_mol.RemoveAllConformers()
for ii in range(self.get_nframes()):
conf = Conformer()
Expand Down
11 changes: 6 additions & 5 deletions dpdata/deepmd/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

try:
import h5py
except ImportError:
pass
import numpy as np
from wcmatch.glob import globfilter

import dpdata

if TYPE_CHECKING:
import h5py

__all__ = ["to_system_data", "dump"]


Expand All @@ -35,6 +34,8 @@ def to_system_data(
labels : bool
labels
"""
from wcmatch.glob import globfilter

g = f[folder] if folder else f

data = {}
Expand Down
20 changes: 7 additions & 13 deletions dpdata/gaussian/gjf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

try:
from openbabel import openbabel
except ImportError:
try:
import openbabel
except ImportError:
openbabel = None
from dpdata.periodic_table import Element


Expand Down Expand Up @@ -53,10 +44,13 @@ def _crd2frag(symbols: List[str], crds: np.ndarray) -> Tuple[int, List[int]]:
ImportError
if Open Babel is not installed
"""
if openbabel is None:
raise ImportError(
"Open Babel (Python interface) should be installed to detect fragmentation!"
)
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

try:
from openbabel import openbabel
except ImportError:
import openbabel
atomnumber = len(symbols)
# Use openbabel to connect atoms
mol = openbabel.OBMol()
Expand Down
8 changes: 4 additions & 4 deletions dpdata/periodic_table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from pathlib import Path

from monty.serialization import loadfn

fpdt = str(Path(__file__).absolute().parent / "periodic_table.json")
_pdt = loadfn(fpdt)
fpdt = Path(__file__).absolute().parent / "periodic_table.json"
with fpdt.open("r") as fpdt:
_pdt = json.load(fpdt)
ELEMENTS = [
"H",
"He",
Expand Down
20 changes: 11 additions & 9 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@
from dpdata.driver import Driver, Minimizer
from dpdata.format import Format

try:
import ase.io
from ase.calculators.calculator import PropertyNotImplementedError
from ase.io import Trajectory

if TYPE_CHECKING:
from ase.optimize.optimize import Optimizer
except ImportError:
pass
if TYPE_CHECKING:
import ase
from ase.optimize.optimize import Optimizer


@Format.register("ase/structure")
Expand Down Expand Up @@ -84,6 +78,8 @@ def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict:
ASE will raise RuntimeError if the atoms does not
have a calculator
"""
from ase.calculators.calculator import PropertyNotImplementedError

info_dict = self.from_system(atoms)
try:
energies = atoms.get_potential_energy(force_consistent=True)
Expand Down Expand Up @@ -137,6 +133,8 @@ def from_multi_systems(
ase.Atoms
ASE atoms in the file
"""
import ase.io

frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
yield from frames

Expand Down Expand Up @@ -222,6 +220,8 @@ def from_system(
dict_frames: dict
a dictionary containing data of multiple frames
"""
from ase.io import Trajectory

traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]
dict_frames = ASEStructureFormat().from_system(sub_traj[0])
Expand Down Expand Up @@ -264,6 +264,8 @@ def from_labeled_system(
dict_frames: dict
a dictionary containing data of multiple frames
"""
from ase.io import Trajectory

traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]

Expand Down
16 changes: 12 additions & 4 deletions dpdata/plugins/deepmd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

try:
import h5py
except ImportError:
pass
import numpy as np

import dpdata
Expand All @@ -16,6 +13,9 @@
from dpdata.driver import Driver
from dpdata.format import Format

if TYPE_CHECKING:
import h5py


@Format.register("deepmd")
@Format.register("deepmd/raw")
Expand Down Expand Up @@ -202,6 +202,8 @@ def _from_system(
TypeError
file_name is not str or h5py.Group or h5py.File
"""
import h5py

if isinstance(file_name, (h5py.Group, h5py.File)):
return dpdata.deepmd.hdf5.to_system_data(
file_name, "", type_map=type_map, labels=labels
Expand Down Expand Up @@ -300,6 +302,8 @@ def to_system(
**kwargs : dict
other parameters
"""
import h5py

if isinstance(file_name, (h5py.Group, h5py.File)):
dpdata.deepmd.hdf5.dump(
file_name, "", data, set_size=set_size, comp_prec=comp_prec
Expand Down Expand Up @@ -330,6 +334,8 @@ def from_multi_systems(self, directory: str, **kwargs) -> h5py.Group:
h5py.Group
a HDF5 group in the HDF5 file
"""
import h5py

with h5py.File(directory, "r") as f:
for ff in f.keys():
yield f[ff]
Expand All @@ -353,6 +359,8 @@ def to_multi_systems(
h5py.Group
a HDF5 group with the name of formula
"""
import h5py

with h5py.File(directory, "w") as f:
for ff in formulas:
yield f.create_group(ff)
Expand Down
16 changes: 9 additions & 7 deletions dpdata/plugins/rdkit.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import dpdata.rdkit.utils
from dpdata.format import Format

try:
import rdkit.Chem

import dpdata.rdkit.utils
except ModuleNotFoundError:
pass


@Format.register("mol")
@Format.register("mol_file")
class MolFormat(Format):
def from_bond_order_system(self, file_name, **kwargs):
import rdkit.Chem

return rdkit.Chem.MolFromMolFile(file_name, sanitize=False, removeHs=False)

def to_bond_order_system(self, data, mol, file_name, frame_idx=0, **kwargs):
import rdkit.Chem

assert frame_idx < mol.GetNumConformers()
rdkit.Chem.MolToMolFile(mol, file_name, confId=frame_idx)

Expand All @@ -24,6 +22,8 @@ def to_bond_order_system(self, data, mol, file_name, frame_idx=0, **kwargs):
class SdfFormat(Format):
def from_bond_order_system(self, file_name, **kwargs):
"""Note that it requires all molecules in .sdf file must be of the same topology."""
import rdkit.Chem

mols = [
m
for m in rdkit.Chem.SDMolSupplier(file_name, sanitize=False, removeHs=False)
Expand All @@ -35,6 +35,8 @@ def from_bond_order_system(self, file_name, **kwargs):
return mol

def to_bond_order_system(self, data, mol, file_name, frame_idx=-1, **kwargs):
import rdkit.Chem

sdf_writer = rdkit.Chem.SDWriter(file_name)
if frame_idx == -1:
for ii in range(mol.GetNumConformers()):
Expand Down
10 changes: 4 additions & 6 deletions dpdata/pymatgen/molecule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import numpy as np

try:
from pymatgen.core import Molecule
except ImportError:
pass
from collections import Counter

import numpy as np


def to_system_data(file_name, protect_layer=9):
from pymatgen.core import Molecule

mol = Molecule.from_file(file_name)
elem_mol = list(str(site.species.elements[0]) for site in mol.sites)
elem_counter = Counter(elem_mol)
Expand Down
5 changes: 0 additions & 5 deletions dpdata/pymatgen/structure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import numpy as np

try:
from pymatgen.core import Structure # noqa: F401
except ImportError:
pass


def from_system_data(structure) -> dict:
symbols = [site.species_string for site in structure]
Expand Down
Loading

0 comments on commit a7bf93d

Please sign in to comment.