Skip to content

Commit

Permalink
annotation-based type checking for PolarizabilityDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 18, 2024
1 parent 1148a0e commit 8aa0287
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 95 deletions.
6 changes: 6 additions & 0 deletions ramannoodle/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ def verify_positions(name: str, array: NDArray) -> None:
verify_ndarray_shape(name, array, (None, 3))
if (0 > array).any() or (array > 1.0).any():
raise ValueError(f"{name} has coordinates that are not between 0 and 1")


def get_torch_missing_error() -> UsageError:
"""Get error indicating that torch is not installed."""
required_modules = "'torch', 'torch-scatter', and 'torch-sparse' modules"
return UsageError(f"torch functionality requires {required_modules}")
15 changes: 9 additions & 6 deletions ramannoodle/io/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@

import numpy as np
from numpy.typing import NDArray
from ramannoodle.dynamics.phonon import Phonons

from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory

from ramannoodle.structure.reference import ReferenceStructure
from ramannoodle.exceptions import UsageError, get_torch_missing_error
import ramannoodle.io.vasp as vasp_io

TORCH_PRESENT = True
try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset
except UsageError:
TORCH_PRESENT = False

# These map between file formats and appropriate IO functions.
_PHONON_READERS = {
Expand Down Expand Up @@ -193,7 +194,7 @@ def read_structure_and_polarizability(
def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
file_format: str,
) -> dataset.PolarizabilityDataset:
) -> "PolarizabilityDataset":
"""Read polarizability dataset from files.
Parameters
Expand All @@ -214,6 +215,8 @@ def read_polarizability_dataset(
IncompatibleFileException
File is incompatible with the dataset.
"""
if not TORCH_PRESENT:
raise get_torch_missing_error()
try:
return _POLARIZABILITY_DATASET_READERS[file_format](filepaths)
except KeyError as exc:
Expand Down
19 changes: 10 additions & 9 deletions ramannoodle/io/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
verify_positions,
verify_list_len,
IncompatibleStructureException,
get_torch_missing_error,
UsageError,
)
from ramannoodle.globals import ATOM_SYMBOLS

TORCH_PRESENT = True
try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset
except UsageError:
TORCH_PRESENT = False


def _skip_file_until_line_contains(file: TextIO, content: str) -> str:
Expand Down Expand Up @@ -99,7 +102,7 @@ def _read_polarizability_dataset(
[str | Path],
tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]],
],
) -> dataset.PolarizabilityDataset:
) -> "PolarizabilityDataset":
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand All @@ -118,11 +121,9 @@ def _read_polarizability_dataset(
File has an unexpected format.
IncompatibleFileException
File is incompatible with the dataset.
ModuleNotFoundError
Torch installation could not be found.
"""
if not dataset.TORCH_PRESENT:
raise ModuleNotFoundError("torch installation not found")
if not TORCH_PRESENT:
raise get_torch_missing_error()
filepaths = pathify_as_list(filepaths)

lattices: list[NDArray[np.float64]] = []
Expand Down Expand Up @@ -151,7 +152,7 @@ def _read_polarizability_dataset(
positions_list.append(positions)
polarizabilities.append(polarizability)

return dataset.PolarizabilityDataset(
return PolarizabilityDataset(
np.array(lattices),
atomic_numbers_list,
np.array(positions_list),
Expand Down
10 changes: 5 additions & 5 deletions ramannoodle/io/vasp/outcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
)
from ramannoodle.exceptions import InvalidFileException, NoMatchingLineFoundException
from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS
from ramannoodle.exceptions import get_type_error
from ramannoodle.exceptions import get_type_error, UsageError
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset
except UsageError:
pass


# Utilities for OUTCAR. Warning: some of these functions partially read files.
Expand Down Expand Up @@ -404,7 +404,7 @@ def read_structure_and_polarizability(

def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> dataset.PolarizabilityDataset:
) -> "PolarizabilityDataset":
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand Down
10 changes: 5 additions & 5 deletions ramannoodle/io/vasp/vasprun.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
from numpy.typing import NDArray

from ramannoodle.io.io_utils import pathify, _read_polarizability_dataset
from ramannoodle.exceptions import InvalidFileException
from ramannoodle.exceptions import InvalidFileException, UsageError
from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS
from ramannoodle.dynamics.phonon import Phonons
from ramannoodle.dynamics.trajectory import Trajectory
from ramannoodle.structure.reference import ReferenceStructure

try:
from ramannoodle.polarizability.torch import dataset
except ModuleNotFoundError:
import ramannoodle.polarizability.torch.dummy_dataset as dataset # type: ignore
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset
except UsageError:
pass


def _get_root_element(file: TextIO) -> Element:
Expand Down Expand Up @@ -199,7 +199,7 @@ def read_structure_and_polarizability(

def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
) -> dataset.PolarizabilityDataset:
) -> "PolarizabilityDataset":
"""Read polarizability dataset from OUTCAR files.
Parameters
Expand Down
19 changes: 14 additions & 5 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@
import numpy as np
from numpy.typing import NDArray

import torch
from torch import Tensor
from torch.utils.data import Dataset
from ramannoodle.exceptions import (
verify_ndarray_shape,
verify_list_len,
get_type_error,
get_torch_missing_error,
)

try:
import torch
from torch import Tensor
from torch.utils.data import Dataset
import ramannoodle.polarizability.torch.utils as rn_torch_utils
except ModuleNotFoundError as exc:
raise get_torch_missing_error() from exc

from ramannoodle.exceptions import verify_ndarray_shape, verify_list_len, get_type_error
import ramannoodle.polarizability.torch.utils as rn_torch_utils

TORCH_PRESENT = True

Expand Down
44 changes: 0 additions & 44 deletions ramannoodle/polarizability/torch/dummy_dataset.py

This file was deleted.

41 changes: 23 additions & 18 deletions ramannoodle/polarizability/torch/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,30 @@

import typing

import torch
from torch import Tensor
from torch.nn import (
BatchNorm1d,
Embedding,
Linear,
ModuleList,
Sequential,
Module,
LayerNorm,
)

from torch_geometric.nn.inits import reset
from torch_geometric.nn.models.dimenet import triplets
from torch_geometric.nn.models.schnet import ShiftedSoftplus
from torch_geometric.utils import scatter

from ramannoodle.structure.reference import ReferenceStructure
import ramannoodle.polarizability.torch.utils as rn_torch_utils
from ramannoodle.exceptions import get_torch_missing_error

try:
import torch
from torch import Tensor
from torch.nn import (
BatchNorm1d,
Embedding,
Linear,
ModuleList,
Sequential,
Module,
LayerNorm,
)

from torch_geometric.nn.inits import reset
from torch_geometric.nn.models.dimenet import triplets
from torch_geometric.nn.models.schnet import ShiftedSoftplus
from torch_geometric.utils import scatter

import ramannoodle.polarizability.torch.utils as rn_torch_utils
except ModuleNotFoundError as exc:
raise get_torch_missing_error() from exc

# pylint: disable=not-callable

Expand Down
9 changes: 6 additions & 3 deletions ramannoodle/polarizability/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from typing import Sequence

import torch
from torch import Tensor
from ramannoodle.exceptions import get_type_error, get_torch_missing_error

from ramannoodle.exceptions import get_type_error
try:
import torch
from torch import Tensor
except ModuleNotFoundError as exc:
raise get_torch_missing_error() from exc

# pylint complains about torch.norm
# pylint: disable=not-callable
Expand Down

0 comments on commit 8aa0287

Please sign in to comment.