Skip to content

Commit

Permalink
More pooling options and a small test config for them
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Sep 9, 2024
1 parent 19b3102 commit 81a8346
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 13 deletions.
103 changes: 103 additions & 0 deletions configs/downstream/pooling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
seed: 42
data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_test
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_pool
datasets:
- name: Immunogenicity
task: classification
- name: Glycosylation
task: classification
- name: Taxonomy_Domain
task: multilabel
pre-transforms:
model:
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: global_mean
suffix: _global
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: local_mean
suffix: _local
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: weighted_mean
suffix: _weighted
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: global_attention
suffix: _attention
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: local_attention
suffix: _local_att
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: weighted_attention
suffix: _weighted_att
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: cell_attention
suffix: _cell_att
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: local_cell_attention
suffix: _local_cell_att
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: weighted_cell_attention
suffix: _weighted_cell_att
129 changes: 116 additions & 13 deletions gifflar/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,73 @@
from typing import Literal, Any

import torch
from torch import nn
from torch import nn, scatter_add
from torch_geometric.nn.inits import reset
from torch_geometric.utils import softmax
from torch_geometric.nn import GINConv, global_mean_pool

from gifflar.utils import get_metrics


class GlobalAttention(torch.nn.Module):
r"""Global soft attention layer from the `"Gated Graph Sequence Neural
Networks" <https://arxiv.org/abs/1511.05493>`_ paper
.. math::
\mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left(
h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \odot
h_{\mathbf{\Theta}} ( \mathbf{x}_n ),
where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to
\mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.*
MLPS.
Args:
gate_nn (torch.nn.Module): A neural network :math:`h_{\mathrm{gate}}`
that computes attention scores by mapping node features :obj:`x` of
shape :obj:`[-1, in_channels]` to shape :obj:`[-1, 1]`, *e.g.*,
defined by :class:`torch.nn.Sequential`.
nn (torch.nn.Module, optional): A neural network
:math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of
shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`
before combining them with the attention scores, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
"""
def __init__(self, gate_nn, nn=None):
super().__init__()
self.gate_nn = gate_nn
self.nn = nn

self.reset_parameters()

def reset_parameters(self):
reset(self.gate_nn)
reset(self.nn)


def forward(self, x, batch, size=None):
""""""
x = x.unsqueeze(-1) if x.dim() == 1 else x
size = batch[-1].item() + 1 if size is None else size

gate = self.gate_nn(x).view(-1, 1)
x = self.nn(x) if self.nn is not None else x
assert gate.dim() == x.dim() and gate.size(0) == x.size(0)

gate = softmax(gate, batch, num_nodes=size)

# out = scatter_add(gate * x, index=batch, dim=0) # , dim_size=size)
src = torch.zeros_like(size)
out = src.scatter_add(src=gate * x, index=batch, dim=0)

return out


def __repr__(self) -> str:
return (f'{self.__class__.__name__}(gate_nn={self.gate_nn}, '
f'nn={self.nn})')


def get_gin_layer(input_dim: int, output_dim: int) -> GINConv:
"""
Get a GIN layer with the specified input and output dimensions
Expand Down Expand Up @@ -103,11 +164,44 @@ def __init__(self, mode: str = "global_mean"):
Initialize the GIFFLARPooling class
Args:
input_dim: The input dimension of the pooling layer
output_dim: The output dimension of the pooling layer
mode: The pooling mode to use, either
- global_mean:
Standard mean pooling over all nodes in the graph
- local_mean:
Standard mean pooling over all nodes of each type, then mean over the node types
- weighted_mean:
Local mean pooling with learnable weights per node type in the second mean-computation
- global_attention:
Standard self-attention mechanism over all nodes in the graph
- local_attention:
Standard self-attention mechanism over all nodes of each type, then mean over the nodes types
- weighted_attention:
Standard self-attention mechanism over all nodes of each type, then mean with learnable weights
per node type in the cell aggregation.
- cell_attention:
Local mean pooling with self-attention over the cell results
- local_cell_attention:
Local mean pooling with self-attention over the cell results of each type, then mean over the cell types
- weighted_cell_attention:
Local mean pooling with self-attention over the cell results (their means) and learnable weights for aggregation
"""
super().__init__()
self.mode = mode
self.attention = None
self.weights = None
if "weighted" in mode:
self.weights = nn.Parameter(torch.ones(3), requires_grad=True)
if "attention" in mode:
if mode == "global_attention":
self.attention = {"": GlobalAttention(
gate_nn=get_gin_layer(1024, 1),
nn=get_gin_layer(1024, 1024),
)}
else:
self.attention = {key: GlobalAttention(
gate_nn=get_gin_layer(1024, 1),
nn=get_gin_layer(1024, 1024),
) for key in ["atoms", "bonds", "monosacchs"]}

def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Expand All @@ -117,14 +211,23 @@ def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Ten
x: The input to the pooling layer
"""
match self.mode:
case "global_mean":
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0)
)
case "local_mean":
return torch.sum(torch.stack([
global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()
]), dim=0)
case _:
raise ValueError(f"Pooling mode {self.mode} not supported")
nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()}
case "weighted_mean":
nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])}
case "global_attention":
nodes = {key: self.attention[""](nodes[key], batch_ids[key]) for key in nodes.keys()}
case "local_attention":
nodes = {key: self.attention[key](nodes[key], batch_ids[key]) for key in nodes.keys()}
case "weighted_attention":
nodes = {key: self.attention[key](nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])}
case "cell_attention":
nodes = {key: self.attention[""](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()}
case "local_cell_attention":
nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()}
case "weighted_cell_attention":
nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])}
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0)
)

0 comments on commit 81a8346

Please sign in to comment.