Skip to content

Commit

Permalink
batch support for _radius_graph_pbc
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Aug 19, 2024
1 parent 7dda1d4 commit 3e7cf05
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 49 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ repos:
args: [--strict]
additional_dependencies:
[
pytest,types-tabulate
pytest,types-tabulate,torch
]
exclude: "conf.py"
- repo: https://github.com/pycqa/flake8
rev: '7.1.1' # pick a git hash / tag to point to
hooks:
- id: flake8
args: [--max-line-length=88, --docstring-convention=numpy]
args: [--max-line-length=88, --extend-ignore=E203, --docstring-convention=numpy]
additional_dependencies:
[
flake8-bugbear, flake8-docstrings, flake8-absolute-import
Expand Down
2 changes: 2 additions & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"dictionaries": [],
"words": [
"__repr__",
"allclose",
"anharmonic",
"ANSICOLORS",
"argmax",
Expand Down Expand Up @@ -57,6 +58,7 @@
"quickstart",
"Raman",
"ramannoodle",
"randn",
"repr",
"schnet",
"Softplus",
Expand Down
113 changes: 66 additions & 47 deletions ramannoodle/polarizability/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,42 @@

from __future__ import annotations

import typing

import numpy as np
from numpy.typing import NDArray

import torch
from torch import Tensor
from torch.nn import BatchNorm1d, Embedding, Linear, ModuleList, Sequential
from torch.nn import BatchNorm1d, Embedding, Linear, ModuleList, Sequential, Module
from torch.utils.data import Dataset

from torch_geometric.nn.inits import reset
from torch_geometric.nn.models.dimenet import triplets
from torch_geometric.nn.models.schnet import ShiftedSoftplus
from torch_geometric.typing import OptTensor
from torch_geometric.utils import scatter


class PolarizabilityDataset(Dataset):
def _get_scaled_polarizabilities(
polarizabilities: Tensor,
) -> tuple[Tensor, Tensor, Tensor]:
"""Compute scaled, flattened (6 member) polarizabilities."""
mean = polarizabilities.mean(0, keepdim=True)
stddev = polarizabilities.std(0, unbiased=False, keepdim=True)
polarizabilities = (polarizabilities - mean) / stddev

scaled_polarizabilities = torch.zeros((polarizabilities.size(0), 6))
scaled_polarizabilities[:, 0] = polarizabilities[:, 0, 0]
scaled_polarizabilities[:, 1] = polarizabilities[:, 1, 1]
scaled_polarizabilities[:, 2] = polarizabilities[:, 2, 2]
scaled_polarizabilities[:, 3] = polarizabilities[:, 0, 1]
scaled_polarizabilities[:, 4] = polarizabilities[:, 0, 2]
scaled_polarizabilities[:, 5] = polarizabilities[:, 1, 2]

return mean, stddev, scaled_polarizabilities


class PolarizabilityDataset(Dataset[tuple[Tensor, Tensor, Tensor, Tensor]]):
"""Pytorch dataset containing atom positions mapped to polarizabilities.
Polarizabilities are scaled and reshaped into tensors with size [6,].
Expand Down Expand Up @@ -47,31 +67,16 @@ def __init__(
self._atomic_numbers = torch.tensor(atomic_numbers)
self._positions = torch.from_numpy(positions).float()
self._polarizabilities = torch.from_numpy(polarizabilities).float()

self._mean_polarizability = self._polarizabilities.mean(0, keepdim=True)
self._stddev_polarizability = self._polarizabilities.std(
0, unbiased=False, keepdim=True
)
scaled_polarizabilities = (
self._polarizabilities - self._mean_polarizability
) / self._stddev_polarizability

self._scaled_polarizabilities = torch.zeros(
(scaled_polarizabilities.size()[0], 6)
)

self._scaled_polarizabilities[:, 0] = scaled_polarizabilities[:, 0, 0]
self._scaled_polarizabilities[:, 1] = scaled_polarizabilities[:, 1, 1]
self._scaled_polarizabilities[:, 2] = scaled_polarizabilities[:, 2, 2]
self._scaled_polarizabilities[:, 3] = scaled_polarizabilities[:, 0, 1]
self._scaled_polarizabilities[:, 4] = scaled_polarizabilities[:, 0, 2]
self._scaled_polarizabilities[:, 5] = scaled_polarizabilities[:, 1, 2]
mean, stddev, scaled = _get_scaled_polarizabilities(self._polarizabilities)
self._mean_polarizability = mean
self._stddev_polarizability = stddev
self._scaled_polarizabilities = scaled

def __len__(self) -> int:
"""Get length."""
return len(self._positions)

def __getitem__(self, i):
def __getitem__(self, i: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Get positions, atomic numbers, lattices, and scaled polarizabilities."""
return (
self._lattices[i],
Expand Down Expand Up @@ -122,13 +127,15 @@ def __init__(self, hidden_node_channels: int, hidden_edge_channels: int):
# "filter" : c1ij W1 + b1
# "core" : c1ij W2 + b2
self.lin_c1 = Linear(
hidden_node_channels + hidden_edge_channels, 2 * hidden_node_channels
hidden_node_channels + hidden_edge_channels,
2 * hidden_node_channels,
dtype=Tensor,
)

self.bn_c1 = BatchNorm1d(2 * hidden_node_channels)
self.bn = BatchNorm1d(hidden_node_channels)

def reset_parameters(self):
def reset_parameters(self) -> None:
"""Reset model parameters."""
self.lin_c1.reset_parameters()
self.bn_c1.reset_parameters()
Expand All @@ -148,7 +155,7 @@ def forward(
)
c1_emb = self.bn(c1_emb)

return (node_embedding + c1_emb).tanh()
return typing.cast(Tensor, (node_embedding + c1_emb).tanh())


class EdgeBlock(torch.nn.Module):
Expand Down Expand Up @@ -184,7 +191,7 @@ def __init__(self, hidden_node_channels: int, hidden_edge_channels: int):
self.bn_c2_2 = BatchNorm1d(hidden_edge_channels)
self.bn_c3_2 = BatchNorm1d(hidden_edge_channels)

def reset_parameters(self):
def reset_parameters(self) -> None:
"""Reset model parameters."""
self.lin_c2.reset_parameters()
self.lin_c3.reset_parameters()
Expand Down Expand Up @@ -236,10 +243,10 @@ def forward( # pylint: disable=too-many-arguments, too-many-locals
)
c3_emb = self.bn_c3_2(c3_emb)

return (edge_embedding + c2_emb + c3_emb).tanh()
return typing.cast(Tensor, (edge_embedding + c2_emb + c3_emb).tanh())


class PotGNN(torch.nn.Module):
class PotGNN(Module):
r"""POlarizability Tensor Graph Neural Network (PotGNN).
GNN architecture was inspired by the "direct force architecture" developed in Park
Expand Down Expand Up @@ -300,7 +307,7 @@ def __init__(
Linear(hidden_edge_channels, 6),
)

def reset_parameters(self):
def reset_parameters(self) -> None:
"""Reset model parameters."""
reset(self.node_embedding)
self.edge_embedding.reset_parameters()
Expand Down Expand Up @@ -342,7 +349,7 @@ def forward( # pylint: disable=too-many-locals
# Polarizability prediction block:
polarizability = self.polarizability_predictor(edge_emb)
polarizability = self._get_polarizability_tensors(polarizability)
rotation = self._get_rotation_matrices(unit_vec)
rotation = self._get_rotations(unit_vec)

polarizability = torch.matmul(rotation, polarizability)
polarizability = torch.matmul(polarizability, torch.linalg.inv(rotation))
Expand All @@ -367,10 +374,10 @@ def _get_polarizability_vectors(self, x: Tensor) -> Tensor:
indices = torch.tensor([[0, 0], [1, 1], [2, 2], [0, 1], [0, 2], [1, 2]]).T
return x[:, indices[0], indices[1]]

def _get_rotation_matrices(self, destination: Tensor):
def _get_rotations(self, destination: Tensor) -> Tensor:
"""Get rotation matrices.
The source vector is set to (1,0,0). I think this should be in cartesian??
The source vector is set to (1,0,0).
Parameters
----------
Expand Down Expand Up @@ -406,28 +413,40 @@ def _get_rotation_matrices(self, destination: Tensor):
return rotation_matrix


def _radius_graph_pbc(lattice: Tensor, positions: Tensor, cutoff: float):
def _radius_graph_pbc(
lattice: Tensor, positions: Tensor, cutoff: float
) -> tuple[Tensor, Tensor, Tensor]:
"""Generate edge indexes and unit vectors."""
num_samples = lattice.size(0)
num_atoms = positions.size(1)

# Compute pairwise distance matrix.
positions_a = positions.unsqueeze(1)
positions_b = positions.unsqueeze(0)
displacement = positions_b - positions_a
displacement = positions.unsqueeze(1) - positions.unsqueeze(2)
displacement = torch.where(
displacement % 1 > 0.5, displacement % 1 - 1, displacement % 1
)
cart_displacement = displacement.matmul(lattice)
expanded_lattice = lattice[:, None, :, :].expand(-1, displacement.size(1), -1, -1)
cart_displacement = displacement.matmul(expanded_lattice)
cart_distance_matrix = torch.sqrt(torch.sum(cart_displacement**2, dim=-1))

# Valid edges
valid_edge = cart_distance_matrix <= cutoff
not_self_loop = ~torch.eye(valid_edge.size()[0], dtype=torch.bool)
valid_edges = torch.logical_and(valid_edge, not_self_loop)
# Compute edge matrix
adjacency_matrix = cart_distance_matrix <= cutoff
not_self_loop = ~torch.eye(adjacency_matrix.size(-1), dtype=torch.bool).expand(
num_samples, -1, -1
)
adjacency_matrix = torch.logical_and(adjacency_matrix, not_self_loop)

# Edge indexes
edge_indexes = torch.nonzero(adjacency_matrix).T

# Cart_unit_vectors are flattened
cart_unit_vectors = cart_displacement[*edge_indexes]
cart_unit_vectors /= torch.linalg.norm(cart_unit_vectors, dim=-1)[:, None]

distances = cart_distance_matrix[*edge_indexes].view(-1, 1)

edge_indexes = edge_indexes[[1, 2]] + edge_indexes[0] * num_atoms

# Edge indexes and unit vectors
edge_indexes = torch.nonzero(valid_edges).T
cart_unit_vectors = cart_displacement[edge_indexes[0], edge_indexes[1]]
cart_unit_vectors /= torch.linalg.norm(cart_unit_vectors, dim=1)[:, None]
distances = cart_distance_matrix[edge_indexes[0], edge_indexes[1]].view(-1, 1)
return edge_indexes, cart_unit_vectors, distances


Expand Down
50 changes: 50 additions & 0 deletions test/tests/test_gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Testing for GNN-based models."""

import torch

from ramannoodle.polarizability.gnn import PotGNN, _radius_graph_pbc

# pylint: disable=protected-access, too-many-arguments, not-callable


def test_get_rotations() -> None:
"""Test _get_rotations (normal)."""
model = PotGNN(5, 5, 5, 5)
unit_vector = torch.randn((40, 3))

rotation = model._get_rotations(unit_vector)
rotated = rotation.matmul(torch.tensor([1.0, 0.0, 0.0]))
known = unit_vector / torch.linalg.norm(unit_vector, dim=1).view(-1, 1)

assert torch.allclose(rotated, known, atol=1e-5)


def test_radius_graph_pbc() -> None:
"""Test _radius_graph_pbc (normal)."""
for batch_size in range(1, 3):
# Generate random data.
num_atoms = 40
lattice = torch.eye(3) * 10
lattice = lattice.expand(batch_size, 3, 3)
positions = torch.randn(batch_size, num_atoms, 3)

# Batched graph
batch_edge_index, batch_unit_vector, batch_distance = _radius_graph_pbc(
lattice, positions, cutoff=3
)

# Individual graphs concatenated together
edge_index = []
unit_vector = []
distance = []
for i in range(batch_size):
ei, uv, d = _radius_graph_pbc(
lattice[i : i + 1], positions[i : i + 1], cutoff=3
)
edge_index.append(ei + num_atoms * i)
unit_vector.append(uv)
distance.append(d)

assert torch.allclose(batch_edge_index, torch.concat(edge_index, dim=1))
assert torch.allclose(batch_unit_vector, torch.concat(unit_vector, dim=0))
assert torch.allclose(batch_distance, torch.concat(distance, dim=0))

0 comments on commit 3e7cf05

Please sign in to comment.