From a7bf93d79c6acdf7ac2a8b6f7a76bb09d6dd3a61 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 02:34:48 -0400 Subject: [PATCH] pref: lazy import modules (#658) Fix #526. ## Summary by CodeRabbit - **Refactor** - Removed unnecessary import statements and restructured import handling for improved code organization and readability. - Reorganized imports within functions to localize dependencies and enhance code modularity. - **New Features** - Introduced conditional imports based on `TYPE_CHECKING` for better resource management and efficiency. - Added a new method `from_dict` to the `System` class for constructing instances from a data dictionary. - **Chores** - Updated linting rules in `pyproject.toml` to include `TID253` for banned module-level imports. - Modified import statements in test files to comply with the new linting rules for better code quality. - **Style** - Added `# noqa: TID253` comments to specific import statements to adhere to new linting rules and ensure clean code styling. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dpdata/__init__.py | 26 +------------------ dpdata/amber/md.py | 3 ++- dpdata/ase_calculator.py | 2 +- dpdata/bond_order_system.py | 3 ++- dpdata/deepmd/hdf5.py | 11 ++++---- dpdata/gaussian/gjf.py | 20 +++++--------- dpdata/periodic_table.py | 8 +++--- dpdata/plugins/ase.py | 20 +++++++------- dpdata/plugins/deepmd.py | 16 +++++++++--- dpdata/plugins/rdkit.py | 16 +++++++----- dpdata/pymatgen/molecule.py | 10 +++---- dpdata/pymatgen/structure.py | 5 ---- dpdata/rdkit/sanitize.py | 46 +++++++++++++++++++++------------ dpdata/rdkit/utils.py | 8 +++--- dpdata/system.py | 20 +++++++++++--- dpdata/unit.py | 2 +- pyproject.toml | 15 +++++++++++ tests/test_custom_data_type.py | 2 +- tests/test_to_pymatgen_entry.py | 2 +- 19 files changed, 128 insertions(+), 107 deletions(-) diff --git a/dpdata/__init__.py b/dpdata/__init__.py index dd853697..847554d3 100644 --- a/dpdata/__init__.py +++ b/dpdata/__init__.py @@ -1,17 +1,5 @@ -# monty needs lzma -# See https://github.com/pandas-dev/pandas/pull/27882 -try: - import lzma # noqa: F401 -except ImportError: - - class fakemodule: - pass - - import sys - - sys.modules["lzma"] = fakemodule - from . import lammps, md, vasp +from .bond_order_system import BondOrderSystem from .system import LabeledSystem, MultiSystems, System try: @@ -19,18 +7,6 @@ class fakemodule: except ImportError: from .__about__ import __version__ -# BondOrder System has dependency on rdkit -try: - # prevent conflict with dpdata.rdkit - import rdkit as _ # noqa: F401 - - USE_RDKIT = True -except ModuleNotFoundError: - USE_RDKIT = False - -if USE_RDKIT: - from .bond_order_system import BondOrderSystem - __all__ = [ "__version__", "lammps", diff --git a/dpdata/amber/md.py b/dpdata/amber/md.py index 279ce55e..91240121 100644 --- a/dpdata/amber/md.py +++ b/dpdata/amber/md.py @@ -2,7 +2,6 @@ import re import numpy as np -from scipy.io import netcdf_file from dpdata.amber.mask import pick_by_amber_mask from dpdata.unit import EnergyConversion @@ -44,6 +43,8 @@ def read_amber_traj( labeled : bool Whether to return labeled data """ + from scipy.io import netcdf_file + flag_atom_type = False flag_atom_numb = False amber_types = [] diff --git a/dpdata/ase_calculator.py b/dpdata/ase_calculator.py index 65a462a5..c0579978 100644 --- a/dpdata/ase_calculator.py +++ b/dpdata/ase_calculator.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List, Optional -from ase.calculators.calculator import ( +from ase.calculators.calculator import ( # noqa: TID253 Calculator, PropertyNotImplementedError, all_changes, diff --git a/dpdata/bond_order_system.py b/dpdata/bond_order_system.py index e2449ee6..1b6f903d 100644 --- a/dpdata/bond_order_system.py +++ b/dpdata/bond_order_system.py @@ -3,7 +3,6 @@ from copy import deepcopy import numpy as np -from rdkit.Chem import Conformer import dpdata.rdkit.utils from dpdata.rdkit.sanitize import Sanitizer @@ -102,6 +101,8 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): return self def to_fmt_obj(self, fmtobj, *args, **kwargs): + from rdkit.Chem import Conformer + self.rdkit_mol.RemoveAllConformers() for ii in range(self.get_nframes()): conf = Conformer() diff --git a/dpdata/deepmd/hdf5.py b/dpdata/deepmd/hdf5.py index 992a1344..34ae9dbe 100644 --- a/dpdata/deepmd/hdf5.py +++ b/dpdata/deepmd/hdf5.py @@ -3,16 +3,15 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING -try: - import h5py -except ImportError: - pass import numpy as np -from wcmatch.glob import globfilter import dpdata +if TYPE_CHECKING: + import h5py + __all__ = ["to_system_data", "dump"] @@ -35,6 +34,8 @@ def to_system_data( labels : bool labels """ + from wcmatch.glob import globfilter + g = f[folder] if folder else f data = {} diff --git a/dpdata/gaussian/gjf.py b/dpdata/gaussian/gjf.py index 21300a60..90aaf2f0 100644 --- a/dpdata/gaussian/gjf.py +++ b/dpdata/gaussian/gjf.py @@ -10,16 +10,7 @@ from typing import List, Optional, Tuple, Union import numpy as np -from scipy.sparse import csr_matrix -from scipy.sparse.csgraph import connected_components -try: - from openbabel import openbabel -except ImportError: - try: - import openbabel - except ImportError: - openbabel = None from dpdata.periodic_table import Element @@ -53,10 +44,13 @@ def _crd2frag(symbols: List[str], crds: np.ndarray) -> Tuple[int, List[int]]: ImportError if Open Babel is not installed """ - if openbabel is None: - raise ImportError( - "Open Babel (Python interface) should be installed to detect fragmentation!" - ) + from scipy.sparse import csr_matrix + from scipy.sparse.csgraph import connected_components + + try: + from openbabel import openbabel + except ImportError: + import openbabel atomnumber = len(symbols) # Use openbabel to connect atoms mol = openbabel.OBMol() diff --git a/dpdata/periodic_table.py b/dpdata/periodic_table.py index b05a2cfb..6df1fd41 100644 --- a/dpdata/periodic_table.py +++ b/dpdata/periodic_table.py @@ -1,9 +1,9 @@ +import json from pathlib import Path -from monty.serialization import loadfn - -fpdt = str(Path(__file__).absolute().parent / "periodic_table.json") -_pdt = loadfn(fpdt) +fpdt = Path(__file__).absolute().parent / "periodic_table.json" +with fpdt.open("r") as fpdt: + _pdt = json.load(fpdt) ELEMENTS = [ "H", "He", diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 127611f6..f3347c99 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -6,15 +6,9 @@ from dpdata.driver import Driver, Minimizer from dpdata.format import Format -try: - import ase.io - from ase.calculators.calculator import PropertyNotImplementedError - from ase.io import Trajectory - - if TYPE_CHECKING: - from ase.optimize.optimize import Optimizer -except ImportError: - pass +if TYPE_CHECKING: + import ase + from ase.optimize.optimize import Optimizer @Format.register("ase/structure") @@ -84,6 +78,8 @@ def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict: ASE will raise RuntimeError if the atoms does not have a calculator """ + from ase.calculators.calculator import PropertyNotImplementedError + info_dict = self.from_system(atoms) try: energies = atoms.get_potential_energy(force_consistent=True) @@ -137,6 +133,8 @@ def from_multi_systems( ase.Atoms ASE atoms in the file """ + import ase.io + frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step)) yield from frames @@ -222,6 +220,8 @@ def from_system( dict_frames: dict a dictionary containing data of multiple frames """ + from ase.io import Trajectory + traj = Trajectory(file_name) sub_traj = traj[begin:end:step] dict_frames = ASEStructureFormat().from_system(sub_traj[0]) @@ -264,6 +264,8 @@ def from_labeled_system( dict_frames: dict a dictionary containing data of multiple frames """ + from ase.io import Trajectory + traj = Trajectory(file_name) sub_traj = traj[begin:end:step] diff --git a/dpdata/plugins/deepmd.py b/dpdata/plugins/deepmd.py index 9daadcf3..1ed79f72 100644 --- a/dpdata/plugins/deepmd.py +++ b/dpdata/plugins/deepmd.py @@ -1,11 +1,8 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING -try: - import h5py -except ImportError: - pass import numpy as np import dpdata @@ -16,6 +13,9 @@ from dpdata.driver import Driver from dpdata.format import Format +if TYPE_CHECKING: + import h5py + @Format.register("deepmd") @Format.register("deepmd/raw") @@ -202,6 +202,8 @@ def _from_system( TypeError file_name is not str or h5py.Group or h5py.File """ + import h5py + if isinstance(file_name, (h5py.Group, h5py.File)): return dpdata.deepmd.hdf5.to_system_data( file_name, "", type_map=type_map, labels=labels @@ -300,6 +302,8 @@ def to_system( **kwargs : dict other parameters """ + import h5py + if isinstance(file_name, (h5py.Group, h5py.File)): dpdata.deepmd.hdf5.dump( file_name, "", data, set_size=set_size, comp_prec=comp_prec @@ -330,6 +334,8 @@ def from_multi_systems(self, directory: str, **kwargs) -> h5py.Group: h5py.Group a HDF5 group in the HDF5 file """ + import h5py + with h5py.File(directory, "r") as f: for ff in f.keys(): yield f[ff] @@ -353,6 +359,8 @@ def to_multi_systems( h5py.Group a HDF5 group with the name of formula """ + import h5py + with h5py.File(directory, "w") as f: for ff in formulas: yield f.create_group(ff) diff --git a/dpdata/plugins/rdkit.py b/dpdata/plugins/rdkit.py index ff7638cb..c7cef07f 100644 --- a/dpdata/plugins/rdkit.py +++ b/dpdata/plugins/rdkit.py @@ -1,20 +1,18 @@ +import dpdata.rdkit.utils from dpdata.format import Format -try: - import rdkit.Chem - - import dpdata.rdkit.utils -except ModuleNotFoundError: - pass - @Format.register("mol") @Format.register("mol_file") class MolFormat(Format): def from_bond_order_system(self, file_name, **kwargs): + import rdkit.Chem + return rdkit.Chem.MolFromMolFile(file_name, sanitize=False, removeHs=False) def to_bond_order_system(self, data, mol, file_name, frame_idx=0, **kwargs): + import rdkit.Chem + assert frame_idx < mol.GetNumConformers() rdkit.Chem.MolToMolFile(mol, file_name, confId=frame_idx) @@ -24,6 +22,8 @@ def to_bond_order_system(self, data, mol, file_name, frame_idx=0, **kwargs): class SdfFormat(Format): def from_bond_order_system(self, file_name, **kwargs): """Note that it requires all molecules in .sdf file must be of the same topology.""" + import rdkit.Chem + mols = [ m for m in rdkit.Chem.SDMolSupplier(file_name, sanitize=False, removeHs=False) @@ -35,6 +35,8 @@ def from_bond_order_system(self, file_name, **kwargs): return mol def to_bond_order_system(self, data, mol, file_name, frame_idx=-1, **kwargs): + import rdkit.Chem + sdf_writer = rdkit.Chem.SDWriter(file_name) if frame_idx == -1: for ii in range(mol.GetNumConformers()): diff --git a/dpdata/pymatgen/molecule.py b/dpdata/pymatgen/molecule.py index 13d4046c..fc05b07a 100644 --- a/dpdata/pymatgen/molecule.py +++ b/dpdata/pymatgen/molecule.py @@ -1,13 +1,11 @@ -import numpy as np - -try: - from pymatgen.core import Molecule -except ImportError: - pass from collections import Counter +import numpy as np + def to_system_data(file_name, protect_layer=9): + from pymatgen.core import Molecule + mol = Molecule.from_file(file_name) elem_mol = list(str(site.species.elements[0]) for site in mol.sites) elem_counter = Counter(elem_mol) diff --git a/dpdata/pymatgen/structure.py b/dpdata/pymatgen/structure.py index b6c148da..9f47baee 100644 --- a/dpdata/pymatgen/structure.py +++ b/dpdata/pymatgen/structure.py @@ -1,10 +1,5 @@ import numpy as np -try: - from pymatgen.core import Structure # noqa: F401 -except ImportError: - pass - def from_system_data(structure) -> dict: symbols = [site.species_string for site in structure] diff --git a/dpdata/rdkit/sanitize.py b/dpdata/rdkit/sanitize.py index 061de3d9..45060abc 100644 --- a/dpdata/rdkit/sanitize.py +++ b/dpdata/rdkit/sanitize.py @@ -2,17 +2,6 @@ import time from copy import deepcopy -from rdkit import Chem -from rdkit.Chem.rdchem import BondType - -# openbabel -try: - from openbabel import openbabel - - USE_OBABEL = True -except ModuleNotFoundError as e: - USE_OBABEL = False - def get_explicit_valence(atom, verbose=False): exp_val_calculated_from_bonds = int( @@ -32,6 +21,8 @@ def get_explicit_valence(atom, verbose=False): def regularize_formal_charges(mol, sanitize=True, verbose=False): """Regularize formal charges of atoms.""" + from rdkit import Chem + assert isinstance(mol, Chem.rdchem.Mol) for atom in mol.GetAtoms(): assign_formal_charge_for_atom(atom, verbose) @@ -47,6 +38,8 @@ def regularize_formal_charges(mol, sanitize=True, verbose=False): def assign_formal_charge_for_atom(atom, verbose=False): """Assigen formal charge according to 8-electron rule for element B,C,N,O,S,P,As.""" + from rdkit import Chem + assert isinstance(atom, Chem.rdchem.Atom) valence = get_explicit_valence(atom, verbose) if atom.GetSymbol() == "B": @@ -135,6 +128,8 @@ def get_terminal_NR2s(atom): def sanitize_phosphate_Patom(P_atom, verbose=True): + from rdkit import Chem + if P_atom.GetSymbol() == "P": terminal_oxygens = get_terminal_oxygens(P_atom) mol = P_atom.GetOwningMol() @@ -161,6 +156,8 @@ def sanitize_phosphate(mol): def sanitize_sulfate_Satom(S_atom, verbose=True): + from rdkit import Chem + if S_atom.GetSymbol() == "S": terminal_oxygens = get_terminal_oxygens(S_atom) mol = S_atom.GetOwningMol() @@ -187,6 +184,8 @@ def sanitize_sulfate(mol): def sanitize_carboxyl_Catom(C_atom, verbose=True): + from rdkit import Chem + if C_atom.GetSymbol() == "C": terminal_oxygens = get_terminal_oxygens(C_atom) mol = C_atom.GetOwningMol() @@ -214,6 +213,8 @@ def sanitize_carboxyl(mol): def sanitize_guanidine_Catom(C_atom, verbose=True): + from rdkit import Chem + if C_atom.GetSymbol() == "C": terminal_NR2s = get_terminal_NR2s(C_atom) mol = C_atom.GetOwningMol() @@ -241,6 +242,8 @@ def sanitize_guanidine(mol): def sanitize_nitro_Natom(N_atom, verbose=True): + from rdkit import Chem + if N_atom.GetSymbol() == "N": terminal_oxygens = get_terminal_oxygens(N_atom) mol = N_atom.GetOwningMol() @@ -275,6 +278,8 @@ def is_terminal_nitrogen(N_atom): def sanitize_nitrine_Natom(atom, verbose=True): + from rdkit import Chem + if atom.GetSymbol() == "N" and len(atom.GetNeighbors()) == 2: mol = atom.GetOwningMol() nei1, nei2 = atom.GetNeighbors()[0], atom.GetNeighbors()[1] @@ -312,6 +317,8 @@ def contain_hetero_aromatic(mol): # for carbon with explicit valence > 4 def regularize_carbon_bond_order(atom, verbose=True): + from rdkit import Chem + if atom.GetSymbol() == "C" and get_explicit_valence(atom) > 4: if verbose: print("Detecting carbon with explicit valence > 4, fixing it...") @@ -330,6 +337,8 @@ def regularize_carbon_bond_order(atom, verbose=True): # for nitrogen with explicit valence > 4 def regularize_nitrogen_bond_order(atom, verbose=True): + from rdkit import Chem + mol = atom.GetOwningMol() if atom.GetSymbol() == "N" and get_explicit_valence(atom) > 4: O_atoms = get_terminal_oxygens(atom) @@ -363,6 +372,9 @@ def mol_edit_log(mol, i, j): def kekulize_aromatic_heterocycles(mol_in, assign_formal_charge=True, sanitize=True): + from rdkit import Chem + from rdkit.Chem.rdchem import BondType + mol = Chem.RWMol(mol_in) rings = Chem.rdmolops.GetSymmSSSR(mol) rings = [list(i) for i in list(rings)] @@ -566,6 +578,9 @@ def hetero_priority(idx, mol): def convert_by_obabel( mol, cache_dir=os.path.join(os.getcwd(), ".cache"), obabel_path="obabel" ): + from openbabel import openbabel + from rdkit import Chem + if not os.path.exists(cache_dir): os.mkdir(cache_dir) if mol.HasProp("_Name"): @@ -585,6 +600,8 @@ def convert_by_obabel( def super_sanitize_mol(mol, name=None, verbose=True): + from rdkit import Chem + if name is None: if mol.HasProp("_Name"): name = mol.GetProp("_Name") @@ -655,11 +672,6 @@ def _check_level(self, level): raise ValueError( f"Invalid level '{level}', please set to 'low', 'medium' or 'high'" ) - else: - if level == "high" and not USE_OBABEL: - raise ModuleNotFoundError( - "obabel not installed, high level sanitizer cannot work" - ) def _handle_exception(self, error_info): if self.raise_errors: @@ -669,6 +681,8 @@ def _handle_exception(self, error_info): def sanitize(self, mol): """Sanitize mol according to `self.level`. If failed, return None.""" + from rdkit import Chem + if self.level == "low": try: Chem.SanitizeMol(mol) diff --git a/dpdata/rdkit/utils.py b/dpdata/rdkit/utils.py index 25cf97cd..9c7e50af 100644 --- a/dpdata/rdkit/utils.py +++ b/dpdata/rdkit/utils.py @@ -1,11 +1,9 @@ -try: - from rdkit import Chem -except ModuleNotFoundError: - pass import numpy as np def mol_to_system_data(mol): + from rdkit import Chem + if not isinstance(mol, Chem.rdchem.Mol): raise TypeError(f"rdkit.Chem.Mol required, not {type(mol)}") @@ -52,6 +50,8 @@ def mol_to_system_data(mol): def system_data_to_mol(data): + from rdkit import Chem + mol_ed = Chem.RWMol() atom_symbols = [data["atom_names"][i] for i in data["atom_types"]] # add atoms diff --git a/dpdata/system.py b/dpdata/system.py index a848066e..33b7e7cf 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -7,8 +7,6 @@ from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from monty.json import MSONable -from monty.serialization import dumpfn, loadfn import dpdata import dpdata.md.pbc @@ -41,7 +39,7 @@ def load_format(fmt): ) -class System(MSONable): +class System: """The data System. A data System (a concept used by `deepmd-kit `_) @@ -297,6 +295,8 @@ def __add__(self, others): def dump(self, filename, indent=4): """Dump .json or .yaml file.""" + from monty.serialization import dumpfn + dumpfn(self.as_dict(), filename, indent=indent) def map_atom_types(self, type_map=None) -> np.ndarray: @@ -340,8 +340,22 @@ def map_atom_types(self, type_map=None) -> np.ndarray: @staticmethod def load(filename): """Rebuild System obj. from .json or .yaml file.""" + from monty.serialization import loadfn + return loadfn(filename) + @classmethod + def from_dict(cls, data: dict): + """Construct a System instance from a data dict.""" + from monty.serialization import MontyDecoder + + decoded = { + k: MontyDecoder().process_decoded(v) + for k, v in data.items() + if not k.startswith("@") + } + return cls(**decoded) + def as_dict(self): """Returns data dict of System instance.""" d = { diff --git a/dpdata/unit.py b/dpdata/unit.py index eba07b41..5fc8fe1e 100644 --- a/dpdata/unit.py +++ b/dpdata/unit.py @@ -1,6 +1,6 @@ from abc import ABC -from scipy import constants +from scipy import constants # noqa: TID253 AVOGADRO = constants.Avogadro # Avagadro constant ELE_CHG = constants.elementary_charge # Elementary Charge, in C diff --git a/pyproject.toml b/pyproject.toml index 5292ba9c..1be79442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ select = [ "D", # pydocstyle "UP", # pyupgrade "I", # isort + "TID253", # banned-module-level-imports ] ignore = [ "E501", # line too long @@ -107,3 +108,17 @@ ignore-init-module-imports = true [tool.ruff.lint.pydocstyle] convention = "numpy" + +[tool.ruff.lint.flake8-tidy-imports] +banned-module-level-imports = [ + "pymatgen", + "ase", + "openbabel", + "rdkit", + "parmed", + "deepmd", + "h5py", + "wcmatch", + "monty", + "scipy", +] diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 006d6b01..7e3278ea 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -1,6 +1,6 @@ import unittest -import h5py +import h5py # noqa: TID253 import numpy as np import dpdata diff --git a/tests/test_to_pymatgen_entry.py b/tests/test_to_pymatgen_entry.py index fd8f40fc..7111dcdc 100644 --- a/tests/test_to_pymatgen_entry.py +++ b/tests/test_to_pymatgen_entry.py @@ -2,7 +2,7 @@ import unittest from context import dpdata -from monty.serialization import loadfn +from monty.serialization import loadfn # noqa: TID253 try: from pymatgen.entries.computed_entries import ComputedStructureEntry # noqa: F401