Skip to content

Commit

Permalink
improved symmetry arg handling, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Aug 5, 2024
1 parent e3bf651 commit 71ba081
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 79 deletions.
2 changes: 1 addition & 1 deletion docs/coverage-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/tests-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 17 additions & 4 deletions ramannoodle/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def get_type_error(name: str, value: Any, correct_type: str) -> TypeError:
return TypeError(f"{name} should have type {correct_type}, not {wrong_type}")


def get_shape_error(
name: str, array: NDArray, desired_shape: Sequence[int | None]
) -> ValueError:
"""Return ValueError for an ndarray with the wrong shape."""
shape_spec = f"{_shape_string(array.shape)} != {_shape_string(desired_shape)}"
return ValueError(f"{name} has wrong shape: {shape_spec}")


def verify_ndarray(name: str, array: NDArray) -> None:
"""Verify type of NDArray .
Expand All @@ -81,11 +89,16 @@ def verify_ndarray_shape(
"""
try:
if len(shape) != array.ndim:
shape_spec = f"{_shape_string(array.shape)} != {_shape_string(shape)}"
raise ValueError(f"{name} has wrong shape: {shape_spec}")
raise get_shape_error(name, array, shape)
for d1, d2 in zip(array.shape, shape, strict=True):
if d2 is not None and d1 != d2:
shape_spec = f"{_shape_string(array.shape)} != {_shape_string(shape)}"
raise ValueError(f"{name} has wrong shape: {shape_spec}")
raise get_shape_error(name, array, shape)
except AttributeError as exc:
raise get_type_error(name, array, "ndarray") from exc


def verify_positions(name: str, array: NDArray) -> None:
"""Verify fractional positions according to dimensions and boundary conditions."""
verify_ndarray_shape(name, array, (None, 3))
if (0 > array).any() or (array > 1.0).any():
raise ValueError(f"{name} has coordinates that are not between 0 and 1")
12 changes: 7 additions & 5 deletions ramannoodle/symmetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def __init__( # pylint: disable=too-many-arguments
) -> None:
verify_ndarray_shape("atomic_numbers", atomic_numbers, (None,))
verify_ndarray_shape("lattice", lattice, (3, 3))
verify_ndarray_shape("fractional_positions", fractional_positions, (None, 3))
verify_ndarray_shape(
"fractional_positions", fractional_positions, (len(atomic_numbers), 3)
)

self._atomic_numbers = atomic_numbers
self._lattice = lattice
Expand Down Expand Up @@ -79,7 +81,7 @@ def get_equivalent_displacements(
transform the parameter `displacements` into that degree of freedom.
"""
assert (displacement >= -0.5).all() and (displacement <= 0.5).all()
displacement = symmetry_utils.apply_pbc_displacement(displacement)

ref_positions = symmetry_utils.displace_fractional_positions(
self._fractional_positions, displacement
Expand Down Expand Up @@ -152,9 +154,9 @@ def get_cartesian_displacement(
fractional_displacement
2D array with shape (N,3) where N is the number of atoms
"""
assert (fractional_displacement >= -0.5).all() and (
fractional_displacement <= 0.5
).all()
fractional_displacement = symmetry_utils.apply_pbc_displacement(
fractional_displacement
)

return fractional_displacement @ self._lattice

Expand Down
161 changes: 118 additions & 43 deletions ramannoodle/symmetry/symmetry_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,54 @@
"""Utility functions relevant to symmetry."""

from typing import Iterable

import numpy as np
from numpy.typing import NDArray

from ..exceptions import (
get_type_error,
SymmetryException,
verify_positions,
get_shape_error,
)


def are_collinear(vector_1: NDArray[np.float64], vector_2: NDArray[np.float64]) -> bool:
"""Return whether or not two vectors are collinear."""
vector_1_copy = vector_1 / np.linalg.norm(vector_1)
vector_2_copy = vector_2 / np.linalg.norm(vector_2)
dot_product = vector_1_copy.dot(vector_2_copy)
result: bool = np.isclose(dot_product, 1).all() or np.isclose(dot_product, -1).all()
return result
"""Return whether or not two vectors are collinear.
Parameters
----------
vector_1
ndarray with shape (M,)
vector_2
ndarray with shape (M,)
Raises
------
TypeError
ValueError
"""
try:
vector_1 = vector_1 / float(np.linalg.norm(vector_1))
except TypeError as exc:
raise get_type_error("vector_1", vector_1, "ndarray") from exc
try:
vector_2 = vector_2 / float(np.linalg.norm(vector_2))
except TypeError as exc:
raise get_type_error("vector_2", vector_2, "ndarray") from exc
try:
dot_product = vector_1.dot(vector_2)
except ValueError as exc:
length_expr = f"{len(vector_1)} != {len(vector_2)}"
raise ValueError(
f"vector_1 and vector_2 have different lengths: {length_expr}"
) from exc
return bool(np.isclose(dot_product, 1).all() or np.isclose(dot_product, -1).all())


def is_orthogonal_to_all(
vector_1: NDArray[np.float64], vectors: list[NDArray[np.float64]]
vector_1: NDArray[np.float64], vectors: Iterable[NDArray[np.float64]]
) -> int:
"""Check whether a given vector is orthogonal to a list of others.
Expand All @@ -23,23 +57,31 @@ def is_orthogonal_to_all(
int
first index of non-orthogonal vector, otherwise -1
Raises
------
TypeError
"""
# This implementation could be made more efficient but readability would
# be sacrificed .
vector_1_copy = vector_1 / np.linalg.norm(vector_1)
# This implementation could be made more efficient.
try:
vector_1 = vector_1 / float(np.linalg.norm(vector_1))
except TypeError as exc:
raise get_type_error("vector_1", vector_1, "ndarray") from exc

for index, vector_2 in enumerate(vectors):
vector_2_copy = vector_2 / np.linalg.norm(vector_2)
if not np.isclose(
np.dot(vector_1_copy.flatten(), vector_2_copy.flatten()) + 1, 1
).all():
try:
vector_2 = vector_2 / np.linalg.norm(vector_2)
except TypeError as exc:
raise get_type_error(f"vectors[{index}]", vector_2, "ndarray") from exc

if not np.isclose(np.dot(vector_1.flatten(), vector_2.flatten()) + 1, 1).all():
return index

return -1


def is_collinear_with_all(
vector_1: NDArray[np.float64], vectors: list[NDArray[np.float64]]
vector_1: NDArray[np.float64], vectors: Iterable[NDArray[np.float64]]
) -> int:
"""Check if a given vector is collinear to a list of others.
Expand All @@ -49,8 +91,7 @@ def is_collinear_with_all(
first index of non-collinear vector, otherwise -1
"""
# This implementation could be made more efficient but readability would
# be sacrificed.
# This implementation could be made more efficient.
for index, vector_2 in enumerate(vectors):
if not are_collinear(vector_1.flatten(), vector_2.flatten()):
return index
Expand All @@ -69,8 +110,7 @@ def is_non_collinear_with_all(
first index of collinear vector, otherwise -1
"""
# This implementation could be made more efficient but readability would
# be sacrificed.
# This implementation could be made more efficient.
for index, vector_2 in enumerate(vectors):
if are_collinear(vector_1.flatten(), vector_2.flatten()):
return index
Expand All @@ -83,11 +123,17 @@ def compute_permutation_matrices(
translations: NDArray[np.float64],
fractional_positions: NDArray[np.float64],
) -> NDArray[np.float64]:
"""Expresses a series of rotation/translations as permutation matrices."""
"""Expresses a series of rotation/translations as permutation matrices.
Raises
------
SymmetryException
"""
permutation_matrices = []
for rotation, translation in zip(rotations, translations):
permutation_matrices.append(
get_fractional_positions_permutation_matrix(
_get_fractional_positions_permutation_matrix(
fractional_positions,
transform_fractional_positions(
fractional_positions, rotation, translation
Expand All @@ -97,14 +143,19 @@ def compute_permutation_matrices(
return np.array(permutation_matrices)


def get_fractional_positions_permutation_matrix(
def _get_fractional_positions_permutation_matrix(
reference: NDArray[np.float64], permuted: NDArray[np.float64]
) -> NDArray[np.float64]:
"""Calculate a permutation matrix given permuted fractional positions."""
assert (0 <= reference).all() and (reference <= 1.0).all()
assert (0 <= permuted).all() and (permuted <= 1.0).all()
"""Calculate a permutation matrix given permuted fractional positions.
Raises
------
SymmetryException
"""
reference = apply_pbc(reference)
permuted = apply_pbc(permuted)

# Perhaps not the best implementation, but it'll do for now
# This implementation is VERY slow.
permutation_matrix = np.zeros((len(reference), len(reference)))

for ref_index, ref_position in enumerate(reference):
Expand All @@ -114,7 +165,9 @@ def get_fractional_positions_permutation_matrix(
if distance < 0.001:
permutation_matrix[ref_index][permuted_index] = 1
break
assert np.isclose(np.sum(permutation_matrix, axis=1), 1).all()

if not np.isclose(np.sum(permutation_matrix, axis=1), 1).all():
raise SymmetryException("permutation matrix could not be found")
return permutation_matrix


Expand All @@ -123,11 +176,22 @@ def transform_fractional_positions(
rotation: NDArray[np.float64],
translation: NDArray[np.float64],
) -> NDArray[np.float64]:
"""Transform fractional coordinates under periodic boundary conditions."""
assert (0 <= positions).all() and (positions <= 1.0).all()
rotated = positions @ rotation
rotated[rotated < 0.0] += 1
rotated[rotated > 1.0] -= 1
"""Transform fractional coordinates under periodic boundary conditions.
Raises
------
TypeError
ValueError
"""
verify_positions("positions", positions)
positions = apply_pbc(positions)
try:
rotated = positions @ rotation
except TypeError as exc:
raise get_type_error("rotation", rotation, "ndarray") from exc
except ValueError as exc:
raise get_shape_error("rotation", rotation, (3, 3)) from exc
rotated = apply_pbc(rotated)
return displace_fractional_positions(rotated, translation)


Expand All @@ -136,12 +200,10 @@ def displace_fractional_positions(
displacement: NDArray[np.float64],
) -> NDArray[np.float64]:
"""Add fractional positions together under periodic boundary conditions."""
assert (0 <= positions).all() and (positions <= 1.0).all()
positions = apply_pbc(positions)
displacement = apply_pbc_displacement(displacement)

result = positions + displacement
result[result < 0.0] += 1
result[result > 1.0] -= 1
return result
return apply_pbc(positions + displacement)


def calculate_displacement(
Expand All @@ -153,10 +215,23 @@ def calculate_displacement(
Returns a displacement.
"""
assert (0 <= positions_1).all() and (positions_1 <= 1.0).all()
assert (0 <= positions_2).all() and (positions_2 <= 1.0).all()
positions_1 = apply_pbc(positions_1)
positions_2 = apply_pbc(positions_2)

return apply_pbc_displacement(positions_1 - positions_2)


def apply_pbc(positions: NDArray[np.float64]) -> NDArray[np.float64]:
"""Return fractional positions such that all coordinates are b/t 0 and 1."""
try:
return positions - positions // 1
except TypeError as exc:
raise get_type_error("positions", positions, "ndarray") from exc


difference = positions_1 - positions_2
difference[difference > 0.5] -= 1.0
difference[difference < -0.5] += 1.0
return difference
def apply_pbc_displacement(displacement: NDArray[np.float64]) -> NDArray[np.float64]:
"""Return fractional displacement such as all coordinates are b/t -0.5 and 0.5."""
try:
return np.where(displacement % 1 > 0.5, displacement % 1 - 1, displacement % 1)
except TypeError as exc:
raise get_type_error("displacement", displacement, "ndarray") from exc
6 changes: 3 additions & 3 deletions test/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
],
)
def test_find_duplicates(vectors: NDArray[np.float64], known: bool) -> None:
"""Test for find_duplicates (success)."""
"""Test find_duplicates (normal)."""
assert find_duplicates(vectors) == known


Expand All @@ -38,7 +38,7 @@ def test_find_duplicates(vectors: NDArray[np.float64], known: bool) -> None:
def test_find_duplicates_exception(
vectors: NDArray[np.float64], exception_type: Type[Exception], in_reason: str
) -> None:
"""Test for find_duplicates (exception)."""
"""Test find_duplicates (exception)."""
with pytest.raises(exception_type) as error:
find_duplicates(vectors)
assert in_reason in str(error.value)
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_add_dof_from_files_exception(
exception_type: Type[Exception],
in_reason: str,
) -> None:
"""Test exceptions in add_dof_from_files."""
"""Test add_dof_from_files (exception)."""
symmetry = outcar_symmetry_fixture
model = InterpolationPolarizabilityModel(symmetry, np.zeros((3, 3)))
with pytest.raises(exception_type) as error:
Expand Down
Loading

0 comments on commit 71ba081

Please sign in to comment.