Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pref: lazy import modules #658

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions dpdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,12 @@
# monty needs lzma
# See https://github.com/pandas-dev/pandas/pull/27882
try:
import lzma # noqa: F401
except ImportError:

class fakemodule:
pass

import sys

sys.modules["lzma"] = fakemodule

from . import lammps, md, vasp
from .bond_order_system import BondOrderSystem
from .system import LabeledSystem, MultiSystems, System

try:
from ._version import version as __version__
except ImportError:
from .__about__ import __version__

# BondOrder System has dependency on rdkit
try:
# prevent conflict with dpdata.rdkit
import rdkit as _ # noqa: F401

USE_RDKIT = True
except ModuleNotFoundError:
USE_RDKIT = False

if USE_RDKIT:
from .bond_order_system import BondOrderSystem

__all__ = [
"__version__",
"lammps",
Expand Down
3 changes: 2 additions & 1 deletion dpdata/amber/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re

import numpy as np
from scipy.io import netcdf_file

from dpdata.amber.mask import pick_by_amber_mask
from dpdata.unit import EnergyConversion
Expand Down Expand Up @@ -44,6 +43,8 @@ def read_amber_traj(
labeled : bool
Whether to return labeled data
"""
from scipy.io import netcdf_file

flag_atom_type = False
flag_atom_numb = False
amber_types = []
Expand Down
2 changes: 1 addition & 1 deletion dpdata/ase_calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional

from ase.calculators.calculator import (
from ase.calculators.calculator import ( # noqa: TID253
Calculator,
PropertyNotImplementedError,
all_changes,
Expand Down
3 changes: 2 additions & 1 deletion dpdata/bond_order_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from copy import deepcopy

import numpy as np
from rdkit.Chem import Conformer

import dpdata.rdkit.utils
from dpdata.rdkit.sanitize import Sanitizer
Expand Down Expand Up @@ -102,6 +101,8 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs):
return self

def to_fmt_obj(self, fmtobj, *args, **kwargs):
from rdkit.Chem import Conformer

self.rdkit_mol.RemoveAllConformers()
for ii in range(self.get_nframes()):
conf = Conformer()
Expand Down
11 changes: 6 additions & 5 deletions dpdata/deepmd/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

try:
import h5py
except ImportError:
pass
import numpy as np
from wcmatch.glob import globfilter

import dpdata

if TYPE_CHECKING:
import h5py

Check warning on line 13 in dpdata/deepmd/hdf5.py

View check run for this annotation

Codecov / codecov/patch

dpdata/deepmd/hdf5.py#L13

Added line #L13 was not covered by tests

__all__ = ["to_system_data", "dump"]


Expand All @@ -35,6 +34,8 @@
labels : bool
labels
"""
from wcmatch.glob import globfilter

g = f[folder] if folder else f

data = {}
Expand Down
20 changes: 7 additions & 13 deletions dpdata/gaussian/gjf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

try:
from openbabel import openbabel
except ImportError:
try:
import openbabel
except ImportError:
openbabel = None
from dpdata.periodic_table import Element


Expand Down Expand Up @@ -53,10 +44,13 @@
ImportError
if Open Babel is not installed
"""
if openbabel is None:
raise ImportError(
"Open Babel (Python interface) should be installed to detect fragmentation!"
)
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

try:
from openbabel import openbabel
except ImportError:
import openbabel

Check warning on line 53 in dpdata/gaussian/gjf.py

View check run for this annotation

Codecov / codecov/patch

dpdata/gaussian/gjf.py#L52-L53

Added lines #L52 - L53 were not covered by tests
atomnumber = len(symbols)
# Use openbabel to connect atoms
mol = openbabel.OBMol()
Expand Down
8 changes: 4 additions & 4 deletions dpdata/periodic_table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from pathlib import Path

from monty.serialization import loadfn

fpdt = str(Path(__file__).absolute().parent / "periodic_table.json")
_pdt = loadfn(fpdt)
fpdt = Path(__file__).absolute().parent / "periodic_table.json"
with fpdt.open("r") as fpdt:
_pdt = json.load(fpdt)
ELEMENTS = [
"H",
"He",
Expand Down
20 changes: 11 additions & 9 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@
from dpdata.driver import Driver, Minimizer
from dpdata.format import Format

try:
import ase.io
from ase.calculators.calculator import PropertyNotImplementedError
from ase.io import Trajectory

if TYPE_CHECKING:
from ase.optimize.optimize import Optimizer
except ImportError:
pass
if TYPE_CHECKING:
import ase
from ase.optimize.optimize import Optimizer

Check warning on line 11 in dpdata/plugins/ase.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/ase.py#L10-L11

Added lines #L10 - L11 were not covered by tests


@Format.register("ase/structure")
Expand Down Expand Up @@ -84,6 +78,8 @@
ASE will raise RuntimeError if the atoms does not
have a calculator
"""
from ase.calculators.calculator import PropertyNotImplementedError

info_dict = self.from_system(atoms)
try:
energies = atoms.get_potential_energy(force_consistent=True)
Expand Down Expand Up @@ -137,6 +133,8 @@
ase.Atoms
ASE atoms in the file
"""
import ase.io

frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
yield from frames

Expand Down Expand Up @@ -222,6 +220,8 @@
dict_frames: dict
a dictionary containing data of multiple frames
"""
from ase.io import Trajectory

traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]
dict_frames = ASEStructureFormat().from_system(sub_traj[0])
Expand Down Expand Up @@ -264,6 +264,8 @@
dict_frames: dict
a dictionary containing data of multiple frames
"""
from ase.io import Trajectory

traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]

Expand Down
16 changes: 12 additions & 4 deletions dpdata/plugins/deepmd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

try:
import h5py
except ImportError:
pass
import numpy as np

import dpdata
Expand All @@ -16,6 +13,9 @@
from dpdata.driver import Driver
from dpdata.format import Format

if TYPE_CHECKING:
import h5py

Check warning on line 17 in dpdata/plugins/deepmd.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/deepmd.py#L17

Added line #L17 was not covered by tests


@Format.register("deepmd")
@Format.register("deepmd/raw")
Expand Down Expand Up @@ -202,6 +202,8 @@
TypeError
file_name is not str or h5py.Group or h5py.File
"""
import h5py

if isinstance(file_name, (h5py.Group, h5py.File)):
return dpdata.deepmd.hdf5.to_system_data(
file_name, "", type_map=type_map, labels=labels
Expand Down Expand Up @@ -300,6 +302,8 @@
**kwargs : dict
other parameters
"""
import h5py

if isinstance(file_name, (h5py.Group, h5py.File)):
dpdata.deepmd.hdf5.dump(
file_name, "", data, set_size=set_size, comp_prec=comp_prec
Expand Down Expand Up @@ -330,6 +334,8 @@
h5py.Group
a HDF5 group in the HDF5 file
"""
import h5py

with h5py.File(directory, "r") as f:
for ff in f.keys():
yield f[ff]
Expand All @@ -353,6 +359,8 @@
h5py.Group
a HDF5 group with the name of formula
"""
import h5py

with h5py.File(directory, "w") as f:
for ff in formulas:
yield f.create_group(ff)
Expand Down
16 changes: 9 additions & 7 deletions dpdata/plugins/rdkit.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import dpdata.rdkit.utils
from dpdata.format import Format

try:
import rdkit.Chem

import dpdata.rdkit.utils
except ModuleNotFoundError:
pass


@Format.register("mol")
@Format.register("mol_file")
class MolFormat(Format):
def from_bond_order_system(self, file_name, **kwargs):
import rdkit.Chem

return rdkit.Chem.MolFromMolFile(file_name, sanitize=False, removeHs=False)

def to_bond_order_system(self, data, mol, file_name, frame_idx=0, **kwargs):
import rdkit.Chem

Check warning on line 14 in dpdata/plugins/rdkit.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/rdkit.py#L14

Added line #L14 was not covered by tests

assert frame_idx < mol.GetNumConformers()
rdkit.Chem.MolToMolFile(mol, file_name, confId=frame_idx)

Expand All @@ -24,6 +22,8 @@
class SdfFormat(Format):
def from_bond_order_system(self, file_name, **kwargs):
"""Note that it requires all molecules in .sdf file must be of the same topology."""
import rdkit.Chem

mols = [
m
for m in rdkit.Chem.SDMolSupplier(file_name, sanitize=False, removeHs=False)
Expand All @@ -35,6 +35,8 @@
return mol

def to_bond_order_system(self, data, mol, file_name, frame_idx=-1, **kwargs):
import rdkit.Chem

sdf_writer = rdkit.Chem.SDWriter(file_name)
if frame_idx == -1:
for ii in range(mol.GetNumConformers()):
Expand Down
10 changes: 4 additions & 6 deletions dpdata/pymatgen/molecule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import numpy as np

try:
from pymatgen.core import Molecule
except ImportError:
pass
from collections import Counter

import numpy as np


def to_system_data(file_name, protect_layer=9):
from pymatgen.core import Molecule

mol = Molecule.from_file(file_name)
elem_mol = list(str(site.species.elements[0]) for site in mol.sites)
elem_counter = Counter(elem_mol)
Expand Down
5 changes: 0 additions & 5 deletions dpdata/pymatgen/structure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import numpy as np

try:
from pymatgen.core import Structure # noqa: F401
except ImportError:
pass


def from_system_data(structure) -> dict:
symbols = [site.species_string for site in structure]
Expand Down
Loading