Skip to content

Commit

Permalink
initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 12, 2024
1 parent 1285fc4 commit 2f73f73
Show file tree
Hide file tree
Showing 3 changed files with 460 additions and 234 deletions.
200 changes: 145 additions & 55 deletions ramannoodle/polarizability/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@

import torch
from torch import Tensor
from torch.nn import BatchNorm1d, Embedding, Linear, ModuleList, Sequential, Module
from torch.nn import (
BatchNorm1d,
Embedding,
Linear,
ModuleList,
Sequential,
Module,
LayerNorm,
)
from torch.utils.data import Dataset

from torch_geometric.nn.inits import reset
Expand All @@ -24,11 +32,10 @@
from ramannoodle.exceptions import get_type_error


def _size_string_tensor(size: Sequence[int | None]) -> str:
def _get_tensor_size_str(size: Sequence[int | None]) -> str:
"""Get a string representing a tensor size.
Maps None --> "_", indicating that this element can
be anything.
Maps None --> "_", indicating that this dimension can be any size.
"""
result = "["
for i in size:
Expand All @@ -43,12 +50,13 @@ def _size_string_tensor(size: Sequence[int | None]) -> str:

def _get_size_error_tensor(name: str, tensor: Tensor, desired_size: str) -> ValueError:
"""Return ValueError for an pytorch tensor with the wrong size."""
shape_spec = f"{_size_string_tensor(tensor.size())} != {desired_size}"
shape_spec = f"{_get_tensor_size_str(tensor.size())} != {desired_size}"
return ValueError(f"{name} has wrong size: {shape_spec}")


def _get_scaled_polarizabilities(
polarizabilities: Tensor,
scale: str,
) -> tuple[Tensor, Tensor, Tensor]:
"""Compute standard-scaled and flattened (6 member) polarizabilities.
Expand All @@ -67,8 +75,13 @@ def _get_scaled_polarizabilities(
"""
mean = polarizabilities.mean(0, keepdim=True)
stddev = torch.max(polarizabilities.std(0, unbiased=False, keepdim=True))
polarizabilities = (polarizabilities - mean) / stddev
stddev = polarizabilities.std(0, unbiased=False, keepdim=True)
if scale == "standard":
polarizabilities = (polarizabilities - mean) / stddev
elif scale == "stddev":
polarizabilities = (polarizabilities - mean) / stddev + mean
elif scale != "none":
raise ValueError("invalid scale option")

scaled_polarizabilities = torch.zeros((polarizabilities.size(0), 6))
scaled_polarizabilities[:, 0] = polarizabilities[:, 0, 0]
Expand Down Expand Up @@ -124,6 +137,16 @@ def _get_rotations(targets: Tensor) -> Tensor:
return rotations


def _get_atom_types(atomic_numbers: list[list[int]]) -> Tensor:
"""Convert atomic numbers into a one-hot encoding."""
ndarray = np.array(atomic_numbers)
unique_atomic_numbers = set(ndarray.flatten())

for i, atomic_number in enumerate(unique_atomic_numbers):
ndarray[ndarray == atomic_number] = i
return torch.tensor(ndarray)


class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]):
"""PyTorch dataset of atomic structures and polarizabilities.
Expand All @@ -143,34 +166,36 @@ class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]):
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
lattices: NDArray[np.float64],
atomic_numbers: list[list[int]],
positions: NDArray[np.float64],
polarizabilities: NDArray[np.float64],
scale: str = "none",
):
default_type = torch.get_default_dtype()
self._lattices = torch.from_numpy(lattices).type(default_type)
self._atomic_numbers = torch.tensor(atomic_numbers)
self._positions = torch.from_numpy(positions).type(default_type)
self._polarizabilities = torch.from_numpy(polarizabilities).type(default_type)
mean, stddev, scaled = _get_scaled_polarizabilities(self._polarizabilities)
mean, stddev, scaled = _get_scaled_polarizabilities(
torch.from_numpy(polarizabilities), scale=scale
)
self._mean_polarizability = mean.type(default_type)
self._stddev_polarizability = stddev.type(default_type)
self._scaled_polarizabilities = scaled.type(default_type)
self._polarizabilities = scaled.type(default_type)

def __len__(self) -> int:
"""Get length."""
return len(self._positions)

def __getitem__(self, i: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Get lattice, atomic numbers, positions, and scaled polarizabilities."""
"""Get lattice, atomic numbers, positions, and polarizabilities."""
return (
self._lattices[i],
self._atomic_numbers[i],
self._positions[i],
self._scaled_polarizabilities[i],
self._polarizabilities[i],
)


Expand Down Expand Up @@ -216,7 +241,11 @@ class NodeBlock(torch.nn.Module):
s41524-021-00543-3. The architecture has been modified to use batch normalization.
"""

def __init__(self, size_node_embedding: int, size_edge_embedding: int):
def __init__(
self,
size_node_embedding: int,
size_edge_embedding: int,
):
super().__init__()

# Combination of two linear layers, dubbed "core" and "filter":
Expand All @@ -227,8 +256,8 @@ def __init__(self, size_node_embedding: int, size_edge_embedding: int):
2 * size_node_embedding,
)

self.bn_c1 = BatchNorm1d(2 * size_node_embedding)
self.bn = BatchNorm1d(size_node_embedding)
self.bn_c1 = LayerNorm(2 * size_node_embedding)
self.bn = LayerNorm(size_node_embedding)

def reset_parameters(self) -> None:
"""Reset model parameters."""
Expand Down Expand Up @@ -263,7 +292,11 @@ class EdgeBlock(torch.nn.Module):
s41524-021-00543-3. The architecture has been modified to use batch normalization.
"""

def __init__(self, size_node_embedding: int, size_edge_embedding: int):
def __init__(
self,
size_node_embedding: int,
size_edge_embedding: int,
):
super().__init__()

# Combination of two linear layers, dubbed "core" and "filter":
Expand All @@ -279,10 +312,10 @@ def __init__(self, size_node_embedding: int, size_edge_embedding: int):
2 * size_edge_embedding,
)

self.bn_c2 = BatchNorm1d(2 * size_edge_embedding)
self.bn_c3 = BatchNorm1d(2 * size_edge_embedding)
self.bn_c2_2 = BatchNorm1d(size_edge_embedding)
self.bn_c3_2 = BatchNorm1d(size_edge_embedding)
self.bn_c2 = LayerNorm(2 * size_edge_embedding)
self.bn_c3 = LayerNorm(2 * size_edge_embedding)
self.bn_c2_2 = LayerNorm(size_edge_embedding)
self.bn_c3_2 = LayerNorm(size_edge_embedding)

def reset_parameters(self) -> None:
"""Reset model parameters."""
Expand Down Expand Up @@ -416,19 +449,6 @@ def polarizability_vectors_to_tensors(polarizability_vectors: Tensor) -> Tensor:
) from exc


def _get_edge_polarizability_tensor(x: Tensor) -> tuple[Tensor, Tensor]:
"""X should have size (_,2)."""
diag_polarizability = torch.zeros((x.size(0), 3, 3))
diag_polarizability[:, 0, 0] = x[:, 0]
diag_polarizability[:, 1, 1] = x[:, 1]
diag_polarizability[:, 2, 2] = x[:, 1]
off_diag_polarizability = torch.zeros((x.size(0), 3, 3))
off_diag_polarizability[:, 0, 0] = x[:, 2]
off_diag_polarizability[:, 1, 1] = x[:, 3]
off_diag_polarizability[:, 2, 2] = x[:, 3]
return diag_polarizability, off_diag_polarizability


class PotGNN(Module): # pylint: disable = too-many-instance-attributes
r"""POlarizability Tensor Graph Neural Network (PotGNN).
Expand Down Expand Up @@ -468,8 +488,14 @@ def __init__( # pylint: disable=too-many-arguments
)
self._cutoff = cutoff

unique_atomic_numbers = set(ref_structure.atomic_numbers)
self._num_atom_types = len(unique_atomic_numbers)
self._atom_types = (torch.zeros(119) - 1).type(torch.int)
for atom_type, atomic_number in enumerate(unique_atomic_numbers):
self._atom_types[atomic_number] = atom_type

self._node_embedding = Sequential(
Embedding(95, size_node_embedding),
Embedding(self._num_atom_types, size_node_embedding),
ShiftedSoftplus(), # nonlinear activation layer
Linear(size_node_embedding, size_node_embedding),
ShiftedSoftplus(), # nonlinear activation layer
Expand All @@ -492,19 +518,15 @@ def __init__( # pylint: disable=too-many-arguments

self._polarizability_predictor = Sequential(
Linear(size_edge_embedding, size_edge_embedding),
BatchNorm1d(size_edge_embedding),
ShiftedSoftplus(),
Linear(size_edge_embedding, size_edge_embedding),
ShiftedSoftplus(),
Linear(size_edge_embedding, 4),
Linear(size_edge_embedding, 12),
)

# self._node_polarizability_predictor = Sequential(
# Linear(size_node_embedding, size_node_embedding),
# ShiftedSoftplus(),
# Linear(size_node_embedding, size_node_embedding),
# ShiftedSoftplus(),
# Linear(size_node_embedding, 3),
# )
def _convert_to_atom_type(self, atomic_numbers: Tensor) -> Tensor:
return self._atom_types[atomic_numbers]

def reset_parameters(self) -> None:
"""Reset model parameters."""
Expand All @@ -517,7 +539,53 @@ def reset_parameters(self) -> None:
reset(self._polarizability_predictor)
reset(self.node_polarizability_predictor)

def _get_batch_graph(
def _get_polarizability_tensors(self, x: Tensor) -> Tensor:
"""X should have size (_,6)."""
indices = torch.tensor(
[
[0, 3, 4],
[3, 1, 5],
[4, 5, 2],
]
)
return x[:, indices]

def _get_edge_polarizability_tensor(
self, x: Tensor
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""X should have size (_,2)."""
t1 = torch.zeros((x.size(0), 3, 3))
t1[:, 0, 0] = x[:, 0]
t1[:, 1, 1] = x[:, 1]
t1[:, 2, 2] = x[:, 1]
t2 = torch.zeros((x.size(0), 3, 3))
t2[:, 0, 0] = x[:, 2]
t2[:, 1, 1] = x[:, 3]
t2[:, 2, 2] = x[:, 3]
t3 = torch.zeros((x.size(0), 3, 3))
t3[:, 0, 0] = x[:, 4]
t3[:, 1, 1] = x[:, 5]
t3[:, 2, 2] = x[:, 5]
t4 = torch.zeros((x.size(0), 3, 3))
t4[:, 0, 0] = x[:, 6]
t4[:, 1, 1] = x[:, 7]
t4[:, 2, 2] = x[:, 7]
t5 = torch.zeros((x.size(0), 3, 3))
t5[:, 0, 0] = x[:, 8]
t5[:, 1, 1] = x[:, 9]
t5[:, 2, 2] = x[:, 9]
t6 = torch.zeros((x.size(0), 3, 3))
t6[:, 0, 0] = x[:, 10]
t6[:, 1, 1] = x[:, 11]
t6[:, 2, 2] = x[:, 11]
return t1, t2, t3, t4, t5, t6

def _get_polarizability_vectors(self, x: Tensor) -> Tensor:
"""X should have size (_,3,3)."""
indices = torch.tensor([[0, 0], [1, 1], [2, 2], [0, 1], [0, 2], [1, 2]]).T
return x[:, indices[0], indices[1]]

def _batch_graph(
self, lattice: Tensor, positions: Tensor
) -> tuple[Tensor, Tensor, Tensor]:
"""Generate edge indexes, unit vectors, and changes in distance.
Expand Down Expand Up @@ -571,8 +639,8 @@ def _get_batch_graph(
distances = cart_distance_matrix[
edge_indexes[0], edge_indexes[1], edge_indexes[2]
].view(-1, 1)
ref_distances = self._ref_distances.repeat((num_samples, 1))
distances -= ref_distances
# ref_distances = self._ref_distances.repeat((num_samples, 1))
# distances -= ref_distances

# Disconnect all S graphs
edge_indexes[1] += edge_indexes[0] * num_atoms
Expand Down Expand Up @@ -605,15 +673,15 @@ def forward( # pylint: disable=too-many-locals
:func:`polarizability_vectors_to_tensors`.
"""
edge_index, unit_vec, dist = self._get_batch_graph(lattice, positions)
atomic_numbers = atomic_numbers.flatten()
edge_index, unit_vec, dist = self._batch_graph(lattice, positions)
atom_types = self._convert_to_atom_type(atomic_numbers).flatten()

i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
edge_index[[1, 2]], num_nodes=atomic_numbers.size(0)
edge_index[[1, 2]], num_nodes=atom_types.size(0)
)

# Embedding blocks:
node_emb = self._node_embedding(atomic_numbers)
node_emb = self._node_embedding(atom_types)
edge_emb = self._edge_embedding(dist)

# Message passing blocks:
Expand All @@ -625,13 +693,35 @@ def forward( # pylint: disable=too-many-locals

# Polarizability prediction block:
edge_polarizability = self._polarizability_predictor(edge_emb)
diag, off_diag = _get_edge_polarizability_tensor(edge_polarizability)

t1, t2, t3, t4, t5, t6 = self._get_edge_polarizability_tensor(
edge_polarizability
)
rotation = _get_rotations(unit_vec)

diag = rotation @ diag @ torch.linalg.inv(rotation)
off_diag = rotation @ off_diag @ torch.linalg.inv(rotation)
edge_polarizability = diag * torch.eye(3) + off_diag * -(torch.eye(3) - 1)
edge_polarizability = _polarizability_tensors_to_vectors(edge_polarizability)
t1 = rotation @ t1 @ torch.linalg.inv(rotation)
t2 = rotation @ t2 @ torch.linalg.inv(rotation)
t3 = rotation @ t3 @ torch.linalg.inv(rotation)
t4 = rotation @ t4 @ torch.linalg.inv(rotation)
t5 = rotation @ t5 @ torch.linalg.inv(rotation)
t6 = rotation @ t6 @ torch.linalg.inv(rotation)

mask_1 = torch.tensor([[0, 1, 0], [1, 0, 0], [0, 0, 0]])
mask_2 = torch.tensor([[0, 0, 1], [0, 0, 0], [1, 0, 0]])
mask_3 = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0]])
mask_4 = torch.tensor([[1, 0, 0], [0, 0, 0], [0, 0, 0]])
mask_5 = torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]])
mask_6 = torch.tensor([[0, 0, 0], [0, 0, 0], [0, 0, 1]])

edge_polarizability = (
t1 * mask_1
+ t2 * mask_2
+ t3 * mask_3
+ t4 * mask_4
+ t5 * mask_5
+ t6 * mask_6
)
edge_polarizability = self._get_polarizability_vectors(edge_polarizability)

# Isolate polarizabilities from batch graphs.
polarizability = torch.zeros((positions.size(0), 6))
Expand Down
Loading

0 comments on commit 2f73f73

Please sign in to comment.