diff --git a/ramannoodle/polarizability/torch/__init__.py b/ramannoodle/polarizability/torch/__init__.py new file mode 100644 index 0000000..502e429 --- /dev/null +++ b/ramannoodle/polarizability/torch/__init__.py @@ -0,0 +1 @@ +"""Modules for polarizability model based on graph neural networks.""" diff --git a/ramannoodle/polarizability/torch/dataset.py b/ramannoodle/polarizability/torch/dataset.py new file mode 100644 index 0000000..1abd03d --- /dev/null +++ b/ramannoodle/polarizability/torch/dataset.py @@ -0,0 +1,175 @@ +"""Polarizability PyTorch dataset.""" + +import numpy as np +from numpy.typing import NDArray + +import torch +from torch import Tensor +from torch.utils.data import Dataset + +from ramannoodle.exceptions import verify_ndarray_shape, verify_list_len, get_type_error +import ramannoodle.polarizability.torch.utils as rn_torch_utils + + +def _scale_and_flatten_polarizabilities( + polarizabilities: Tensor, + scale_mode: str, +) -> tuple[Tensor, Tensor, Tensor]: + """Scale and flatten polarizabilities. + + 3x3 polarizabilities are flattened into 6-vectors: (xx,yy,zz,xy,xz,yz). + + Parameters + ---------- + polarizabilities + | 3D tensor with size [S,3,3] where S is the number of samples. + scale_mode + | Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by + | standard deviation), and ``"none"`` (no scaling). + + Returns + ------- + : + 3-tuple: + 0. | mean -- + | Element-wise mean of polarizabilities. + #. | standard deviation -- + | Element-wise standard deviation of polarizabilities. + #. | polarizability vectors -- + | 2D tensor with size [S,6]. + + """ + rn_torch_utils.verify_tensor_size( + "polarizabilities", polarizabilities, [None, 3, 3] + ) + + mean = polarizabilities.mean(0, keepdim=True) + stddev = polarizabilities.std(0, unbiased=False, keepdim=True) + if scale_mode == "standard": + polarizabilities = (polarizabilities - mean) / stddev + elif scale_mode == "stddev": + polarizabilities = (polarizabilities - mean) / stddev + mean + elif scale_mode != "none": + raise ValueError(f"unsupported scale mode: {scale_mode}") + + scaled_polarizabilities = torch.zeros((polarizabilities.size(0), 6)) + scaled_polarizabilities[:, 0] = polarizabilities[:, 0, 0] + scaled_polarizabilities[:, 1] = polarizabilities[:, 1, 1] + scaled_polarizabilities[:, 2] = polarizabilities[:, 2, 2] + scaled_polarizabilities[:, 3] = polarizabilities[:, 0, 1] + scaled_polarizabilities[:, 4] = polarizabilities[:, 0, 2] + scaled_polarizabilities[:, 5] = polarizabilities[:, 1, 2] + + return mean, stddev, scaled_polarizabilities + + +class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]): + """PyTorch dataset of atomic structures and polarizabilities. + + Polarizabilities are scaled and flattened into vectors containing the six + independent tensor components. + + Parameters + ---------- + lattices + | (Å) 3D array with shape (S,3,3) where S is the number of samples. + atomic_numbers + | List of length S containing lists of length N, where N is the number of atoms. + positions + | (fractional) 3D array with shape (S,N,3). + polarizabilities + | 3D array with shape (S,3,3). + scale_mode + | Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by + | standard deviation), and ``"none"`` (no scaling). + + """ + + def __init__( # pylint: disable=too-many-arguments + self, + lattices: NDArray[np.float64], + atomic_numbers: list[list[int]], + positions: NDArray[np.float64], + polarizabilities: NDArray[np.float64], + scale_mode: str = "standard", + ): + verify_ndarray_shape("lattices", lattices, (None, 3, 3)) + num_samples = lattices.shape[0] + verify_list_len("atomic_numbers", atomic_numbers, num_samples) + num_atoms = None + for i, sublist in enumerate(atomic_numbers): + verify_list_len(f"atomic_numbers[{i}]", sublist, num_atoms) + if num_atoms is None: + num_atoms = len(sublist) + verify_ndarray_shape("positions", positions, (num_samples, num_atoms, 3)) + verify_ndarray_shape( + "polarizabilities", polarizabilities, (num_samples, None, None) + ) + + default_type = torch.get_default_dtype() + self._lattices = torch.from_numpy(lattices).type(default_type) + try: + self._atomic_numbers = torch.tensor(atomic_numbers).type(torch.int) + except (TypeError, ValueError) as exc: + raise get_type_error( + "atomic_numbers", atomic_numbers, "list[list[int]]" + ) from exc + self._positions = torch.from_numpy(positions).type(default_type) + self._polarizabilities = torch.from_numpy(polarizabilities) + + mean, stddev, scaled = _scale_and_flatten_polarizabilities( + self._polarizabilities, scale_mode=scale_mode + ) + self._mean_polarizability = mean.type(default_type) + self._stddev_polarizability = stddev.type(default_type) + self._scaled_polarizabilities = scaled.type(default_type) + + def scale_polarizabilities(self, mean: Tensor, stddev: Tensor) -> None: + """Standard-scale polarizabilities given a mean and standard deviation. + + This method may be used to scale validation/test datasets according + to the mean and standard deviation of the training set, as is best practice. + This method does **not** update ... + + Parameters + ---------- + mean + | 2D tensor with size [3,3] or 1D tensor. + stddev + | 2D tensor with size [3,3] or 1D tensor. + + """ + _, _, scaled = _scale_and_flatten_polarizabilities( + self._polarizabilities, scale_mode="none" + ) + try: + scaled = self._polarizabilities - mean + except TypeError as exc: + raise get_type_error("mean", mean, "Tensor") from exc + except RuntimeError as exc: + raise rn_torch_utils.get_tensor_size_error( + "mean", mean, "[3,3] or [1]" + ) from exc + try: + scaled /= stddev + except TypeError as exc: + raise get_type_error("stddev", stddev, "Tensor") from exc + except RuntimeError as exc: + raise rn_torch_utils.get_tensor_size_error( + "stddev", stddev, "[3,3] or [1]" + ) from exc + + self._scaled_polarizabilities = scaled + + def __len__(self) -> int: + """Get number of samples.""" + return len(self._positions) + + def __getitem__(self, i: int) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Get lattice, atomic numbers, positions, and scaled polarizabilities.""" + return ( + self._lattices[i], + self._atomic_numbers[i], + self._positions[i], + self._scaled_polarizabilities[i], + ) diff --git a/ramannoodle/polarizability/gnn.py b/ramannoodle/polarizability/torch/gnn.py similarity index 60% rename from ramannoodle/polarizability/gnn.py rename to ramannoodle/polarizability/torch/gnn.py index 0b70465..f8feb8d 100644 --- a/ramannoodle/polarizability/gnn.py +++ b/ramannoodle/polarizability/torch/gnn.py @@ -1,14 +1,8 @@ """Polarizability model based on a graph neural network (GNN).""" -# pylint: disable=not-callable - from __future__ import annotations import typing -from typing import Sequence - -import numpy as np -from numpy.typing import NDArray import torch from torch import Tensor @@ -21,7 +15,6 @@ Module, LayerNorm, ) -from torch.utils.data import Dataset from torch_geometric.nn.inits import reset from torch_geometric.nn.models.dimenet import triplets @@ -29,186 +22,9 @@ from torch_geometric.utils import scatter from ramannoodle.structure.reference import ReferenceStructure -from ramannoodle.exceptions import get_type_error - - -def _get_tensor_size_str(size: Sequence[int | None]) -> str: - """Get a string representing a tensor size. - - "_" indicates a dimension can be any size. - - Parameters - ---------- - size - | None indicates dimension can be any size. - """ - result = "[" - for i in size: - if i is None: - result += "_," - else: - result += f"{i}," - if len(size) == 1: - return result + "]" - return result[:-1] + "]" - - -def _get_size_error_tensor(name: str, tensor: Tensor, desired_size: str) -> ValueError: - """Get ValueError indicating a pytorch Tensor has the wrong size.""" - try: - shape_spec = f"{_get_tensor_size_str(tensor.size())} != {desired_size}" - except AttributeError as exc: - raise get_type_error("tensor", tensor, "Tensor") from exc - return ValueError(f"{name} has wrong size: {shape_spec}") - - -def _scale_and_flatten_polarizabilities( - polarizabilities: Tensor, - scale_mode: str, -) -> tuple[Tensor, Tensor, Tensor]: - """Scale and flatten polarizabilities. - - 3x3 polarizabilities are flattened into 6-vectors: (xx,yy,zz,xy,xz,yz). +import ramannoodle.polarizability.torch.utils as rn_torch_utils - Parameters - ---------- - polarizabilities - | 3D tensor with size [S,3,3] where S is the number of samples. - scale_mode - | Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by - | standard deviation), and ``"none"`` (no scaling). - - Returns - ------- - : - 3-tuple: - 0. | mean -- - | Element-wise mean of polarizabilities. - #. | standard deviation -- - | Element-wise standard deviation of polarizabilities. - #. | polarizability vectors -- - | 2D tensor with size [S,6]. - - """ - mean = polarizabilities.mean(0, keepdim=True) - stddev = polarizabilities.std(0, unbiased=False, keepdim=True) - if scale_mode == "standard": - polarizabilities = (polarizabilities - mean) / stddev - elif scale_mode == "stddev": - polarizabilities = (polarizabilities - mean) / stddev + mean - elif scale_mode != "none": - raise ValueError("invalid scale option") - - scaled_polarizabilities = torch.zeros((polarizabilities.size(0), 6)) - scaled_polarizabilities[:, 0] = polarizabilities[:, 0, 0] - scaled_polarizabilities[:, 1] = polarizabilities[:, 1, 1] - scaled_polarizabilities[:, 2] = polarizabilities[:, 2, 2] - scaled_polarizabilities[:, 3] = polarizabilities[:, 0, 1] - scaled_polarizabilities[:, 4] = polarizabilities[:, 0, 2] - scaled_polarizabilities[:, 5] = polarizabilities[:, 1, 2] - - return mean, stddev, scaled_polarizabilities - - -def _get_rotations(targets: Tensor) -> Tensor: - """Get rotation matrices from (1,0,0) to target vectors. - - Parameters - ---------- - targets - | 2D tensor with size [S,3]. Vectors do not need to be normalized. - - Returns - ------- - : - 3D tensor with size [S,3,3]. - """ - reference = torch.zeros(targets.size()) - reference[:, 0] = 1 - - a = reference / torch.linalg.norm(reference, dim=1).view(-1, 1) - b = targets / torch.linalg.norm(targets, dim=1).view(-1, 1) - - v = torch.linalg.cross(a, b) # This will be (0,0,0) if a == b - c = torch.linalg.vecdot(a, b) - s = torch.linalg.norm(v, dim=1) # This will be zero if a == b - - k_matrix = torch.zeros((len(v), 3, 3)) - k_matrix[:, 0, 1] = -v[:, 2] - k_matrix[:, 0, 2] = v[:, 1] - k_matrix[:, 1, 0] = v[:, 2] - k_matrix[:, 1, 2] = -v[:, 0] - k_matrix[:, 2, 0] = -v[:, 1] - k_matrix[:, 2, 1] = v[:, 0] - - rotations = torch.zeros((len(v), 3, 3)) - rotations[:] = torch.eye(3) # Rotation starts as identity - rotations += k_matrix - a1 = k_matrix.matmul(k_matrix) - b1 = (1 - c) / (s**2) - rotations += a1 * b1[:, None, None] - - # a==b implies s==0, which implies rotation should be identity/ - rotations[s == 0] = torch.eye(3) - - return rotations - - -class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]): - """PyTorch dataset of atomic structures and polarizabilities. - - Polarizabilities are scaled and flattened into vectors containing the six - independent tensor components. - - Parameters - ---------- - lattices - | (Å) 3D array with shape (S,3,3) where S is the number of samples. - atomic_numbers - | List of length S containing lists of length N, where N is the number of atoms. - positions - | (fractional) 3D array with shape (S,N,3). - polarizabilities - | 3D array with shape (S,3,3). - scale_mode - | Supports ``"standard"`` (standard scaling), ``"stddev"`` (division by - | standard deviation), and ``"none"`` (no scaling). - - """ - - def __init__( # pylint: disable=too-many-arguments - self, - lattices: NDArray[np.float64], - atomic_numbers: list[list[int]], - positions: NDArray[np.float64], - polarizabilities: NDArray[np.float64], - scale_mode: str = "standard", - ): - default_type = torch.get_default_dtype() - - self._lattices = torch.from_numpy(lattices).type(default_type) - self._atomic_numbers = torch.tensor(atomic_numbers) - self._positions = torch.from_numpy(positions).type(default_type) - - mean, stddev, scaled = _scale_and_flatten_polarizabilities( - torch.from_numpy(polarizabilities), scale_mode=scale_mode - ) - self._mean_polarizability = mean.type(default_type) - self._stddev_polarizability = stddev.type(default_type) - self._polarizabilities = scaled.type(default_type) - - def __len__(self) -> int: - """Get length.""" - return len(self._positions) - - def __getitem__(self, i: int) -> tuple[Tensor, Tensor, Tensor, Tensor]: - """Get lattice, atomic numbers, positions, and polarizabilities.""" - return ( - self._lattices[i], - self._atomic_numbers[i], - self._positions[i], - self._polarizabilities[i], - ) +# pylint: disable=not-callable class GaussianFilter(torch.nn.Module): @@ -220,9 +36,9 @@ class GaussianFilter(torch.nn.Module): Parameters ---------- start - | Lower bound for filter input. + | Lower bound. stop - | Upper bound for filter input. + | Upper bound. steps | Number of steps between start and stop. """ @@ -239,12 +55,12 @@ def forward(self, x: Tensor) -> Tensor: Parameters ---------- X - | 1D tensor with shape [D,]. Typically contains interatomic distances. + | 1D tensor with size [D,]. Typically contains interatomic distances. Returns ------- : - 2D tensor with shape [D,steps]. + 2D tensor with size [D,steps]. """ x = x.view(-1, 1) - self.offset.view(1, -1) @@ -274,7 +90,7 @@ def __init__( # Combination of two linear layers, dubbed "core" and "filter": # "filter" : c1ij W1 + b1 # "core" : c1ij W2 + b2 - self.linear_c1 = Linear( + self.c1_linear = Linear( size_node_embedding + size_edge_embedding, 2 * size_node_embedding, ) @@ -284,7 +100,7 @@ def __init__( def reset_parameters(self) -> None: """Reset model parameters.""" - self.linear_c1.reset_parameters() + self.c1_linear.reset_parameters() self.c1_norm.reset_parameters() self.final_norm.reset_parameters() @@ -308,7 +124,7 @@ def forward( 2D tensor with size [N,size_node_embedding]. """ c1 = torch.cat([node_embedding[i], edge_embedding], dim=1) - c1 = self.c1_norm(self.linear_c1(c1)) + c1 = self.c1_norm(self.c1_linear(c1)) c1_filter, c1_core = c1.chunk(2, dim=1) c1_filter = c1_filter.sigmoid() c1_core = c1_core.tanh() @@ -337,29 +153,29 @@ def __init__( # Combination of two linear layers, dubbed "core" and "filter": # "filter" : c2ij W3 + b3 # "core" : c2ij W4 + b4 - self.lin_c2 = Linear(size_node_embedding, 2 * size_edge_embedding) + self.c2_linear = Linear(size_node_embedding, 2 * size_edge_embedding) # Combination of two linear layers, dubbed "core" and "filter": # "filter" : c3ij W5 + b5 # "core" : c3ij W6 + b6 - self.lin_c3 = Linear( + self.c3_linear = Linear( 3 * size_node_embedding + 2 * size_edge_embedding, 2 * size_edge_embedding, ) - self.bn_c2 = LayerNorm(2 * size_edge_embedding) - self.bn_c3 = LayerNorm(2 * size_edge_embedding) - self.bn_c2_2 = LayerNorm(size_edge_embedding) - self.bn_c3_2 = LayerNorm(size_edge_embedding) + self.c2_norm_1 = LayerNorm(2 * size_edge_embedding) + self.c3_norm_1 = LayerNorm(2 * size_edge_embedding) + self.c2_norm_2 = LayerNorm(size_edge_embedding) + self.c3_norm_2 = LayerNorm(size_edge_embedding) def reset_parameters(self) -> None: """Reset model parameters.""" - self.lin_c2.reset_parameters() - self.lin_c3.reset_parameters() - self.bn_c2.reset_parameters() - self.bn_c3.reset_parameters() - self.bn_c2_2.reset_parameters() - self.bn_c3_2.reset_parameters() + self.c2_linear.reset_parameters() + self.c3_linear.reset_parameters() + self.c2_norm_1.reset_parameters() + self.c3_norm_1.reset_parameters() + self.c2_norm_2.reset_parameters() + self.c3_norm_2.reset_parameters() def _get_c2_embedding( self, @@ -369,11 +185,11 @@ def _get_c2_embedding( ) -> Tensor: """Get c2 embedding.""" c2 = node_embedding[i] * node_embedding[j] - c2 = self.bn_c2(self.lin_c2(c2)) + c2 = self.c2_norm_1(self.c2_linear(c2)) c2_filter, c2_core = c2.chunk(2, dim=1) c2_filter = c2_filter.sigmoid() c2_core = c2_core.tanh() - return typing.cast(Tensor, self.bn_c2_2(c2_filter * c2_core)) + return typing.cast(Tensor, self.c2_norm_2(c2_filter * c2_core)) def _get_c3_embedding( # pylint: disable=too-many-arguments self, @@ -396,7 +212,7 @@ def _get_c3_embedding( # pylint: disable=too-many-arguments ], dim=1, ) - c3 = self.bn_c3(self.lin_c3(c3)) + c3 = self.c3_norm_1(self.c3_linear(c3)) c3_filter, c3_core = c3.chunk(2, dim=1) c3_filter = c3_filter.sigmoid() c3_core = c3_core.tanh() @@ -407,7 +223,7 @@ def _get_c3_embedding( # pylint: disable=too-many-arguments dim_size=edge_embedding.size(0), reduce="sum", ) - return typing.cast(Tensor, self.bn_c3_2(c3_emb)) + return typing.cast(Tensor, self.c3_norm_2(c3_emb)) def forward( # pylint: disable=too-many-arguments self, @@ -435,55 +251,6 @@ def forward( # pylint: disable=too-many-arguments return (edge_embedding + c2_embedding + c3_embedding).tanh() -def _polarizability_tensors_to_vectors(polarizability_tensors: Tensor) -> Tensor: - """Convert polarizability vectors to tensors. - - Parameters - ---------- - polarizability_vectors - Tensor with size [S,6]. - - Returns - ------- - : - Symmetric tensor with size [S,3,3]. - """ - indices = torch.tensor([[0, 0], [1, 1], [2, 2], [0, 1], [0, 2], [1, 2]]).T - return polarizability_tensors[:, indices[0], indices[1]] - - -def polarizability_vectors_to_tensors(polarizability_vectors: Tensor) -> Tensor: - """Convert polarizability vectors to tensors. - - Parameters - ---------- - polarizability_vectors - Tensor with size [S,6]. - - Returns - ------- - : - Symmetric tensor with size [S,3,3]. - """ - indices = torch.tensor( - [ - [0, 3, 4], - [3, 1, 5], - [4, 5, 2], - ] - ) - try: - return polarizability_vectors[:, indices] - except IndexError as exc: - raise _get_size_error_tensor( - "polarizability_vectors", polarizability_vectors, "[_,6]" - ) from exc - except TypeError as exc: - raise get_type_error( - "polarizability_vectors", polarizability_vectors, "Tensor" - ) from exc - - class PotGNN(Module): # pylint: disable = too-many-instance-attributes r"""POlarizability Tensor Graph Neural Network (PotGNN). @@ -519,16 +286,18 @@ def __init__( # pylint: disable=too-many-arguments self._cutoff = cutoff # Set up graph. - self._ref_edge_indexes, _, self._ref_distances = _radius_graph_pbc( - torch.from_numpy(ref_structure.lattice) - .unsqueeze(0) - .type(default_type) - .to(default_device), - torch.from_numpy(ref_structure.positions) - .unsqueeze(0) - .type(default_type) - .to(default_device), - cutoff, + self._ref_edge_indexes, _, self._ref_distances = ( + rn_torch_utils._radius_graph_pbc( + torch.from_numpy(ref_structure.lattice) + .unsqueeze(0) + .type(default_type) + .to(default_device), + torch.from_numpy(ref_structure.positions) + .unsqueeze(0) + .type(default_type) + .to(default_device), + cutoff, + ) ) self._num_nodes = len(ref_structure.atomic_numbers) self._num_edges = len(self._ref_edge_indexes[1]) @@ -615,17 +384,6 @@ def reset_parameters(self) -> None: edge_block.reset_parameters() reset(self._polarizability_predictor) - def _get_polarizability_tensors(self, x: Tensor) -> Tensor: - """X should have size (_,6).""" - indices = torch.tensor( - [ - [0, 3, 4], - [3, 1, 5], - [4, 5, 2], - ] - ) - return x[:, indices] - def _get_edge_polarizability_tensor( self, x: Tensor ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -708,21 +466,9 @@ def _batch_graph( cart_displacement = displacement.matmul(expanded_lattice) cart_distance_matrix = torch.sqrt(torch.sum(cart_displacement**2, dim=-1)) - cart_unit_vectors = cart_displacement[ - edge_indexes[0], edge_indexes[1], edge_indexes[2] - ] # python 3.10 complains if we use the unpacking operator (*) - cart_unit_vectors /= torch.linalg.norm(cart_unit_vectors, dim=-1)[:, None] - distances = cart_distance_matrix[ - edge_indexes[0], edge_indexes[1], edge_indexes[2] - ].view(-1, 1) - # ref_distances = self._ref_distances.repeat((num_samples, 1)) - # distances -= ref_distances - - # Disconnect all S graphs - edge_indexes[1] += edge_indexes[0] * num_atoms - edge_indexes[2] += edge_indexes[0] * num_atoms - - return edge_indexes, cart_unit_vectors, distances + return rn_torch_utils.get_graph_info( + cart_displacement, edge_indexes, cart_distance_matrix, num_atoms + ) def _get_batch_triplets( self, @@ -829,7 +575,7 @@ def forward( # pylint: disable=too-many-locals t1, t2, t3, t4, t5, t6 = self._get_edge_polarizability_tensor( edge_polarizability ) - rotation = _get_rotations(unit_vec) + rotation = rn_torch_utils.get_rotations(unit_vec) inv_rotation = torch.linalg.inv(rotation) t1 = rotation @ t1 @ inv_rotation @@ -864,64 +610,3 @@ def forward( # pylint: disable=too-many-locals polarizability[structure_i] = torch.sum(edge_polarizability * mask, dim=0) polarizability[structure_i] /= count return polarizability - - -def _radius_graph_pbc( - lattice: Tensor, positions: Tensor, cutoff: float -) -> tuple[Tensor, Tensor, Tensor]: - """Generate graph for structures while respecting periodic boundary conditions. - - Parameters - ---------- - lattice - | (Å) 3D tensor with size [S,3,3] where S is the number of samples. - positions - | (fractional) 3D tensor with size [S,N,3] where N is the number of atoms. - cutoff - | Edge cutoff distance. - - Returns - ------- - : - 3-tuple. - First element is edge indexes, a tensor of size [3,X] where X is the number of - edges. This tensor defines S non-interconnected graphs making up a batch. The - first row defines the graph index. The second and third rows define the actual - edge indexes used by ``triplet``. - Second element is cartesian unit vectors, a tensor of size [X,3]. - Third element is distances, a tensor of side [X,1]. - - """ - num_samples = lattice.size(0) - num_atoms = positions.size(1) - - # Compute pairwise distance matrix. - displacement = positions.unsqueeze(1) - positions.unsqueeze(2) - displacement = torch.where( - displacement % 1 > 0.5, displacement % 1 - 1, displacement % 1 - ) - expanded_lattice = lattice[:, None, :, :].expand(-1, displacement.size(1), -1, -1) - cart_displacement = displacement.matmul(expanded_lattice) - cart_distance_matrix = torch.sqrt(torch.sum(cart_displacement**2, dim=-1)) - - # Compute adjacency matrix - adjacency_matrix = cart_distance_matrix <= cutoff - not_self_loop = ~torch.eye(adjacency_matrix.size(-1), dtype=torch.bool).expand( - num_samples, -1, -1 - ) - adjacency_matrix = torch.logical_and(adjacency_matrix, not_self_loop) - - edge_indexes = torch.nonzero(adjacency_matrix).T - cart_unit_vectors = cart_displacement[ - edge_indexes[0], edge_indexes[1], edge_indexes[2] - ] - cart_unit_vectors /= torch.linalg.norm(cart_unit_vectors, dim=-1)[:, None] - distances = cart_distance_matrix[ - edge_indexes[0], edge_indexes[1], edge_indexes[2] - ].view(-1, 1) - - # Disconnect all S graphs - edge_indexes[1] += edge_indexes[0] * num_atoms - edge_indexes[2] += edge_indexes[0] * num_atoms - - return edge_indexes, cart_unit_vectors, distances diff --git a/ramannoodle/polarizability/torch/utils.py b/ramannoodle/polarizability/torch/utils.py new file mode 100644 index 0000000..f284df0 --- /dev/null +++ b/ramannoodle/polarizability/torch/utils.py @@ -0,0 +1,223 @@ +"""Utility functions for PyTorch models.""" + +from typing import Sequence + +import torch +from torch import Tensor + +from ramannoodle.exceptions import get_type_error + +# pylint complains about torch.norm +# pylint: disable=not-callable + + +def polarizability_vectors_to_tensors(polarizability_vectors: Tensor) -> Tensor: + """Convert polarizability vectors to tensors. + + Parameters + ---------- + polarizability_vectors + Tensor with size [S,6]. + + Returns + ------- + : + Symmetric tensor with size [S,3,3]. + """ + indices = torch.tensor( + [ + [0, 3, 4], + [3, 1, 5], + [4, 5, 2], + ] + ) + try: + return polarizability_vectors[:, indices] + except IndexError as exc: + raise get_tensor_size_error( + "polarizability_vectors", polarizability_vectors, "[_,6]" + ) from exc + except TypeError as exc: + raise get_type_error( + "polarizability_vectors", polarizability_vectors, "Tensor" + ) from exc + + +def _get_polarizability_tensors(x: Tensor) -> Tensor: + """X should have size (_,6).""" + indices = torch.tensor( + [ + [0, 3, 4], + [3, 1, 5], + [4, 5, 2], + ] + ) + return x[:, indices] + + +def _get_tensor_size_str(size: Sequence[int | None]) -> str: + """Get a string representing a tensor size. + + "_" indicates a dimension can be any size. + + Parameters + ---------- + size + | None indicates dimension can be any size. + """ + result = "[" + for i in size: + if i is None: + result += "_," + else: + result += f"{i}," + if len(size) == 1: + return result + "]" + return result[:-1] + "]" + + +def get_tensor_size_error(name: str, tensor: Tensor, desired_size: str) -> ValueError: + """Get ValueError indicating a PyTorch Tensor has the wrong size.""" + try: + shape_spec = f"{_get_tensor_size_str(tensor.size())} != {desired_size}" + except AttributeError as exc: + raise get_type_error("tensor", tensor, "Tensor") from exc + return ValueError(f"{name} has wrong size: {shape_spec}") + + +def verify_tensor_size(name: str, tensor: Tensor, size: Sequence[int | None]) -> None: + """Verify a PyTorch Tensor's size. + + :meta private: We should avoid calling this function whenever possible (EATF). + + Parameters + ---------- + size + int elements will be checked, None elements will not be. + """ + try: + if len(size) != tensor.ndim: + raise get_tensor_size_error(name, tensor, _get_tensor_size_str(size)) + for d1, d2 in zip(tensor.size(), size, strict=True): + if d2 is not None and d1 != d2: + raise get_tensor_size_error(name, tensor, _get_tensor_size_str(size)) + except AttributeError as exc: + raise get_type_error(name, tensor, "Tensor") from exc + + +def get_rotations(targets: Tensor) -> Tensor: + """Get rotation matrices from (1,0,0) to target vectors. + + Parameters + ---------- + targets + | 2D tensor with size [S,3]. Vectors do not need to be normalized. + + Returns + ------- + : + 3D tensor with size [S,3,3]. + """ + reference = torch.zeros(targets.size()) + reference[:, 0] = 1 + + a = reference / torch.linalg.norm(reference, dim=1).view(-1, 1) + b = targets / torch.linalg.norm(targets, dim=1).view(-1, 1) + + v = torch.linalg.cross(a, b) # This will be (0,0,0) if a == b + c = torch.linalg.vecdot(a, b) + s = torch.linalg.norm(v, dim=1) # This will be zero if a == b + + k_matrix = torch.zeros((len(v), 3, 3)) + k_matrix[:, 0, 1] = -v[:, 2] + k_matrix[:, 0, 2] = v[:, 1] + k_matrix[:, 1, 0] = v[:, 2] + k_matrix[:, 1, 2] = -v[:, 0] + k_matrix[:, 2, 0] = -v[:, 1] + k_matrix[:, 2, 1] = v[:, 0] + + rotations = torch.zeros((len(v), 3, 3)) + rotations[:] = torch.eye(3) # Rotation starts as identity + rotations += k_matrix + a1 = k_matrix.matmul(k_matrix) + b1 = (1 - c) / (s**2) + rotations += a1 * b1[:, None, None] + + # a==b implies s==0, which implies rotation should be identity/ + rotations[s == 0] = torch.eye(3) + + return rotations + + +def get_graph_info( + cart_displacement: Tensor, + edge_indexes: Tensor, + cart_distance_matrix: Tensor, + num_atoms: int, +) -> tuple[Tensor, Tensor, Tensor]: + """Get information on graph.""" + cart_unit_vectors = cart_displacement[ + edge_indexes[0], edge_indexes[1], edge_indexes[2] + ] # python 3.10 complains if we use the unpacking operator (*) + cart_unit_vectors /= torch.linalg.norm(cart_unit_vectors, dim=-1)[:, None] + distances = cart_distance_matrix[ + edge_indexes[0], edge_indexes[1], edge_indexes[2] + ].view(-1, 1) + # ref_distances = self._ref_distances.repeat((num_samples, 1)) + # distances -= ref_distances + # Disconnect all S graphs + edge_indexes[1] += edge_indexes[0] * num_atoms + edge_indexes[2] += edge_indexes[0] * num_atoms + return edge_indexes, cart_unit_vectors, distances + + +def _radius_graph_pbc( + lattice: Tensor, positions: Tensor, cutoff: float +) -> tuple[Tensor, Tensor, Tensor]: + """Generate graph for structures while respecting periodic boundary conditions. + + Parameters + ---------- + lattice + | (Å) 3D tensor with size [S,3,3] where S is the number of samples. + positions + | (fractional) 3D tensor with size [S,N,3] where N is the number of atoms. + cutoff + | Edge cutoff distance. + + Returns + ------- + : + 3-tuple. + First element is edge indexes, a tensor of size [3,X] where X is the number of + edges. This tensor defines S non-interconnected graphs making up a batch. The + first row defines the graph index. The second and third rows define the actual + edge indexes used by ``triplet``. + Second element is cartesian unit vectors, a tensor of size [X,3]. + Third element is distances, a tensor of side [X,1]. + + """ + num_samples = lattice.size(0) + num_atoms = positions.size(1) + + # Compute pairwise distance matrix. + displacement = positions.unsqueeze(1) - positions.unsqueeze(2) + displacement = torch.where( + displacement % 1 > 0.5, displacement % 1 - 1, displacement % 1 + ) + expanded_lattice = lattice[:, None, :, :].expand(-1, displacement.size(1), -1, -1) + cart_displacement = displacement.matmul(expanded_lattice) + cart_distance_matrix = torch.sqrt(torch.sum(cart_displacement**2, dim=-1)) + + # Compute adjacency matrix + adjacency_matrix = cart_distance_matrix <= cutoff + not_self_loop = ~torch.eye(adjacency_matrix.size(-1), dtype=torch.bool).expand( + num_samples, -1, -1 + ) + adjacency_matrix = torch.logical_and(adjacency_matrix, not_self_loop) + + edge_indexes = torch.nonzero(adjacency_matrix).T + + return get_graph_info( + cart_displacement, edge_indexes, cart_distance_matrix, num_atoms + ) diff --git a/test/tests/test_gnn.py b/test/tests/test_gnn.py index 8ac1e51..105b740 100644 --- a/test/tests/test_gnn.py +++ b/test/tests/test_gnn.py @@ -1,16 +1,14 @@ """Testing for GNN-based models.""" +import os + import pytest # import numpy as np import torch -from ramannoodle.polarizability.gnn import ( - PotGNN, - _radius_graph_pbc, - _get_rotations, - # polarizability_vectors_to_tensors, -) +from ramannoodle.polarizability.torch.gnn import PotGNN +from ramannoodle.polarizability.torch.utils import _radius_graph_pbc, get_rotations # import ramannoodle.io.vasp as vasp_io # from ramannoodle.structure.structure_utils import apply_pbc @@ -20,10 +18,10 @@ def test_get_rotations() -> None: - """Test _get_rotations (normal).""" + """Test get_rotations (normal).""" unit_vector = torch.randn((40, 3)) - rotation = _get_rotations(unit_vector) + rotation = get_rotations(unit_vector) rotated = torch.matmul(rotation, torch.tensor([1.0, 0.0, 0.0])) known = unit_vector / torch.linalg.norm(unit_vector, dim=1).view(-1, 1) @@ -118,7 +116,9 @@ def test_gpu(poscar_ref_structure_fixture: ReferenceStructure) -> None: device = "cpu" if torch.cuda.is_available(): device = "cuda" - elif torch.backends.mps.is_available(): + # mps backend doesn't work with github runners + # https://github.com/actions/runner-images/issues/9918 + elif torch.backends.mps.is_available() and os.getenv("GITHUB_ACTIONS") == "true": device = "mps" else: return