Skip to content

Commit

Permalink
added masking, fixed artmodel orthogonal bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Aug 10, 2024
1 parent cddc07d commit d6211ec
Show file tree
Hide file tree
Showing 17 changed files with 597 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ Contents
:caption: Contents:

introduction
notebooks/Basic tutorial
tutorials
generated/api
dev guide
2 changes: 1 addition & 1 deletion docs/source/notebooks/Basic tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@
"\n",
"\n",
"for axis_i, (i,j) in enumerate([(0,0),(1,1),(2,2),(0,1),(0,2),(1,2)]):\n",
" axis = axes[axis_i]\n",
" axis = axes[axis_i] # type: ignore\n",
" for interp,sign,color in zip(target_interps,signs,['pink', 'red', 'orange']):\n",
" if sign < 0:\n",
" print('negative sign')\n",
Expand Down
306 changes: 306 additions & 0 deletions docs/source/notebooks/Masking.ipynb

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions docs/source/tutorials.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Tutorials
=========

.. toctree::
:maxdepth: 2
:caption: Contents:

notebooks/Basic tutorial
notebooks/Masking
2 changes: 2 additions & 0 deletions ramannoodle/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@
"Og": 118,
}

ATOM_SYMBOLS = {b: a for a, b in ATOMIC_NUMBERS.items()}

RAMAN_TENSOR_CENTRAL_DIFFERENCE = 0.001
BOLTZMANN_CONSTANT = 8.617333262e-5 # Units: eV/K

Expand Down
42 changes: 36 additions & 6 deletions ramannoodle/polarizability/art.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,44 @@ def get_specification_tuples(
)
return specification_tuples

def get_dof_indexes(
self, atom_indexes_or_symbols: int | str | list[int | str]
) -> list[int]:
"""Return art (DOF) indexes for certain atoms.
Parameters
----------
atom_indexes_or_symbols
If integer or list of integers, specifies atom indexes. If string or list
of strings, specifies atom symbols. Mixtures of integers and strings are
allowed.
"""
if not isinstance(atom_indexes_or_symbols, list):
atom_indexes_or_symbols = list([atom_indexes_or_symbols])

atom_indexes = []
for item in atom_indexes_or_symbols:
if isinstance(item, str):
atom_indexes += self._structural_symmetry.get_atom_indexes(item)
else:
atom_indexes += [item]
atom_indexes = list(set(atom_indexes))

dof_indexes = []
for atom_index in atom_indexes:
for index, basis_vector in enumerate(self._cartesian_basis_vectors):
direction = basis_vector[atom_index]
if not np.isclose(direction, 0, atol=1e-5).all():
dof_indexes.append(index)
return dof_indexes

def _get_art_directions(self, atom_index: int) -> list[NDArray[np.float64]]:
"""Return specified art direction vectors for an atom."""
directions = []
for basis_vector in self._cartesian_basis_vectors:
direction = basis_vector[atom_index]
if not np.isclose(direction, 0, atol=1e-5).all():
directions.append(direction)
assert len(directions) <= 3
indexes = self.get_dof_indexes(atom_index)
directions = [
self._cartesian_basis_vectors[index][atom_index] for index in indexes
]
return directions

def __repr__(self) -> str:
Expand Down
73 changes: 58 additions & 15 deletions ramannoodle/polarizability/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Polarizability models."""

from pathlib import Path
from typing import Self
import copy

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(
self._equilibrium_polarizability = equilibrium_polarizability
self._cartesian_basis_vectors: list[NDArray[np.float64]] = []
self._interpolations: list[BSpline] = []
self._mask: NDArray[np.bool] = np.array([], dtype="bool")

def get_polarizability(
self, cartesian_displacement: NDArray[np.float64]
Expand All @@ -83,8 +86,8 @@ def get_polarizability(
"""
delta_polarizability: NDArray[np.float64] = np.zeros((3, 3))
for basis_vector, interpolation in zip(
self._cartesian_basis_vectors, self._interpolations
for basis_vector, interpolation, mask in zip(
self._cartesian_basis_vectors, self._interpolations, self._mask
):
try:
amplitude = np.dot(
Expand All @@ -97,7 +100,9 @@ def get_polarizability(
"cartesian_displacement has incompatible length "
f"({len(cartesian_displacement)}!={len(basis_vector)})"
) from exc
delta_polarizability += interpolation(amplitude)
delta_polarizability += mask * np.array(
interpolation(amplitude), dtype="float64"
)

return delta_polarizability + self._equilibrium_polarizability

Expand Down Expand Up @@ -126,6 +131,18 @@ def _get_dof( # pylint: disable=too-many-locals
:
3-tuple of the form (basis vectors, interpolation_xs, interpolation_ys)
"""
# Check that the parent displacement is orthogonal to existing basis vectors
parent_cartesian_basis_vector = (
self._structural_symmetry.get_cartesian_displacement(parent_displacement)
)
result = is_orthogonal_to_all(
parent_cartesian_basis_vector, self._cartesian_basis_vectors
)
if result != -1:
raise InvalidDOFException(
f"new dof is not orthogonal with existing dof (index={result})"
)

displacements_and_transformations = (
self._structural_symmetry.get_equivalent_displacements(parent_displacement)
)
Expand Down Expand Up @@ -224,6 +241,7 @@ def _construct_and_add_interpolations(

self._cartesian_basis_vectors += basis_vectors_to_add
self._interpolations += interpolations_to_add
self._mask = np.append(self._mask, [True] * len(basis_vectors_to_add))

def add_dof( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -275,18 +293,6 @@ def add_dof( # pylint: disable=too-many-arguments
"polarizabilities", polarizabilities, (len(amplitudes), 3, 3)
)

# Check that the parent displacement is orthogonal to existing basis vectors
parent_cartesian_basis_vector = (
self._structural_symmetry.get_cartesian_displacement(parent_displacement)
)
result = is_orthogonal_to_all(
parent_cartesian_basis_vector, self._cartesian_basis_vectors
)
if result != -1:
raise InvalidDOFException(
f"new dof is not orthogonal with existing dof (index={result})"
)

# Get information needed for DOF
basis_vectors_to_add, interpolation_xs, interpolation_ys = self._get_dof(
parent_displacement,
Expand Down Expand Up @@ -398,6 +404,43 @@ def _read_dof(
np.array(polarizabilities),
)

def get_mask(self) -> NDArray[np.bool]:
"""Return mask."""
return self._mask

def set_mask(self, mask: NDArray[np.bool]) -> None:
"""Set mask.
..warning:: To avoid unintentional use of masked models, we discourage masking
in-place. Instead, consider using `get masked_model`.
Parameters
----------
mask
1D array of size (N,) where N is the number of specified degrees
of freedom (DOFs). If an element is False, its corresponding DOF will be
"masked" and therefore excluded from polarizability calculations.
"""
verify_ndarray_shape("mask", mask, self._mask.shape)
self._mask = mask

def get_masked_model(self, dof_indexes_to_mask: list[int]) -> Self:
"""Return new model with certain degrees of freedom deactivated.
Model masking allows for the calculation of partial Raman spectra in which only
certain degrees of freedom are considered.
"""
result = copy.deepcopy(self)
new_mask = result.get_mask()
new_mask[:] = True
new_mask[dof_indexes_to_mask] = False
result.set_mask(new_mask)
return result

def unmask(self) -> None:
"""Clear mask, activating all specified DOFs."""
self._mask[:] = True

def __repr__(self) -> str:
"""Return string representation."""
total_dofs = 3 * len(self._structural_symmetry.get_fractional_positions())
Expand Down
25 changes: 24 additions & 1 deletion ramannoodle/symmetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from numpy.typing import NDArray
import spglib

from ..globals import ATOM_SYMBOLS
from . import symmetry_utils
from ..exceptions import SymmetryException, verify_ndarray_shape
from ..exceptions import SymmetryException, verify_ndarray_shape, get_type_error


class StructuralSymmetry:
Expand Down Expand Up @@ -182,3 +183,25 @@ def get_cartesian_displacement(
def get_fractional_positions(self) -> NDArray[np.float64]:
"""Return fractional positions."""
return self._fractional_positions

def get_atom_indexes(self, atom_symbols: str | list[str]) -> list[int]:
"""Return atom indexes with matching symbols.
Parameters
----------
atom_symbols
If integer or list of integers, specifies atom indexes. If string or list
of strings, specifies atom symbols. Mixtures of integers and strings are
allowed.
"""
symbols = [ATOM_SYMBOLS[number] for number in self._atomic_numbers]
indexes = []
if isinstance(atom_symbols, str):
atom_symbols = [atom_symbols]
try:
for index, symbol in enumerate(symbols):
if symbol in atom_symbols:
indexes.append(index)
except TypeError as err:
raise get_type_error("atom_symbols", atom_symbols, "list") from err
return indexes
Binary file not shown.
Binary file added test/data/TiO2/known_art_O_spectrum.npz
Binary file not shown.
Binary file not shown.
Binary file added test/data/TiO2/known_art_Ti_spectrum.npz
Binary file not shown.
Binary file not shown.
48 changes: 36 additions & 12 deletions test/tests/test_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,71 +143,95 @@ def test_add_art_exception(


@pytest.mark.parametrize(
"outcar_symmetry_fixture,outcar_files,exception_type,in_reason",
"outcar_symmetry_fixture,outcar_file_groups,exception_type,in_reason",
[
(
"test/data/STO_RATTLED_OUTCAR",
["test/data/TiO2/Ti5_0.1x_eps_OUTCAR"],
[["test/data/TiO2/Ti5_0.1x_eps_OUTCAR"]],
InvalidDOFException,
"incompatible outcar",
),
(
"test/data/TiO2/phonons_OUTCAR",
[
"test/data/TiO2/Ti5_0.1x_eps_OUTCAR",
"test/data/TiO2/Ti5_0.2x_eps_OUTCAR",
[
"test/data/TiO2/Ti5_0.1x_eps_OUTCAR",
"test/data/TiO2/Ti5_0.2x_eps_OUTCAR",
]
],
InvalidDOFException,
"wrong number of amplitudes: 4 != 2",
),
(
"test/data/TiO2/phonons_OUTCAR",
[
"test/data/TiO2/Ti5_0.1x_eps_OUTCAR",
"test/data/TiO2/Ti5_m0.1x_eps_OUTCAR",
[
"test/data/TiO2/Ti5_0.1x_eps_OUTCAR",
"test/data/TiO2/Ti5_m0.1x_eps_OUTCAR",
]
],
InvalidDOFException,
"wrong number of amplitudes: 4 != 2",
),
(
"test/data/TiO2/phonons_OUTCAR",
[
"this_outcar_does_not_exist",
[
"this_outcar_does_not_exist",
]
],
FileNotFoundError,
"No such file or directory",
),
(
"test/data/TiO2/phonons_OUTCAR",
[
"test/data/TiO2/Ti5_0.1x_eps_OUTCAR",
"test/data/TiO2/Ti5_0.1y_eps_OUTCAR",
[
"test/data/TiO2/Ti5_0.1x_eps_OUTCAR",
"test/data/TiO2/Ti5_0.1y_eps_OUTCAR",
]
],
InvalidDOFException,
"is not collinear",
),
(
"test/data/TiO2/phonons_OUTCAR",
[
"test/data/TiO2/O43_0.1z_eps_OUTCAR",
[
"test/data/TiO2/O43_0.1z_eps_OUTCAR",
]
],
InvalidDOFException,
"wrong number of amplitudes: 1 != 2",
),
(
"test/data/TiO2/phonons_OUTCAR",
[
[
"test/data/TiO2/Ti5_0.1x_eps_OUTCAR",
],
[
"test/data/TiO2/Ti5_0.1y_eps_OUTCAR",
],
],
InvalidDOFException,
"is not orthogonal",
),
],
indirect=["outcar_symmetry_fixture"],
)
def test_add_art_from_files_exception(
outcar_symmetry_fixture: StructuralSymmetry,
outcar_files: list[str],
outcar_file_groups: list[str],
exception_type: Type[Exception],
in_reason: str,
) -> None:
"""Test add_dof_from_files (exception)."""
symmetry = outcar_symmetry_fixture
model = ARTModel(symmetry, np.zeros((3, 3)))
with pytest.raises(exception_type) as error:
model.add_art_from_files(outcar_files, "outcar")
for outcar_files in outcar_file_groups:
model.add_art_from_files(outcar_files, "outcar")
assert in_reason in str(error.value)


Expand Down
Loading

0 comments on commit d6211ec

Please sign in to comment.