diff --git a/ramannoodle/polarizability/gnn.py b/ramannoodle/polarizability/gnn.py index f26acbf..1dce8d9 100644 --- a/ramannoodle/polarizability/gnn.py +++ b/ramannoodle/polarizability/gnn.py @@ -129,7 +129,6 @@ def __init__(self, hidden_node_channels: int, hidden_edge_channels: int): self.lin_c1 = Linear( hidden_node_channels + hidden_edge_channels, 2 * hidden_node_channels, - dtype=Tensor, ) self.bn_c1 = BatchNorm1d(2 * hidden_node_channels) @@ -317,47 +316,6 @@ def reset_parameters(self) -> None: edge_block.reset_parameters() reset(self.polarizability_predictor) - def forward( # pylint: disable=too-many-locals - self, - lattice: Tensor, - atomic_numbers: Tensor, - positions: Tensor, - # batch: OptTensor = None, - ) -> Tensor: - """Forward pass.""" - edge_index, unit_vec, dist = _radius_graph_pbc(lattice, positions, self.cutoff) - - i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( - edge_index, num_nodes=atomic_numbers.size(0) - ) - - # Calculate distances and unit vector: - # dist = (positions[i] - positions[j]).pow(2).sum(dim=-1).sqrt() - # unit_vec = (positions[i] - positions[j]) / dist.view(-1, 1) - - # Embedding blocks: - node_emb = self.node_embedding(atomic_numbers) - edge_emb = self.edge_embedding(dist) - - # Message passing blocks: - for node_block, edge_block in zip(self.node_blocks, self.edge_blocks): - node_emb = node_block(node_emb, edge_emb, i) - edge_emb = edge_block( - node_emb, edge_emb, i, j, idx_i, idx_j, idx_k, idx_ji, idx_kj - ) - - # Polarizability prediction block: - polarizability = self.polarizability_predictor(edge_emb) - polarizability = self._get_polarizability_tensors(polarizability) - rotation = self._get_rotations(unit_vec) - - polarizability = torch.matmul(rotation, polarizability) - polarizability = torch.matmul(polarizability, torch.linalg.inv(rotation)) - - polarizability = self._get_polarizability_vectors(polarizability) - - return torch.mean(polarizability, dim=0) - def _get_polarizability_tensors(self, x: Tensor) -> Tensor: """X should have size (_,6).""" indices = torch.tensor( @@ -412,11 +370,89 @@ def _get_rotations(self, destination: Tensor) -> Tensor: rotation_matrix += a1 * b1[:, None, None] return rotation_matrix + def forward( # pylint: disable=too-many-locals + self, + lattice: Tensor, + atomic_numbers: Tensor, + positions: Tensor, + ) -> Tensor: + """Forward pass. + + Parameters + ---------- + lattice + Å | Tensor with size [S,3,3] where S is the number of samples. + atomic_numbers + Tensor with size [S,N] where N is the number of atoms. + positions + Unitless | Tensor with size [S,N,3]. + + """ + edge_index, unit_vec, dist = _radius_graph_pbc(lattice, positions, self.cutoff) + atomic_numbers = atomic_numbers.flatten() + + i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( + edge_index[[1, 2]], num_nodes=atomic_numbers.size(0) + ) + + # Embedding blocks: + node_emb = self.node_embedding(atomic_numbers) + edge_emb = self.edge_embedding(dist) + + # Message passing blocks: + for node_block, edge_block in zip(self.node_blocks, self.edge_blocks): + node_emb = node_block(node_emb, edge_emb, i) + edge_emb = edge_block( + node_emb, edge_emb, i, j, idx_i, idx_j, idx_k, idx_ji, idx_kj + ) + + # Polarizability prediction block: + edge_polarizability = self.polarizability_predictor(edge_emb) + edge_polarizability = self._get_polarizability_tensors(edge_polarizability) + rotation = self._get_rotations(unit_vec) + + edge_polarizability = torch.matmul(rotation, edge_polarizability) + edge_polarizability = torch.matmul( + edge_polarizability, torch.linalg.inv(rotation) + ) + + edge_polarizability = self._get_polarizability_vectors(edge_polarizability) + + polarizability = torch.zeros((positions.size(0), 6)) + for i in range(positions.size(0)): + mask = (edge_index[0:1] == i).T + count = torch.sum(mask) + polarizability[i] = torch.sum(edge_polarizability * mask, dim=0) + polarizability[i] /= count + return polarizability + def _radius_graph_pbc( lattice: Tensor, positions: Tensor, cutoff: float ) -> tuple[Tensor, Tensor, Tensor]: - """Generate edge indexes and unit vectors.""" + """Generate structure graphs with edge indexes, unit vectors, and distances. + + Parameters + ---------- + lattice + Å | Tensor with size [S,3,3] where S is the number of samples. + positions + Unitless | Tensor with size [S,N,3] where N is the number of atoms. + cutoff + Å | Edge cutoff distance. + + Returns + ------- + : + 3-tuple. + First element is edge indexes, a tensor of size [3,X] where X is the number of + edges. This tensor defines S non-interconnected graphs making up a batch. The + first row defines the graph index. The second and third rows define the actual + edge indexes used by ``triplet``. + Second element is cartesian unit vectors, a tensor of size [X,3]. + Third element is distances, a tensor of side [X,1]. + + """ num_samples = lattice.size(0) num_atoms = positions.size(1) @@ -429,35 +465,20 @@ def _radius_graph_pbc( cart_displacement = displacement.matmul(expanded_lattice) cart_distance_matrix = torch.sqrt(torch.sum(cart_displacement**2, dim=-1)) - # Compute edge matrix + # Compute adjacency matrix adjacency_matrix = cart_distance_matrix <= cutoff not_self_loop = ~torch.eye(adjacency_matrix.size(-1), dtype=torch.bool).expand( num_samples, -1, -1 ) adjacency_matrix = torch.logical_and(adjacency_matrix, not_self_loop) - # Edge indexes edge_indexes = torch.nonzero(adjacency_matrix).T - - # Cart_unit_vectors are flattened cart_unit_vectors = cart_displacement[*edge_indexes] cart_unit_vectors /= torch.linalg.norm(cart_unit_vectors, dim=-1)[:, None] - distances = cart_distance_matrix[*edge_indexes].view(-1, 1) - edge_indexes = edge_indexes[[1, 2]] + edge_indexes[0] * num_atoms + # Disconnect all S graphs + edge_indexes[1] += edge_indexes[0] * num_atoms + edge_indexes[2] += edge_indexes[0] * num_atoms return edge_indexes, cart_unit_vectors, distances - - -# def test_get_rotation_matrices(): -# destination = torch.randn((30,3)) -# mat = model._get_rotation_matrices(destination) -# vec1_rot = mat.matmul(torch.tensor([1.0, 0.0, 0.0])) -# -# rotated = vec1_rot -# known = destination / torch.linalg.norm(destination, dim=1).view(-1, 1) -# -# assert torch.allclose(rotated, known, atol = 1e-5) -# print("Success!") -# test_get_rotation_matrices() diff --git a/test/tests/test_gnn.py b/test/tests/test_gnn.py index 1e69877..370d9bd 100644 --- a/test/tests/test_gnn.py +++ b/test/tests/test_gnn.py @@ -21,7 +21,7 @@ def test_get_rotations() -> None: def test_radius_graph_pbc() -> None: """Test _radius_graph_pbc (normal).""" - for batch_size in range(1, 3): + for batch_size in range(1, 4): # Generate random data. num_atoms = 40 lattice = torch.eye(3) * 10 @@ -41,10 +41,43 @@ def test_radius_graph_pbc() -> None: ei, uv, d = _radius_graph_pbc( lattice[i : i + 1], positions[i : i + 1], cutoff=3 ) - edge_index.append(ei + num_atoms * i) + ei[0] = i + ei[[1, 2]] += num_atoms * i + edge_index.append(ei) unit_vector.append(uv) distance.append(d) assert torch.allclose(batch_edge_index, torch.concat(edge_index, dim=1)) assert torch.allclose(batch_unit_vector, torch.concat(unit_vector, dim=0)) assert torch.allclose(batch_distance, torch.concat(distance, dim=0)) + + +def test_batch_polarizability() -> None: + """Test of batch functions for forward pass (normal).""" + for batch_size in range(1, 4): + model = PotGNN(5, 5, 5, 5) + model.eval() + + # Generate random data. + num_atoms = 40 + lattice = torch.eye(3) * 10 + atomic_numbers = torch.randint(1, 10, (num_atoms,)) + + batch_lattices = lattice.expand(batch_size, 3, 3) + batch_atomic_numbers = atomic_numbers.expand(batch_size, num_atoms) + batch_positions = torch.randn(batch_size, num_atoms, 3) + batch_polarizability = model.forward( + batch_lattices, batch_atomic_numbers, batch_positions + ) + + # Individual calls + polarizabilities = torch.zeros((batch_size, 6)) + for i in range(batch_size): + polarizability = model.forward( + batch_lattices[i : i + 1], + batch_atomic_numbers[i : i + 1], + batch_positions[i : i + 1], + ) + polarizabilities[i] = polarizability[0] + + assert torch.allclose(batch_polarizability, polarizabilities)