Skip to content

Commit

Permalink
implemented batching in gnn model
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Aug 19, 2024
1 parent 3e7cf05 commit 6579683
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 64 deletions.
145 changes: 83 additions & 62 deletions ramannoodle/polarizability/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def __init__(self, hidden_node_channels: int, hidden_edge_channels: int):
self.lin_c1 = Linear(
hidden_node_channels + hidden_edge_channels,
2 * hidden_node_channels,
dtype=Tensor,
)

self.bn_c1 = BatchNorm1d(2 * hidden_node_channels)
Expand Down Expand Up @@ -317,47 +316,6 @@ def reset_parameters(self) -> None:
edge_block.reset_parameters()
reset(self.polarizability_predictor)

def forward( # pylint: disable=too-many-locals
self,
lattice: Tensor,
atomic_numbers: Tensor,
positions: Tensor,
# batch: OptTensor = None,
) -> Tensor:
"""Forward pass."""
edge_index, unit_vec, dist = _radius_graph_pbc(lattice, positions, self.cutoff)

i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
edge_index, num_nodes=atomic_numbers.size(0)
)

# Calculate distances and unit vector:
# dist = (positions[i] - positions[j]).pow(2).sum(dim=-1).sqrt()
# unit_vec = (positions[i] - positions[j]) / dist.view(-1, 1)

# Embedding blocks:
node_emb = self.node_embedding(atomic_numbers)
edge_emb = self.edge_embedding(dist)

# Message passing blocks:
for node_block, edge_block in zip(self.node_blocks, self.edge_blocks):
node_emb = node_block(node_emb, edge_emb, i)
edge_emb = edge_block(
node_emb, edge_emb, i, j, idx_i, idx_j, idx_k, idx_ji, idx_kj
)

# Polarizability prediction block:
polarizability = self.polarizability_predictor(edge_emb)
polarizability = self._get_polarizability_tensors(polarizability)
rotation = self._get_rotations(unit_vec)

polarizability = torch.matmul(rotation, polarizability)
polarizability = torch.matmul(polarizability, torch.linalg.inv(rotation))

polarizability = self._get_polarizability_vectors(polarizability)

return torch.mean(polarizability, dim=0)

def _get_polarizability_tensors(self, x: Tensor) -> Tensor:
"""X should have size (_,6)."""
indices = torch.tensor(
Expand Down Expand Up @@ -412,11 +370,89 @@ def _get_rotations(self, destination: Tensor) -> Tensor:
rotation_matrix += a1 * b1[:, None, None]
return rotation_matrix

def forward( # pylint: disable=too-many-locals
self,
lattice: Tensor,
atomic_numbers: Tensor,
positions: Tensor,
) -> Tensor:
"""Forward pass.
Parameters
----------
lattice
Å | Tensor with size [S,3,3] where S is the number of samples.
atomic_numbers
Tensor with size [S,N] where N is the number of atoms.
positions
Unitless | Tensor with size [S,N,3].
"""
edge_index, unit_vec, dist = _radius_graph_pbc(lattice, positions, self.cutoff)
atomic_numbers = atomic_numbers.flatten()

i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
edge_index[[1, 2]], num_nodes=atomic_numbers.size(0)
)

# Embedding blocks:
node_emb = self.node_embedding(atomic_numbers)
edge_emb = self.edge_embedding(dist)

# Message passing blocks:
for node_block, edge_block in zip(self.node_blocks, self.edge_blocks):
node_emb = node_block(node_emb, edge_emb, i)
edge_emb = edge_block(
node_emb, edge_emb, i, j, idx_i, idx_j, idx_k, idx_ji, idx_kj
)

# Polarizability prediction block:
edge_polarizability = self.polarizability_predictor(edge_emb)
edge_polarizability = self._get_polarizability_tensors(edge_polarizability)
rotation = self._get_rotations(unit_vec)

edge_polarizability = torch.matmul(rotation, edge_polarizability)
edge_polarizability = torch.matmul(
edge_polarizability, torch.linalg.inv(rotation)
)

edge_polarizability = self._get_polarizability_vectors(edge_polarizability)

polarizability = torch.zeros((positions.size(0), 6))
for i in range(positions.size(0)):
mask = (edge_index[0:1] == i).T
count = torch.sum(mask)
polarizability[i] = torch.sum(edge_polarizability * mask, dim=0)
polarizability[i] /= count
return polarizability


def _radius_graph_pbc(
lattice: Tensor, positions: Tensor, cutoff: float
) -> tuple[Tensor, Tensor, Tensor]:
"""Generate edge indexes and unit vectors."""
"""Generate structure graphs with edge indexes, unit vectors, and distances.
Parameters
----------
lattice
Å | Tensor with size [S,3,3] where S is the number of samples.
positions
Unitless | Tensor with size [S,N,3] where N is the number of atoms.
cutoff
Å | Edge cutoff distance.
Returns
-------
:
3-tuple.
First element is edge indexes, a tensor of size [3,X] where X is the number of
edges. This tensor defines S non-interconnected graphs making up a batch. The
first row defines the graph index. The second and third rows define the actual
edge indexes used by ``triplet``.
Second element is cartesian unit vectors, a tensor of size [X,3].
Third element is distances, a tensor of side [X,1].
"""
num_samples = lattice.size(0)
num_atoms = positions.size(1)

Expand All @@ -429,35 +465,20 @@ def _radius_graph_pbc(
cart_displacement = displacement.matmul(expanded_lattice)
cart_distance_matrix = torch.sqrt(torch.sum(cart_displacement**2, dim=-1))

# Compute edge matrix
# Compute adjacency matrix
adjacency_matrix = cart_distance_matrix <= cutoff
not_self_loop = ~torch.eye(adjacency_matrix.size(-1), dtype=torch.bool).expand(
num_samples, -1, -1
)
adjacency_matrix = torch.logical_and(adjacency_matrix, not_self_loop)

# Edge indexes
edge_indexes = torch.nonzero(adjacency_matrix).T

# Cart_unit_vectors are flattened
cart_unit_vectors = cart_displacement[*edge_indexes]
cart_unit_vectors /= torch.linalg.norm(cart_unit_vectors, dim=-1)[:, None]

distances = cart_distance_matrix[*edge_indexes].view(-1, 1)

edge_indexes = edge_indexes[[1, 2]] + edge_indexes[0] * num_atoms
# Disconnect all S graphs
edge_indexes[1] += edge_indexes[0] * num_atoms
edge_indexes[2] += edge_indexes[0] * num_atoms

return edge_indexes, cart_unit_vectors, distances


# def test_get_rotation_matrices():
# destination = torch.randn((30,3))
# mat = model._get_rotation_matrices(destination)
# vec1_rot = mat.matmul(torch.tensor([1.0, 0.0, 0.0]))
#
# rotated = vec1_rot
# known = destination / torch.linalg.norm(destination, dim=1).view(-1, 1)
#
# assert torch.allclose(rotated, known, atol = 1e-5)
# print("Success!")
# test_get_rotation_matrices()
37 changes: 35 additions & 2 deletions test/tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_get_rotations() -> None:

def test_radius_graph_pbc() -> None:
"""Test _radius_graph_pbc (normal)."""
for batch_size in range(1, 3):
for batch_size in range(1, 4):
# Generate random data.
num_atoms = 40
lattice = torch.eye(3) * 10
Expand All @@ -41,10 +41,43 @@ def test_radius_graph_pbc() -> None:
ei, uv, d = _radius_graph_pbc(
lattice[i : i + 1], positions[i : i + 1], cutoff=3
)
edge_index.append(ei + num_atoms * i)
ei[0] = i
ei[[1, 2]] += num_atoms * i
edge_index.append(ei)
unit_vector.append(uv)
distance.append(d)

assert torch.allclose(batch_edge_index, torch.concat(edge_index, dim=1))
assert torch.allclose(batch_unit_vector, torch.concat(unit_vector, dim=0))
assert torch.allclose(batch_distance, torch.concat(distance, dim=0))


def test_batch_polarizability() -> None:
"""Test of batch functions for forward pass (normal)."""
for batch_size in range(1, 4):
model = PotGNN(5, 5, 5, 5)
model.eval()

# Generate random data.
num_atoms = 40
lattice = torch.eye(3) * 10
atomic_numbers = torch.randint(1, 10, (num_atoms,))

batch_lattices = lattice.expand(batch_size, 3, 3)
batch_atomic_numbers = atomic_numbers.expand(batch_size, num_atoms)
batch_positions = torch.randn(batch_size, num_atoms, 3)
batch_polarizability = model.forward(
batch_lattices, batch_atomic_numbers, batch_positions
)

# Individual calls
polarizabilities = torch.zeros((batch_size, 6))
for i in range(batch_size):
polarizability = model.forward(
batch_lattices[i : i + 1],
batch_atomic_numbers[i : i + 1],
batch_positions[i : i + 1],
)
polarizabilities[i] = polarizability[0]

assert torch.allclose(batch_polarizability, polarizabilities)

0 comments on commit 6579683

Please sign in to comment.