Skip to content

Commit

Permalink
dataset now works with single lattice, atomic_numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 19, 2024
1 parent de9a8e9 commit 3f500d0
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 80 deletions.
18 changes: 9 additions & 9 deletions ramannoodle/io/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,35 +126,35 @@ def _read_polarizability_dataset(
raise get_torch_missing_error()
filepaths = pathify_as_list(filepaths)

lattices: list[NDArray[np.float64]] = []
atomic_numbers_list: list[list[int]] = []
lattice = np.zeros((3, 3))
atomic_numbers: list[int] = []
positions_list: list[NDArray[np.float64]] = []
polarizabilities: list[NDArray[np.float64]] = []
for file_index, filepath in tqdm(list(enumerate(filepaths)), unit=" files"):
lattice, atomic_numbers, positions, polarizability = (
read_lattice, read_atomic_numbers, positions, polarizability = (
read_structure_and_polarizability_fn(filepath)
)
if file_index != 0:
if not np.isclose(lattices[0], lattice, atol=1e-5).all():
if not np.isclose(lattice, read_lattice, atol=1e-5).all():
raise IncompatibleStructureException(
f"incompatible lattice: {filepath}"
)
if atomic_numbers_list[0] != atomic_numbers:
if atomic_numbers != read_atomic_numbers:
raise IncompatibleStructureException(
f"incompatible atomic numbers: {filepath}"
)
if positions_list[0].shape != positions.shape: # check, just to be safe
raise IncompatibleStructureException(
f"incompatible atomic positions: {filepath}"
)
lattices.append(lattice)
atomic_numbers_list.append(atomic_numbers)
lattice = read_lattice
atomic_numbers = read_atomic_numbers
positions_list.append(positions)
polarizabilities.append(polarizability)

return PolarizabilityDataset(
np.array(lattices),
atomic_numbers_list,
lattice,
atomic_numbers,
np.array(positions_list),
np.array(polarizabilities),
scale_mode="standard",
Expand Down
119 changes: 48 additions & 71 deletions ramannoodle/polarizability/torch/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Polarizability PyTorch dataset."""

import copy

import numpy as np
from numpy.typing import NDArray

from ramannoodle.exceptions import (
verify_ndarray_shape,
verify_list_len,
get_type_error,
get_torch_missing_error,
)

Expand All @@ -20,8 +17,6 @@
except ModuleNotFoundError as exc:
raise get_torch_missing_error() from exc

TORCH_PRESENT = True


def _scale_and_flatten_polarizabilities(
polarizabilities: Tensor,
Expand Down Expand Up @@ -83,12 +78,12 @@ class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]):
Parameters
----------
lattices
| (Å) 3D array with shape (S,3,3) where S is the number of samples.
lattice
| (Å) Array with shape (3,3).
atomic_numbers
| List of length S containing lists of length N where N is the number of atoms.
| List of length N where N is the number of atoms.
positions
| (fractional) 3D array with shape (S,N,3).
| (fractional) 3D array with shape (S,N,3) where S is the number of samples.
polarizabilities
| 3D array with shape (S,3,3).
scale_mode
Expand All @@ -99,35 +94,27 @@ class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]):

def __init__( # pylint: disable=too-many-arguments
self,
lattices: NDArray[np.float64],
atomic_numbers: list[list[int]],
lattice: NDArray[np.float64],
atomic_numbers: list[int],
positions: NDArray[np.float64],
polarizabilities: NDArray[np.float64],
scale_mode: str = "standard",
):
verify_ndarray_shape("lattices", lattices, (None, 3, 3))
num_samples = lattices.shape[0]
verify_list_len("atomic_numbers", atomic_numbers, num_samples)
num_atoms = None
for i, sublist in enumerate(atomic_numbers):
verify_list_len(f"atomic_numbers[{i}]", sublist, num_atoms)
if num_atoms is None:
num_atoms = len(sublist)
verify_ndarray_shape("positions", positions, (num_samples, num_atoms, 3))
verify_ndarray_shape(
"polarizabilities", polarizabilities, (num_samples, None, None)
)
# Validate parameter shapes
verify_ndarray_shape("lattice", lattice, (3, 3))
verify_list_len("atomic_numbers", atomic_numbers, None)
num_atoms = len(atomic_numbers)
verify_ndarray_shape("positions", positions, (None, num_atoms, 3))
num_samples = positions.shape[0]
verify_ndarray_shape("polarizabilities", polarizabilities, (num_samples, 3, 3))

default_type = torch.get_default_dtype()
self._lattices = torch.from_numpy(lattices).type(default_type)
try:
self._atomic_numbers = torch.tensor(atomic_numbers).type(torch.int)
except (TypeError, ValueError) as exc:
raise get_type_error(
"atomic_numbers", atomic_numbers, "list[list[int]]"
) from exc
self._positions = torch.from_numpy(positions).type(default_type)
self._polarizabilities = torch.from_numpy(polarizabilities)
self._lattices = torch.tensor(lattice).type(default_type).unsqueeze(0)
self._lattices = self._lattices.expand(num_samples, 3, 3)
self._atomic_numbers = torch.tensor(atomic_numbers).type(torch.int).unsqueeze(0)
self._atomic_numbers = self._atomic_numbers.expand(num_samples, num_atoms)
self._positions = torch.tensor(positions).type(default_type)
self._polarizabilities = torch.tensor(polarizabilities)

_, _, scaled = _scale_and_flatten_polarizabilities(
self._polarizabilities, scale_mode=scale_mode
Expand All @@ -145,74 +132,76 @@ def num_samples(self) -> int:
return self._positions.size(0)

@property
def atomic_numbers(self) -> Tensor:
def atomic_numbers(self) -> list[int]:
"""Get (a copy of) atomic numbers.
Returns
-------
:
2D tensor with size [S,N] where S is the number of samples and N is the
number of atoms.
List of length N where N is the number of atoms.
"""
return copy.copy(self._atomic_numbers)
return [int(n) for n in self._atomic_numbers[0]]

@property
def positions(self) -> Tensor:
def positions(self) -> NDArray[np.float64]:
"""Get (a copy of) positions.
Returns
-------
:
3D tensor with size [S,N,3] where S is the number of samples and N is the
Array with shape (S,N,3) where S is the number of samples and N is the
number of atoms.
"""
return self._positions.detach().clone()
return self._positions.detach().clone().numpy()

@property
def polarizabilities(self) -> Tensor:
def polarizabilities(self) -> NDArray[np.float64]:
"""Get (a copy of) polarizabilities.
Returns
-------
:
3D tensor with size [S,3,3] where S is the number of samples.
3D array with shape (S,3,3) where S is the number of samples.
"""
return self._polarizabilities.detach().clone()
return self._polarizabilities.detach().clone().numpy()

@property
def scaled_polarizabilities(self) -> Tensor:
def scaled_polarizabilities(self) -> NDArray[np.float64]:
"""Get (a copy of) scaled polarizabilities.
Returns
-------
:
2D tensor with size [S,6] where S is the number of samples.
2D array with shape (S,6) where S is the number of samples.
"""
return self._scaled_polarizabilities.detach().clone()
return self._scaled_polarizabilities.detach().clone().numpy()

@property
def mean_polarizability(self) -> Tensor:
def mean_polarizability(self) -> NDArray[np.float64]:
"""Get mean polarizability.
Return
------
:
2D tensor with size [3,3].
2D array with shape (3,3).
"""
return self._polarizabilities.mean(0, keepdim=True)
return self._polarizabilities.mean(0, keepdim=True).clone().numpy()

@property
def stddev_polarizability(self) -> Tensor:
def stddev_polarizability(self) -> NDArray[np.float64]:
"""Get standard deviation of polarizabilities.
Return
------
:
2D tensor with size [3,3].
2D array with shape (3,3).
"""
return self._polarizabilities.std(0, unbiased=False, keepdim=True)
result = self._polarizabilities.std(0, unbiased=False, keepdim=True)
return result.clone().numpy()

def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None:
def scale_polarizabilities(
self, mean: NDArray[np.float64], stddev: NDArray[np.float64]
) -> None:
"""Standard-scale polarizabilities given a mean and standard deviation.
This method may be used to scale validation or test datasets according
Expand All @@ -221,31 +210,19 @@ def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None:
Parameters
----------
mean
| 2D tensor with size [3,3] or 1D tensor.
| Array with shape (3,3).
stddev
| 2D tensor with size [3,3] or 1D tensor.
| Array with shape (3,3).
"""
verify_ndarray_shape("mean", mean, (3, 3))
verify_ndarray_shape("mean", stddev, (3, 3))

_, _, scaled = _scale_and_flatten_polarizabilities(
self._polarizabilities, scale_mode="none"
)
try:
scaled = self._polarizabilities - mean
except TypeError as exc:
raise get_type_error("mean", mean, "Tensor") from exc
except RuntimeError as exc:
raise rn_torch_utils.get_tensor_size_error(
"mean", mean, "[3,3] or [1]"
) from exc
try:
scaled /= stddev
except TypeError as exc:
raise get_type_error("stddev", stddev, "Tensor") from exc
except RuntimeError as exc:
raise rn_torch_utils.get_tensor_size_error(
"stddev", stddev, "[3,3] or [1]"
) from exc

scaled = self._polarizabilities - torch.tensor(mean)
scaled /= stddev
self._scaled_polarizabilities = scaled

def __len__(self) -> int:
Expand Down
46 changes: 46 additions & 0 deletions test/tests/torch/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Testing for PyTorch dataset."""

from typing import Type

import numpy as np
from numpy.typing import NDArray

import pytest

import ramannoodle.io.generic as generic_io
from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset


@pytest.mark.parametrize(
Expand All @@ -28,3 +34,43 @@ def test_load_polarizability_dataset(
assert len(dataset) == len(filepaths)
else:
assert len(dataset) == 1

_ = dataset.atomic_numbers
_ = dataset.mean_polarizability
_ = dataset.num_atoms
_ = dataset.num_samples
_ = dataset.polarizabilities
_ = dataset.positions


@pytest.mark.parametrize(
"lattice, atomic_numbers, positions, polarizabilities, scale_mode, exception_type,"
"in_reason",
[
(
np.zeros((3, 3)),
[1, 2],
np.random.random((2, 2, 3)),
np.random.random((2, 3, 3)),
"invalid_scale_mode",
ValueError,
"unsupported scale mode: invalid_scale_mode",
),
],
)
def test_polarizability_dataset_exception( # pylint: disable=too-many-arguments
lattice: NDArray[np.float64],
atomic_numbers: list[int],
positions: NDArray[np.float64],
polarizabilities: NDArray[np.float64],
scale_mode: str,
exception_type: Type[Exception],
in_reason: str,
) -> None:
"""Test polarizability dataset (exception)."""
with pytest.raises(exception_type) as error:
PolarizabilityDataset(
lattice, atomic_numbers, positions, polarizabilities, scale_mode
)

assert in_reason in str(error.value)

0 comments on commit 3f500d0

Please sign in to comment.