diff --git a/deps/requirements.txt b/deps/requirements.txt index 4b08ce4..eb23a06 100644 --- a/deps/requirements.txt +++ b/deps/requirements.txt @@ -20,3 +20,4 @@ spglib >= 1.16.4;python_version=='3.12' # minimum working tabulate >= 0.8.8;python_version=='3.10' # minimum working tabulate >= 0.8.8;python_version=='3.11' # minimum working tabulate >= 0.8.8;python_version=='3.12' # minimum working +tqdm diff --git a/ramannoodle/exceptions.py b/ramannoodle/exceptions.py index 2efd177..b900b4d 100644 --- a/ramannoodle/exceptions.py +++ b/ramannoodle/exceptions.py @@ -13,6 +13,10 @@ class InvalidFileException(Exception): """File cannot be read, likely due to due to invalid or unexpected format.""" +class IncompatibleStructureException(Exception): + """Supplied file is incompatible.""" + + class InvalidDOFException(Exception): """A supplied degree of freedom is invalid.""" diff --git a/ramannoodle/io/generic.py b/ramannoodle/io/generic.py index ad300da..b3cfe3a 100644 --- a/ramannoodle/io/generic.py +++ b/ramannoodle/io/generic.py @@ -18,6 +18,7 @@ from ramannoodle.structure.reference import ReferenceStructure import ramannoodle.io.vasp as vasp_io +from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset # These map between file formats and appropriate IO functions. _PHONON_READERS = { @@ -36,6 +37,10 @@ "outcar": vasp_io.outcar.read_structure_and_polarizability, "vasprun.xml": vasp_io.vasprun.read_structure_and_polarizability, } +_POLARIZABILITY_DATASET_READERS = { + "outcar": vasp_io.outcar.read_polarizability_dataset, + "vasprun.xml": vasp_io.vasprun.read_polarizability_dataset, +} _POSITION_READERS = { "poscar": vasp_io.poscar.read_positions, "outcar": vasp_io.outcar.read_positions, @@ -181,6 +186,36 @@ def read_structure_and_polarizability( raise ValueError(f"unsupported format: {file_format}") from exc +def read_polarizability_dataset( + filepaths: str | Path | list[str] | list[Path], + file_format: str, +) -> PolarizabilityDataset: + """Read polarizability dataset from files. + + Parameters + ---------- + filepath + file_format + | Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`) + + Returns + ------- + : + + Raises + ------ + FileNotFoundError + InvalidFileException + File has an unexpected format. + IncompatibleFileException + File is incompatible with the dataset. + """ + try: + return _POLARIZABILITY_DATASET_READERS[file_format](filepaths) + except KeyError as exc: + raise ValueError(f"unsupported format: {file_format}") from exc + + def read_positions( filepath: str | Path, file_format: str, diff --git a/ramannoodle/io/io_utils.py b/ramannoodle/io/io_utils.py index 7833144..0605205 100644 --- a/ramannoodle/io/io_utils.py +++ b/ramannoodle/io/io_utils.py @@ -1,18 +1,21 @@ """Universal IO utility functions.""" -from typing import TextIO +from typing import TextIO, Callable from pathlib import Path import numpy as np from numpy.typing import NDArray +from tqdm import tqdm from ramannoodle.exceptions import ( NoMatchingLineFoundException, verify_ndarray_shape, verify_positions, verify_list_len, + IncompatibleStructureException, ) from ramannoodle.globals import ATOM_SYMBOLS +from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset def _skip_file_until_line_contains(file: TextIO, content: str) -> str: @@ -84,3 +87,66 @@ def verify_trajectory( verify_ndarray_shape("positions_ts", positions_ts, (None, len(atomic_numbers), 3)) if (0 > positions_ts).any() or (positions_ts > 1.0).any(): raise ValueError("positions_ts has coordinates that are not between 0 and 1") + + +def _read_polarizability_dataset( + filepaths: str | Path | list[str] | list[Path], + read_structure_and_polarizability_fn: Callable[ + [str | Path], + tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]], + ], +) -> PolarizabilityDataset: + """Read polarizability dataset from OUTCAR files. + + Parameters + ---------- + filepath + read_structure_and_polarizability_fn + + Returns + ------- + : + + Raises + ------ + FileNotFoundError + InvalidFileException + File has an unexpected format. + IncompatibleFileException + File is incompatible with the dataset. + """ + filepaths = pathify_as_list(filepaths) + + lattices: list[NDArray[np.float64]] = [] + atomic_numbers_list: list[list[int]] = [] + positions_list: list[NDArray[np.float64]] = [] + polarizabilities: list[NDArray[np.float64]] = [] + for file_index, filepath in tqdm(list(enumerate(filepaths)), unit="files"): + lattice, atomic_numbers, positions, polarizability = ( + read_structure_and_polarizability_fn(filepath) + ) + if file_index != 0: + if not np.isclose(lattices[0], lattice, atol=1e-5).all(): + raise IncompatibleStructureException( + f"incompatible lattice: {filepath}" + ) + if atomic_numbers_list[0] != atomic_numbers: + raise IncompatibleStructureException( + f"incompatible atomic numbers: {filepath}" + ) + if positions_list[0].shape != positions.shape: # check, just to be safe + raise IncompatibleStructureException( + f"incompatible atomic positions: {filepath}" + ) + lattices.append(lattice) + atomic_numbers_list.append(atomic_numbers) + positions_list.append(positions) + polarizabilities.append(polarizability) + + return PolarizabilityDataset( + np.array(lattices), + atomic_numbers_list, + np.array(positions_list), + np.array(polarizabilities), + scale_mode="standard", + ) diff --git a/ramannoodle/io/vasp/outcar.py b/ramannoodle/io/vasp/outcar.py index fce946b..b3fbeb8 100644 --- a/ramannoodle/io/vasp/outcar.py +++ b/ramannoodle/io/vasp/outcar.py @@ -5,14 +5,18 @@ import numpy as np from numpy.typing import NDArray - -from ramannoodle.io.io_utils import _skip_file_until_line_contains, pathify +from ramannoodle.io.io_utils import ( + _skip_file_until_line_contains, + pathify, + _read_polarizability_dataset, +) from ramannoodle.exceptions import InvalidFileException, NoMatchingLineFoundException from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS from ramannoodle.exceptions import get_type_error from ramannoodle.dynamics.phonon import Phonons from ramannoodle.dynamics.trajectory import Trajectory from ramannoodle.structure.reference import ReferenceStructure +from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset # Utilities for OUTCAR. Warning: some of these functions partially read files. @@ -394,6 +398,30 @@ def read_structure_and_polarizability( return lattice, atomic_numbers, positions, polarizability +def read_polarizability_dataset( + filepaths: str | Path | list[str] | list[Path], +) -> PolarizabilityDataset: + """Read polarizability dataset from OUTCAR files. + + Parameters + ---------- + filepaths + + Returns + ------- + : + + Raises + ------ + FileNotFoundError + InvalidFileException + File has an unexpected format. + IncompatibleFileException + File is incompatible with the dataset. + """ + return _read_polarizability_dataset(filepaths, read_structure_and_polarizability) + + def read_ref_structure(filepath: str | Path) -> ReferenceStructure: """Read reference structure from a VASP OUTCAR file. diff --git a/ramannoodle/io/vasp/vasprun.py b/ramannoodle/io/vasp/vasprun.py index 9a63e10..742f860 100644 --- a/ramannoodle/io/vasp/vasprun.py +++ b/ramannoodle/io/vasp/vasprun.py @@ -8,12 +8,13 @@ import numpy as np from numpy.typing import NDArray -from ramannoodle.io.io_utils import pathify +from ramannoodle.io.io_utils import pathify, _read_polarizability_dataset from ramannoodle.exceptions import InvalidFileException from ramannoodle.globals import ATOMIC_WEIGHTS, ATOMIC_NUMBERS from ramannoodle.dynamics.phonon import Phonons from ramannoodle.dynamics.trajectory import Trajectory from ramannoodle.structure.reference import ReferenceStructure +from ramannoodle.polarizability.torch.dataset import PolarizabilityDataset def _get_root_element(file: TextIO) -> Element: @@ -192,6 +193,30 @@ def read_structure_and_polarizability( return lattice, atomic_numbers, positions, polarizability +def read_polarizability_dataset( + filepaths: str | Path | list[str] | list[Path], +) -> PolarizabilityDataset: + """Read polarizability dataset from OUTCAR files. + + Parameters + ---------- + filepaths + + Returns + ------- + : + + Raises + ------ + FileNotFoundError + InvalidFileException + File has an unexpected format. + IncompatibleFileException + File is incompatible with the dataset. + """ + return _read_polarizability_dataset(filepaths, read_structure_and_polarizability) + + def read_positions(filepath: str | Path) -> NDArray[np.float64]: """Read fractional positions from a vasprun.xml file. diff --git a/ramannoodle/polarizability/torch/dataset.py b/ramannoodle/polarizability/torch/dataset.py index 1abd03d..b46a868 100644 --- a/ramannoodle/polarizability/torch/dataset.py +++ b/ramannoodle/polarizability/torch/dataset.py @@ -1,5 +1,7 @@ """Polarizability PyTorch dataset.""" +import copy + import numpy as np from numpy.typing import NDArray @@ -117,19 +119,88 @@ def __init__( # pylint: disable=too-many-arguments self._positions = torch.from_numpy(positions).type(default_type) self._polarizabilities = torch.from_numpy(polarizabilities) - mean, stddev, scaled = _scale_and_flatten_polarizabilities( + _, _, scaled = _scale_and_flatten_polarizabilities( self._polarizabilities, scale_mode=scale_mode ) - self._mean_polarizability = mean.type(default_type) - self._stddev_polarizability = stddev.type(default_type) self._scaled_polarizabilities = scaled.type(default_type) + @property + def num_atoms(self) -> int: + """Get number of atoms per sample.""" + return self._positions.size(1) + + @property + def num_samples(self) -> int: + """Get number of samples.""" + return self._positions.size(0) + + @property + def atomic_numbers(self) -> Tensor: + """Get (a copy of) atomic numbers. + + Returns + ------- + : + 2D tensor with size [S,N] where S is the number of samples and N is the + number of atoms. + """ + return copy.copy(self._atomic_numbers) + + @property + def positions(self) -> Tensor: + """Get (a copy of) positions. + + Returns + ------- + : + 3D tensor with size [S,N,3] where S is the number of samples and N is the + number of atoms. + """ + return self._positions.detach().clone() + + @property + def polarizabilities(self) -> Tensor: + """Get (a copy of) polarizabilities. + + Returns + ------- + : + 3D tensor with size [S,3,3] where S is the number of samples. + """ + return self._polarizabilities.detach().clone() + + @property + def scaled_polarizabilities(self) -> Tensor: + """Get (a copy of) scaled polarizabilities. + + Returns + ------- + : + 2D tensor with size [S,6] where S is the number of samples. + """ + return self._scaled_polarizabilities.detach().clone() + + @property + def mean_polarizability(self) -> Tensor: + """Get mean polarizability. + + Return + ------ + : + 2D tensor with size [3,3]. + """ + return self._polarizabilities.mean(0, keepdim=True) + + @property + def stddev_polarizability(self) -> Tensor: + """Get standard deviation of polarizability.""" + return self._polarizabilities.std(0, unbiased=False, keepdim=True) + def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None: """Standard-scale polarizabilities given a mean and standard deviation. - This method may be used to scale validation/test datasets according + This method may be used to scale validation or test datasets according to the mean and standard deviation of the training set, as is best practice. - This method does **not** update ... Parameters ---------- @@ -163,7 +234,7 @@ def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None: def __len__(self) -> int: """Get number of samples.""" - return len(self._positions) + return self.num_samples def __getitem__(self, i: int) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Get lattice, atomic numbers, positions, and scaled polarizabilities.""" diff --git a/test/Testing.ipynb b/test/Testing.ipynb index cd5b660..3f298e9 100644 --- a/test/Testing.ipynb +++ b/test/Testing.ipynb @@ -74,53 +74,266 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 99/99 [00:00<00:00, 125.21it/s]\n", - "100%|██████████| 100/100 [00:00<00:00, 114.93it/s]\n" + "100%|██████████| 99/99 [00:00<00:00, 118.96files/s]\n", + "100%|██████████| 100/100 [00:00<00:00, 117.27files/s]\n" ] } ], "source": [ - "from ramannoodle.polarizability.gnn import PolarizabilityDataset\n", - "\n", - "def read_polarizability_dataset(\n", - " filepaths: list[str] | list[Path],\n", - ") -> PolarizabilityDataset:\n", - " \"\"\"Read polarizability dataset from files.\"\"\"\n", - " filepaths = pathify_as_list(filepaths)\n", - " lattices = []\n", - " atomic_numbers_list = []\n", - " positions_list = []\n", - " polarizabilities = []\n", - " for filepath in tqdm(filepaths):\n", - " try:\n", - " lattice, atomic_numbers, positions, polarizability = (\n", - " vasp_io.outcar.read_structure_and_polarizability(filepath)\n", - " )\n", - " except InvalidFileException as exc:\n", - " raise InvalidFileException(f\"invalid file: {filepath}\") from exc\n", - "\n", - " lattices.append(lattice)\n", - " atomic_numbers_list.append(atomic_numbers)\n", - " positions_list.append(positions)\n", - " polarizabilities.append(polarizability)\n", - " return PolarizabilityDataset(\n", - " np.array(lattices),\n", - " atomic_numbers_list,\n", - " np.array(positions_list),\n", - " np.array(polarizabilities),\n", - " scale_mode = \"standard\"\n", - " )\n", + "import ramannoodle.io.vasp as vasp_io\n", "\n", "import glob\n", "\n", - "train_dataset = read_polarizability_dataset(\n", + "train_dataset = vasp_io.outcar.read_polarizability_dataset(\n", " list(glob.glob(\"/Volumes/Untitled/TiO2_eps/train/*ps*/scratch/OUTCAR\"))\n", ")\n", - "validation_dataset = read_polarizability_dataset(\n", + "validation_dataset = vasp_io.outcar.read_polarizability_dataset(\n", " list(glob.glob(\"/Volumes/Untitled/TiO2_eps/validation/*ps*/scratch/OUTCAR\"))\n", ")" ] }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dfb65dfe", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[11.37684345, 0. , 0. ],\n", + " [ 0. , 11.37684345, 0. ],\n", + " [ 0. , 0. , 9.6045742 ]]),\n", + " [22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 22,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8,\n", + " 8],\n", + " array([[0.16666667, 0.16666667, 0.5 ],\n", + " [0.16666667, 0.5 , 0.5 ],\n", + " [0.16666667, 0.83333331, 0.5 ],\n", + " [0.5 , 0.16666667, 0.5 ],\n", + " [0.5 , 0.5 , 0.5 ],\n", + " [0.5 , 0.83333331, 0.5 ],\n", + " [0.83333331, 0.16666667, 0.5 ],\n", + " [0.83333331, 0.5 , 0.5 ],\n", + " [0.83333331, 0.83333331, 0.5 ],\n", + " [0.16666667, 0. , 0.75 ],\n", + " [0.16666667, 0.33333334, 0.75 ],\n", + " [0.16666667, 0.66666669, 0.75 ],\n", + " [0.5 , 0. , 0.75 ],\n", + " [0.5 , 0.33333334, 0.75 ],\n", + " [0.5 , 0.66666669, 0.75 ],\n", + " [0.83333331, 0. , 0.75 ],\n", + " [0.83333331, 0.33333334, 0.75 ],\n", + " [0.83333331, 0.66666669, 0.75 ],\n", + " [0. , 0. , 0. ],\n", + " [0. , 0.33333334, 0. ],\n", + " [0. , 0.66666669, 0. ],\n", + " [0.33333334, 0. , 0. ],\n", + " [0.33333334, 0.33333334, 0. ],\n", + " [0.33333334, 0.66666669, 0. ],\n", + " [0.66666669, 0. , 0. ],\n", + " [0.66666669, 0.33333334, 0. ],\n", + " [0.66666669, 0.66666669, 0. ],\n", + " [0. , 0.16666667, 0.25 ],\n", + " [0. , 0.5 , 0.25 ],\n", + " [0. , 0.83333331, 0.25 ],\n", + " [0.33333334, 0.16666667, 0.25 ],\n", + " [0.33333334, 0.5 , 0.25 ],\n", + " [0.33333334, 0.83333331, 0.25 ],\n", + " [0.66666669, 0.16666667, 0.25 ],\n", + " [0.66666669, 0.5 , 0.25 ],\n", + " [0.66666669, 0.83333331, 0.25 ],\n", + " [0. , 0.16666667, 0.45789859],\n", + " [0. , 0.5 , 0.45789859],\n", + " [0. , 0.83333331, 0.45789859],\n", + " [0.33333334, 0.16666667, 0.45789859],\n", + " [0.33333334, 0.5 , 0.45789859],\n", + " [0.33333334, 0.83333331, 0.45789859],\n", + " [0.66666669, 0.16666667, 0.45789859],\n", + " [0.66666669, 0.5 , 0.45789859],\n", + " [0.66666669, 0.83333331, 0.45789859],\n", + " [0.16666667, 0.16666667, 0.70789862],\n", + " [0.16666667, 0.5 , 0.70789862],\n", + " [0.16666667, 0.83333331, 0.70789862],\n", + " [0.5 , 0.16666667, 0.70789862],\n", + " [0.5 , 0.5 , 0.70789862],\n", + " [0.5 , 0.83333331, 0.70789862],\n", + " [0.83333331, 0.16666667, 0.70789862],\n", + " [0.83333331, 0.5 , 0.70789862],\n", + " [0.83333331, 0.83333331, 0.70789862],\n", + " [0.16666667, 0. , 0.54210138],\n", + " [0.16666667, 0.33333334, 0.54210138],\n", + " [0.16666667, 0.66666669, 0.54210138],\n", + " [0.5 , 0. , 0.54210138],\n", + " [0.5 , 0.33333334, 0.54210138],\n", + " [0.5 , 0.66666669, 0.54210138],\n", + " [0.83333331, 0. , 0.54210138],\n", + " [0.83333331, 0.33333334, 0.54210138],\n", + " [0.83333331, 0.66666669, 0.54210138],\n", + " [0. , 0. , 0.79210138],\n", + " [0. , 0.33333334, 0.79210138],\n", + " [0. , 0.66666669, 0.79210138],\n", + " [0.33333334, 0. , 0.79210138],\n", + " [0.33333334, 0.33333334, 0.79210138],\n", + " [0.33333334, 0.66666669, 0.79210138],\n", + " [0.66666669, 0. , 0.79210138],\n", + " [0.66666669, 0.33333334, 0.79210138],\n", + " [0.66666669, 0.66666669, 0.79210138],\n", + " [0.16666667, 0. , 0.95789862],\n", + " [0.16666667, 0.33333334, 0.95789862],\n", + " [0.16666667, 0.66666669, 0.95789862],\n", + " [0.5 , 0. , 0.95789862],\n", + " [0.5 , 0.33333334, 0.95789862],\n", + " [0.5 , 0.66666669, 0.95789862],\n", + " [0.83333331, 0. , 0.95789862],\n", + " [0.83333331, 0.33333334, 0.95789862],\n", + " [0.83333331, 0.66666669, 0.95789862],\n", + " [0. , 0. , 0.2078986 ],\n", + " [0. , 0.33333334, 0.2078986 ],\n", + " [0. , 0.66666669, 0.2078986 ],\n", + " [0.33333334, 0. , 0.2078986 ],\n", + " [0.33333334, 0.33333334, 0.2078986 ],\n", + " [0.33333334, 0.66666669, 0.2078986 ],\n", + " [0.66666669, 0. , 0.2078986 ],\n", + " [0.66666669, 0.33333334, 0.2078986 ],\n", + " [0.66666669, 0.66666669, 0.2078986 ],\n", + " [0. , 0.16666667, 0.04210141],\n", + " [0. , 0.5 , 0.04210141],\n", + " [0. , 0.83333331, 0.04210141],\n", + " [0.33333334, 0.16666667, 0.04210141],\n", + " [0.33333334, 0.5 , 0.04210141],\n", + " [0.33333334, 0.83333331, 0.04210141],\n", + " [0.66666669, 0.16666667, 0.04210141],\n", + " [0.66666669, 0.5 , 0.04210141],\n", + " [0.66666669, 0.83333331, 0.04210141],\n", + " [0.16666667, 0.16666667, 0.29210141],\n", + " [0.16666667, 0.5 , 0.29210141],\n", + " [0.16666667, 0.83333331, 0.29210141],\n", + " [0.5 , 0.16666667, 0.29210141],\n", + " [0.5 , 0.5 , 0.29210141],\n", + " [0.5 , 0.83333331, 0.29210141],\n", + " [0.83333331, 0.16666667, 0.29210141],\n", + " [0.83333331, 0.5 , 0.29210141],\n", + " [0.83333331, 0.83333331, 0.29210141]]),\n", + " array([[ 6.95882353e+00, -5.75000000e-06, -8.96000000e-06],\n", + " [-2.90000000e-07, 6.95884757e+00, -5.35000000e-06],\n", + " [-4.13000000e-06, -6.04000000e-06, 6.37287749e+00]]))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vasp_io.vasprun.read_structure_and_polarizability(\"data/STO/vasprun.xml\")" + ] + }, { "cell_type": "code", "execution_count": 3, diff --git a/test/tests/torch/test_dataset.py b/test/tests/torch/test_dataset.py new file mode 100644 index 0000000..c9133e3 --- /dev/null +++ b/test/tests/torch/test_dataset.py @@ -0,0 +1,30 @@ +"""Testing for PyTorch dataset.""" + +import pytest + +import ramannoodle.io.generic as generic_io + + +@pytest.mark.parametrize( + "filepaths, file_format", + [ + ( + [ + "test/data/TiO2/O43_0.1x_eps_OUTCAR", + "test/data/TiO2/O43_0.1y_eps_OUTCAR", + "test/data/TiO2/O43_0.1z_eps_OUTCAR", + ], + "outcar", + ), + ("test/data/STO/vasprun.xml", "vasprun.xml"), + ], +) +def test_load_polarizability_dataset( + filepaths: str | list[str], file_format: str +) -> None: + """Test of generic load_polarizability_dataset (normal).""" + dataset = generic_io.read_polarizability_dataset(filepaths, file_format) + if isinstance(filepaths, list): + assert len(dataset) == len(filepaths) + else: + assert len(dataset) == 1