diff --git a/ramannoodle/io/io_utils.py b/ramannoodle/io/io_utils.py index 93c03bd..ffa4528 100644 --- a/ramannoodle/io/io_utils.py +++ b/ramannoodle/io/io_utils.py @@ -126,20 +126,20 @@ def _read_polarizability_dataset( raise get_torch_missing_error() filepaths = pathify_as_list(filepaths) - lattices: list[NDArray[np.float64]] = [] - atomic_numbers_list: list[list[int]] = [] + lattice = np.zeros((3, 3)) + atomic_numbers: list[int] = [] positions_list: list[NDArray[np.float64]] = [] polarizabilities: list[NDArray[np.float64]] = [] for file_index, filepath in tqdm(list(enumerate(filepaths)), unit=" files"): - lattice, atomic_numbers, positions, polarizability = ( + read_lattice, read_atomic_numbers, positions, polarizability = ( read_structure_and_polarizability_fn(filepath) ) if file_index != 0: - if not np.isclose(lattices[0], lattice, atol=1e-5).all(): + if not np.isclose(lattice, read_lattice, atol=1e-5).all(): raise IncompatibleStructureException( f"incompatible lattice: {filepath}" ) - if atomic_numbers_list[0] != atomic_numbers: + if atomic_numbers != read_atomic_numbers: raise IncompatibleStructureException( f"incompatible atomic numbers: {filepath}" ) @@ -147,14 +147,14 @@ def _read_polarizability_dataset( raise IncompatibleStructureException( f"incompatible atomic positions: {filepath}" ) - lattices.append(lattice) - atomic_numbers_list.append(atomic_numbers) + lattice = read_lattice + atomic_numbers = read_atomic_numbers positions_list.append(positions) polarizabilities.append(polarizability) return PolarizabilityDataset( - np.array(lattices), - atomic_numbers_list, + lattice, + atomic_numbers, np.array(positions_list), np.array(polarizabilities), scale_mode="standard", diff --git a/ramannoodle/polarizability/torch/dataset.py b/ramannoodle/polarizability/torch/dataset.py index 659664f..6822c1b 100644 --- a/ramannoodle/polarizability/torch/dataset.py +++ b/ramannoodle/polarizability/torch/dataset.py @@ -1,14 +1,11 @@ """Polarizability PyTorch dataset.""" -import copy - import numpy as np from numpy.typing import NDArray from ramannoodle.exceptions import ( verify_ndarray_shape, verify_list_len, - get_type_error, get_torch_missing_error, ) @@ -20,8 +17,6 @@ except ModuleNotFoundError as exc: raise get_torch_missing_error() from exc -TORCH_PRESENT = True - def _scale_and_flatten_polarizabilities( polarizabilities: Tensor, @@ -83,12 +78,12 @@ class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]): Parameters ---------- - lattices - | (Å) 3D array with shape (S,3,3) where S is the number of samples. + lattice + | (Å) Array with shape (3,3). atomic_numbers - | List of length S containing lists of length N where N is the number of atoms. + | List of length N where N is the number of atoms. positions - | (fractional) 3D array with shape (S,N,3). + | (fractional) 3D array with shape (S,N,3) where S is the number of samples. polarizabilities | 3D array with shape (S,3,3). scale_mode @@ -99,35 +94,27 @@ class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]): def __init__( # pylint: disable=too-many-arguments self, - lattices: NDArray[np.float64], - atomic_numbers: list[list[int]], + lattice: NDArray[np.float64], + atomic_numbers: list[int], positions: NDArray[np.float64], polarizabilities: NDArray[np.float64], scale_mode: str = "standard", ): - verify_ndarray_shape("lattices", lattices, (None, 3, 3)) - num_samples = lattices.shape[0] - verify_list_len("atomic_numbers", atomic_numbers, num_samples) - num_atoms = None - for i, sublist in enumerate(atomic_numbers): - verify_list_len(f"atomic_numbers[{i}]", sublist, num_atoms) - if num_atoms is None: - num_atoms = len(sublist) - verify_ndarray_shape("positions", positions, (num_samples, num_atoms, 3)) - verify_ndarray_shape( - "polarizabilities", polarizabilities, (num_samples, None, None) - ) + # Validate parameter shapes + verify_ndarray_shape("lattice", lattice, (3, 3)) + verify_list_len("atomic_numbers", atomic_numbers, None) + num_atoms = len(atomic_numbers) + verify_ndarray_shape("positions", positions, (None, num_atoms, 3)) + num_samples = positions.shape[0] + verify_ndarray_shape("polarizabilities", polarizabilities, (num_samples, 3, 3)) default_type = torch.get_default_dtype() - self._lattices = torch.from_numpy(lattices).type(default_type) - try: - self._atomic_numbers = torch.tensor(atomic_numbers).type(torch.int) - except (TypeError, ValueError) as exc: - raise get_type_error( - "atomic_numbers", atomic_numbers, "list[list[int]]" - ) from exc - self._positions = torch.from_numpy(positions).type(default_type) - self._polarizabilities = torch.from_numpy(polarizabilities) + self._lattices = torch.tensor(lattice).type(default_type).unsqueeze(0) + self._lattices = self._lattices.expand(num_samples, 3, 3) + self._atomic_numbers = torch.tensor(atomic_numbers).type(torch.int).unsqueeze(0) + self._atomic_numbers = self._atomic_numbers.expand(num_samples, num_atoms) + self._positions = torch.tensor(positions).type(default_type) + self._polarizabilities = torch.tensor(polarizabilities) _, _, scaled = _scale_and_flatten_polarizabilities( self._polarizabilities, scale_mode=scale_mode @@ -145,74 +132,76 @@ def num_samples(self) -> int: return self._positions.size(0) @property - def atomic_numbers(self) -> Tensor: + def atomic_numbers(self) -> list[int]: """Get (a copy of) atomic numbers. Returns ------- : - 2D tensor with size [S,N] where S is the number of samples and N is the - number of atoms. + List of length N where N is the number of atoms. """ - return copy.copy(self._atomic_numbers) + return [int(n) for n in self._atomic_numbers[0]] @property - def positions(self) -> Tensor: + def positions(self) -> NDArray[np.float64]: """Get (a copy of) positions. Returns ------- : - 3D tensor with size [S,N,3] where S is the number of samples and N is the + Array with shape (S,N,3) where S is the number of samples and N is the number of atoms. """ - return self._positions.detach().clone() + return self._positions.detach().clone().numpy() @property - def polarizabilities(self) -> Tensor: + def polarizabilities(self) -> NDArray[np.float64]: """Get (a copy of) polarizabilities. Returns ------- : - 3D tensor with size [S,3,3] where S is the number of samples. + 3D array with shape (S,3,3) where S is the number of samples. """ - return self._polarizabilities.detach().clone() + return self._polarizabilities.detach().clone().numpy() @property - def scaled_polarizabilities(self) -> Tensor: + def scaled_polarizabilities(self) -> NDArray[np.float64]: """Get (a copy of) scaled polarizabilities. Returns ------- : - 2D tensor with size [S,6] where S is the number of samples. + 2D array with shape (S,6) where S is the number of samples. """ - return self._scaled_polarizabilities.detach().clone() + return self._scaled_polarizabilities.detach().clone().numpy() @property - def mean_polarizability(self) -> Tensor: + def mean_polarizability(self) -> NDArray[np.float64]: """Get mean polarizability. Return ------ : - 2D tensor with size [3,3]. + 2D array with shape (3,3). """ - return self._polarizabilities.mean(0, keepdim=True) + return self._polarizabilities.mean(0, keepdim=True).clone().numpy() @property - def stddev_polarizability(self) -> Tensor: + def stddev_polarizability(self) -> NDArray[np.float64]: """Get standard deviation of polarizabilities. Return ------ : - 2D tensor with size [3,3]. + 2D array with shape (3,3). """ - return self._polarizabilities.std(0, unbiased=False, keepdim=True) + result = self._polarizabilities.std(0, unbiased=False, keepdim=True) + return result.clone().numpy() - def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None: + def scale_polarizabilities( + self, mean: NDArray[np.float64], stddev: NDArray[np.float64] + ) -> None: """Standard-scale polarizabilities given a mean and standard deviation. This method may be used to scale validation or test datasets according @@ -221,31 +210,19 @@ def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None: Parameters ---------- mean - | 2D tensor with size [3,3] or 1D tensor. + | Array with shape (3,3). stddev - | 2D tensor with size [3,3] or 1D tensor. + | Array with shape (3,3). """ + verify_ndarray_shape("mean", mean, (3, 3)) + verify_ndarray_shape("mean", stddev, (3, 3)) + _, _, scaled = _scale_and_flatten_polarizabilities( self._polarizabilities, scale_mode="none" ) - try: - scaled = self._polarizabilities - mean - except TypeError as exc: - raise get_type_error("mean", mean, "Tensor") from exc - except RuntimeError as exc: - raise rn_torch_utils.get_tensor_size_error( - "mean", mean, "[3,3] or [1]" - ) from exc - try: - scaled /= stddev - except TypeError as exc: - raise get_type_error("stddev", stddev, "Tensor") from exc - except RuntimeError as exc: - raise rn_torch_utils.get_tensor_size_error( - "stddev", stddev, "[3,3] or [1]" - ) from exc - + scaled = self._polarizabilities - torch.tensor(mean) + scaled /= stddev self._scaled_polarizabilities = scaled def __len__(self) -> int: diff --git a/test/tests/torch/test_dataset.py b/test/tests/torch/test_dataset.py index c9133e3..4281845 100644 --- a/test/tests/torch/test_dataset.py +++ b/test/tests/torch/test_dataset.py @@ -1,8 +1,14 @@ """Testing for PyTorch dataset.""" +from typing import Type + +import numpy as np +from numpy.typing import NDArray + import pytest import ramannoodle.io.generic as generic_io +from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset @pytest.mark.parametrize( @@ -28,3 +34,43 @@ def test_load_polarizability_dataset( assert len(dataset) == len(filepaths) else: assert len(dataset) == 1 + + _ = dataset.atomic_numbers + _ = dataset.mean_polarizability + _ = dataset.num_atoms + _ = dataset.num_samples + _ = dataset.polarizabilities + _ = dataset.positions + + +@pytest.mark.parametrize( + "lattice, atomic_numbers, positions, polarizabilities, scale_mode, exception_type," + "in_reason", + [ + ( + np.zeros((3, 3)), + [1, 2], + np.random.random((2, 2, 3)), + np.random.random((2, 3, 3)), + "invalid_scale_mode", + ValueError, + "unsupported scale mode: invalid_scale_mode", + ), + ], +) +def test_polarizability_dataset_exception( # pylint: disable=too-many-arguments + lattice: NDArray[np.float64], + atomic_numbers: list[int], + positions: NDArray[np.float64], + polarizabilities: NDArray[np.float64], + scale_mode: str, + exception_type: Type[Exception], + in_reason: str, +) -> None: + """Test polarizability dataset (exception).""" + with pytest.raises(exception_type) as error: + PolarizabilityDataset( + lattice, atomic_numbers, positions, polarizabilities, scale_mode + ) + + assert in_reason in str(error.value)