Skip to content

Commit

Permalink
refined exception utility functions
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Aug 4, 2024
1 parent 55c823e commit e12aef3
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 80 deletions.
62 changes: 62 additions & 0 deletions ramannoodle/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Exceptions and warnings for ramannoodle."""

from typing import Any, Sequence

from numpy.typing import NDArray


class NoMatchingLineFoundException(Exception):
"""Raised when no line can be found in file."""
Expand Down Expand Up @@ -27,3 +31,61 @@ class SymmetryException(Exception):

def __init__(self, reason: str):
pass


def _shape_string(shape: Sequence[int | None]) -> str:
"""Get a string representing a shape.
Maps None --> "_", indicating that this element can
be anything.
"""
result = "("
for i in shape:
if i is None:
result += "_,"
else:
result += f"{i},"
if len(shape) == 1:
return result + ")"
return result[:-1] + ")"


def get_type_error(name: str, value: Any, correct_type: str) -> TypeError:
"""Return TypeError for an ndarray argument."""
wrong_type = type(value).__name__
return TypeError(f"{name} should have type {correct_type}, not {wrong_type}")


def verify_ndarray(name: str, array: NDArray) -> None:
"""Verify type of NDArray .
We should avoid calling this function wherever possible (EATF)
"""
try:
_ = array.shape
except AttributeError as exc:
raise get_type_error(name, array, "ndarray") from exc


def verify_ndarray_shape(
name: str, array: NDArray, shape: Sequence[int | None]
) -> None:
"""Verify an NDArray's shape.
We should avoid calling this function whenever possible (EATF).
Parameters
----------
shape
int elements will be checked, None elements will not be.
"""
try:
if len(shape) != array.ndim:
shape_spec = f"{_shape_string(array.shape)} != {_shape_string(shape)}"
raise ValueError(f"{name} has wrong shape: {shape_spec}")
for d1, d2 in zip(array.shape, shape, strict=True):
if d2 is not None and d1 != d2:
shape_spec = f"{_shape_string(array.shape)} != {_shape_string(shape)}"
raise ValueError(f"{name} has wrong shape: {shape_spec}")
except AttributeError as exc:
raise get_type_error(name, array, "ndarray") from exc
56 changes: 0 additions & 56 deletions ramannoodle/globals.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""Defines some useful globals."""

from typing import Sequence

from numpy.typing import NDArray

ATOMIC_WEIGHTS = {
"H": 1.008,
"He": 4.002602,
Expand Down Expand Up @@ -248,55 +244,3 @@

RAMAN_TENSOR_CENTRAL_DIFFERENCE = 0.001
BOLTZMANN_CONSTANT = 8.617333262e-5 # Units: eV/K


def _shape_string(shape: Sequence[int | None]) -> str:
"""Get a string representing a shape.
Maps None --> "_", indicating that this element can
be anything.
"""
result = "("
for i in shape:
if i is None:
result += "_,"
else:
result += f"{i},"
return result[:-1] + ")"


def verify_ndarray(name: str, array: NDArray) -> None:
"""Verify type of NDArray .
We should avoid calling this function wherever possible (EATF)
"""
try:
_ = array.shape
except AttributeError as exc:
wrong_type = type(array).__name__
raise TypeError(f"{name} should be an ndarray, not a {wrong_type}") from exc


def verify_ndarray_shape(
name: str, array: NDArray, shape: Sequence[int | None]
) -> None:
"""Verify an NDArray's shape.
We should avoid calling this function whenever possible (EATF).
Parameters
----------
shape
int elements will be checked, None elements will not be.
"""
try:
if len(shape) != array.ndim:
shape_spec = f"{_shape_string(array.shape)} != {_shape_string(shape)}"
raise ValueError(f"{name} has wrong shape: {shape_spec}")
for d1, d2 in zip(array.shape, shape, strict=True):
if d2 is not None and d1 != d2:
shape_spec = f"{_shape_string(array.shape)} != {_shape_string(shape)}"
raise ValueError(f"{name} has wrong shape: {shape_spec}")
except AttributeError as exc:
wrong_type = type(array).__name__
raise TypeError(f"{name} should be an ndarray, not a {wrong_type}") from exc
2 changes: 1 addition & 1 deletion ramannoodle/polarizability/interpolation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ... import io
from ...io.io_utils import pathify_as_list
from ...globals import verify_ndarray_shape
from ...exceptions import verify_ndarray_shape


def get_amplitude(
Expand Down
5 changes: 3 additions & 2 deletions ramannoodle/polarizability/polarizability_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import numpy as np
from numpy.typing import NDArray, ArrayLike

from ..exceptions import get_type_error


def find_duplicates(vectors: Iterable[ArrayLike]) -> NDArray | None:
"""Return duplicate vector in a list or None if no duplicates found."""
try:
combinations = itertools.combinations(vectors, 2)
except TypeError as exc:
wrong_type = type(vectors).__name__
raise TypeError(f"vectors should be iterable, not {wrong_type}") from exc
raise get_type_error("vectors", vectors, "Iterable") from exc
try:
for vector_1, vector_2 in combinations:
if np.isclose(vector_1, vector_2).all():
Expand Down
2 changes: 1 addition & 1 deletion ramannoodle/spectrum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy.typing import NDArray

from . import spectrum_utils
from ..globals import verify_ndarray_shape
from ..exceptions import verify_ndarray_shape


class PhononRamanSpectrum: # pylint: disable=too-few-public-methods
Expand Down
40 changes: 22 additions & 18 deletions ramannoodle/spectrum/spectrum_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from numpy.typing import NDArray

from ..globals import BOLTZMANN_CONSTANT, verify_ndarray_shape, verify_ndarray
from ..globals import BOLTZMANN_CONSTANT
from ..exceptions import verify_ndarray_shape, verify_ndarray, get_type_error


def get_bose_einstein_correction(
Expand All @@ -27,16 +28,16 @@ def get_bose_einstein_correction(
ValueError
"""
if temperature <= 0:
raise ValueError(f"invalid temperature: {temperature}<=0")
try:
energy = wavenumbers * 29979245800 * 4.1357e-15 # in eV
if temperature <= 0:
raise ValueError(f"invalid temperature: {temperature} <= 0")
except TypeError as exc:
raise get_type_error("temperature", temperature, "float") from exc
try:
energy = wavenumbers * 29979245800.0 * 4.1357e-15 # in eV
return 1 / (1 - np.exp(-energy / (BOLTZMANN_CONSTANT * temperature)))
except TypeError as exc:
wrong_type = type(wavenumbers).__name__
raise TypeError(
f"wavenumbers should be an ndarray, not a {wrong_type}"
) from exc
raise get_type_error("wavenumbers", wavenumbers, "ndarray") from exc


def get_laser_correction(
Expand All @@ -59,15 +60,15 @@ def get_laser_correction(
ValueError
"""
if laser_wavenumber <= 0:
raise ValueError(f"invalid laser wavenumber: {laser_wavenumber}<=0")
try:
if laser_wavenumber <= 0:
raise ValueError(f"invalid laser_wavenumber: {laser_wavenumber} <= 0")
except TypeError as exc:
raise get_type_error("laser_wavenumber", laser_wavenumber, "float") from exc
try:
return ((wavenumbers - laser_wavenumber) / 10000) ** 4 / wavenumbers
except TypeError as exc:
wrong_type = type(wavenumbers).__name__
raise TypeError(
f"wavenumbers should be an ndarray, not a {wrong_type}"
) from exc
raise get_type_error("wavenumbers", wavenumbers, "ndarray") from exc


def convolve_intensities(
Expand Down Expand Up @@ -106,11 +107,14 @@ def convolve_intensities(
out_wavenumbers = np.linspace(
np.min(wavenumbers) - 100, np.max(wavenumbers) + 100, 1000
)
out_wavenumbers = np.array(out_wavenumbers) # to shut the type checker up
verify_ndarray_shape("out_wavenumbers", out_wavenumbers, (None,))
verify_ndarray_shape("wavenumbers", wavenumbers, (None,))
verify_ndarray_shape("intensities", intensities, (len(wavenumbers),))
if width <= 0:
raise ValueError(f"invalid width: {width} <= 0")
try:
if width <= 0:
raise ValueError(f"invalid width: {width} <= 0")
except TypeError as exc:
raise get_type_error("width", width, "float") from exc
verify_ndarray("out_wavenumbers", out_wavenumbers)

convolved_intensities = out_wavenumbers * 0
Expand All @@ -133,6 +137,6 @@ def convolve_intensities(
/ ((wavenumber - out_wavenumbers) ** 2 + (0.5 * width) ** 2)
)
else:
raise ValueError(f"unsupported convolution type: {type}")
raise ValueError(f"unsupported convolution type: {function}")
convolved_intensities += factor * intensity
return (out_wavenumbers, convolved_intensities)
3 changes: 1 addition & 2 deletions ramannoodle/symmetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import spglib

from . import symmetry_utils
from ..exceptions import SymmetryException
from ..globals import verify_ndarray_shape
from ..exceptions import SymmetryException, verify_ndarray_shape


class StructuralSymmetry:
Expand Down

0 comments on commit e12aef3

Please sign in to comment.