diff --git a/ramannoodle/exceptions.py b/ramannoodle/exceptions.py index b900b4d..866d929 100644 --- a/ramannoodle/exceptions.py +++ b/ramannoodle/exceptions.py @@ -133,3 +133,9 @@ def verify_positions(name: str, array: NDArray) -> None: verify_ndarray_shape(name, array, (None, 3)) if (0 > array).any() or (array > 1.0).any(): raise ValueError(f"{name} has coordinates that are not between 0 and 1") + + +def get_torch_missing_error() -> UsageError: + """Get error indicating that torch is not installed.""" + required_modules = "'torch', 'torch-scatter', and 'torch-sparse' modules" + return UsageError(f"torch functionality requires {required_modules}") diff --git a/ramannoodle/io/generic.py b/ramannoodle/io/generic.py index c9c0d3a..3558c1b 100644 --- a/ramannoodle/io/generic.py +++ b/ramannoodle/io/generic.py @@ -12,17 +12,18 @@ import numpy as np from numpy.typing import NDArray -from ramannoodle.dynamics.phonon import Phonons +from ramannoodle.dynamics.phonon import Phonons from ramannoodle.dynamics.trajectory import Trajectory - from ramannoodle.structure.reference import ReferenceStructure +from ramannoodle.exceptions import UsageError, get_torch_missing_error import ramannoodle.io.vasp as vasp_io +TORCH_PRESENT = True try: - from ramannoodle.polarizability.torch import dataset -except ModuleNotFoundError: - import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore + from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset +except UsageError: + TORCH_PRESENT = False # These map between file formats and appropriate IO functions. _PHONON_READERS = { @@ -193,7 +194,7 @@ def read_structure_and_polarizability( def read_polarizability_dataset( filepaths: str | Path | list[str] | list[Path], file_format: str, -) -> dataset.PolarizabilityDataset: +) -> "PolarizabilityDataset": """Read polarizability dataset from files. Parameters @@ -214,6 +215,8 @@ def read_polarizability_dataset( IncompatibleFileException File is incompatible with the dataset. """ + if not TORCH_PRESENT: + raise get_torch_missing_error() try: return _POLARIZABILITY_DATASET_READERS[file_format](filepaths) except KeyError as exc: diff --git a/ramannoodle/io/io_utils.py b/ramannoodle/io/io_utils.py index f2c4f72..93c03bd 100644 --- a/ramannoodle/io/io_utils.py +++ b/ramannoodle/io/io_utils.py @@ -13,13 +13,16 @@ verify_positions, verify_list_len, IncompatibleStructureException, + get_torch_missing_error, + UsageError, ) from ramannoodle.globals import ATOM_SYMBOLS +TORCH_PRESENT = True try: - from ramannoodle.polarizability.torch import dataset -except ModuleNotFoundError: - import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore + from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset +except UsageError: + TORCH_PRESENT = False def _skip_file_until_line_contains(file: TextIO, content: str) -> str: @@ -99,7 +102,7 @@ def _read_polarizability_dataset( [str | Path], tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]], ], -) -> dataset.PolarizabilityDataset: +) -> "PolarizabilityDataset": """Read polarizability dataset from OUTCAR files. Parameters @@ -118,11 +121,9 @@ def _read_polarizability_dataset( File has an unexpected format. IncompatibleFileException File is incompatible with the dataset. - ModuleNotFoundError - Torch installation could not be found. """ - if not dataset.TORCH_PRESENT: - raise ModuleNotFoundError("torch installation not found") + if not TORCH_PRESENT: + raise get_torch_missing_error() filepaths = pathify_as_list(filepaths) lattices: list[NDArray[np.float64]] = [] @@ -151,7 +152,7 @@ def _read_polarizability_dataset( positions_list.append(positions) polarizabilities.append(polarizability) - return dataset.PolarizabilityDataset( + return PolarizabilityDataset( np.array(lattices), atomic_numbers_list, np.array(positions_list), diff --git a/ramannoodle/io/vasp/outcar.py b/ramannoodle/io/vasp/outcar.py index c4fb045..7baae6e 100644 --- a/ramannoodle/io/vasp/outcar.py +++ b/ramannoodle/io/vasp/outcar.py @@ -12,15 +12,15 @@ ) from ramannoodle.exceptions import InvalidFileException, NoMatchingLineFoundException from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS -from ramannoodle.exceptions import get_type_error +from ramannoodle.exceptions import get_type_error, UsageError from ramannoodle.dynamics.phonon import Phonons from ramannoodle.dynamics.trajectory import Trajectory from ramannoodle.structure.reference import ReferenceStructure try: - from ramannoodle.polarizability.torch import dataset -except ModuleNotFoundError: - import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore + from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset +except UsageError: + pass # Utilities for OUTCAR. Warning: some of these functions partially read files. @@ -404,7 +404,7 @@ def read_structure_and_polarizability( def read_polarizability_dataset( filepaths: str | Path | list[str] | list[Path], -) -> dataset.PolarizabilityDataset: +) -> "PolarizabilityDataset": """Read polarizability dataset from OUTCAR files. Parameters diff --git a/ramannoodle/io/vasp/vasprun.py b/ramannoodle/io/vasp/vasprun.py index a028a14..c2b0287 100644 --- a/ramannoodle/io/vasp/vasprun.py +++ b/ramannoodle/io/vasp/vasprun.py @@ -9,16 +9,16 @@ from numpy.typing import NDArray from ramannoodle.io.io_utils import pathify, _read_polarizability_dataset -from ramannoodle.exceptions import InvalidFileException +from ramannoodle.exceptions import InvalidFileException, UsageError from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS from ramannoodle.dynamics.phonon import Phonons from ramannoodle.dynamics.trajectory import Trajectory from ramannoodle.structure.reference import ReferenceStructure try: - from ramannoodle.polarizability.torch import dataset -except ModuleNotFoundError: - import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore + from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset +except UsageError: + pass def _get_root_element(file: TextIO) -> Element: @@ -199,7 +199,7 @@ def read_structure_and_polarizability( def read_polarizability_dataset( filepaths: str | Path | list[str] | list[Path], -) -> dataset.PolarizabilityDataset: +) -> "PolarizabilityDataset": """Read polarizability dataset from OUTCAR files. Parameters diff --git a/ramannoodle/polarizability/torch/dataset.py b/ramannoodle/polarizability/torch/dataset.py index 64cf3ba..402a62a 100644 --- a/ramannoodle/polarizability/torch/dataset.py +++ b/ramannoodle/polarizability/torch/dataset.py @@ -5,12 +5,21 @@ import numpy as np from numpy.typing import NDArray -import torch -from torch import Tensor -from torch.utils.data import Dataset +from ramannoodle.exceptions import ( + verify_ndarray_shape, + verify_list_len, + get_type_error, + get_torch_missing_error, +) + +try: + import torch + from torch import Tensor + from torch.utils.data import Dataset + import ramannoodle.polarizability.torch.utils as rn_torch_utils +except ModuleNotFoundError as exc: + raise get_torch_missing_error() from exc -from ramannoodle.exceptions import verify_ndarray_shape, verify_list_len, get_type_error -import ramannoodle.polarizability.torch.utils as rn_torch_utils TORCH_PRESENT = True diff --git a/ramannoodle/polarizability/torch/dummy_dataset.py b/ramannoodle/polarizability/torch/dummy_dataset.py deleted file mode 100644 index f22b67f..0000000 --- a/ramannoodle/polarizability/torch/dummy_dataset.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Dummy polarizability PyTorch dataset. - -Used when torch installation cannot be found. - -:meta private: -""" - -import numpy as np -from numpy.typing import NDArray - -TORCH_PRESENT = False - - -class PolarizabilityDataset: # pylint: disable=too-few-public-methods - """PyTorch dataset of atomic structures and polarizabilities. - - Polarizabilities are scaled and flattened into vectors containing the six - independent tensor components. - - Parameters - ---------- - lattices - | (Å) 3D array with shape (S,3,3) where S is the number of samples. - atomic_numbers - | List of length S containing lists of length N, where N is the number of atoms. - positions - | (fractional) 3D array with shape (S,N,3). - polarizabilities - | 3D array with shape (S,3,3). - scale_mode - | Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by - | standard deviation), and ``"none"`` (no scaling). - - """ - - def __init__( # pylint: disable=too-many-arguments - self, - lattices: NDArray[np.float64], - atomic_numbers: list[list[int]], - positions: NDArray[np.float64], - polarizabilities: NDArray[np.float64], - scale_mode: str = "standard", - ): - raise ModuleNotFoundError("torch installation not found") diff --git a/ramannoodle/polarizability/torch/gnn.py b/ramannoodle/polarizability/torch/gnn.py index f8feb8d..04f4e8b 100644 --- a/ramannoodle/polarizability/torch/gnn.py +++ b/ramannoodle/polarizability/torch/gnn.py @@ -4,25 +4,30 @@ import typing -import torch -from torch import Tensor -from torch.nn import ( - BatchNorm1d, - Embedding, - Linear, - ModuleList, - Sequential, - Module, - LayerNorm, -) - -from torch_geometric.nn.inits import reset -from torch_geometric.nn.models.dimenet import triplets -from torch_geometric.nn.models.schnet import ShiftedSoftplus -from torch_geometric.utils import scatter - from ramannoodle.structure.reference import ReferenceStructure -import ramannoodle.polarizability.torch.utils as rn_torch_utils +from ramannoodle.exceptions import get_torch_missing_error + +try: + import torch + from torch import Tensor + from torch.nn import ( + BatchNorm1d, + Embedding, + Linear, + ModuleList, + Sequential, + Module, + LayerNorm, + ) + + from torch_geometric.nn.inits import reset + from torch_geometric.nn.models.dimenet import triplets + from torch_geometric.nn.models.schnet import ShiftedSoftplus + from torch_geometric.utils import scatter + + import ramannoodle.polarizability.torch.utils as rn_torch_utils +except ModuleNotFoundError as exc: + raise get_torch_missing_error() from exc # pylint: disable=not-callable diff --git a/ramannoodle/polarizability/torch/utils.py b/ramannoodle/polarizability/torch/utils.py index f284df0..790a780 100644 --- a/ramannoodle/polarizability/torch/utils.py +++ b/ramannoodle/polarizability/torch/utils.py @@ -2,10 +2,13 @@ from typing import Sequence -import torch -from torch import Tensor +from ramannoodle.exceptions import get_type_error, get_torch_missing_error -from ramannoodle.exceptions import get_type_error +try: + import torch + from torch import Tensor +except ModuleNotFoundError as exc: + raise get_torch_missing_error() from exc # pylint complains about torch.norm # pylint: disable=not-callable