Skip to content

Commit

Permalink
Bug fixing in pooling functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Joeres authored and Roman Joeres committed Sep 10, 2024
1 parent 81a8346 commit 3d5e03f
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 76 deletions.
36 changes: 18 additions & 18 deletions configs/downstream/all.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
seed: 42
data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_test
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_test
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data_feat
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/clean_logs
datasets:
- name: Immunogenicity
task: classification
- name: Glycosylation
task: classification
- name: Taxonomy_Domain
task: multilabel
# - name: Taxonomy_Kingdom
# task: multilabel
# - name: Taxonomy_Phylum
# task: multilabel
# - name: Taxonomy_Class
# task: multilabel
# - name: Taxonomy_Order
# task: multilabel
# - name: Taxonomy_Family
# task: multilabel
# - name: Taxonomy_Genus
# task: multilabel
# - name: Taxonomy_Species
# task: multilabel
- name: Taxonomy_Kingdom
task: multilabel
- name: Taxonomy_Phylum
task: multilabel
- name: Taxonomy_Class
task: multilabel
- name: Taxonomy_Order
task: multilabel
- name: Taxonomy_Family
task: multilabel
- name: Taxonomy_Genus
task: multilabel
- name: Taxonomy_Species
task: multilabel
pre-transforms:
model:
# - name: rf
Expand Down Expand Up @@ -79,5 +79,5 @@ model:
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: local_mean
suffix: _128_8_local
pooling: weighted_attention
suffix: _128_8_la_pool
32 changes: 1 addition & 31 deletions configs/downstream/pooling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,36 +68,6 @@ model:
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
optimizer: Adam
pooling: weighted_attention
suffix: _weighted_att
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: cell_attention
suffix: _cell_att
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: local_cell_attention
suffix: _local_cell_att
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: weighted_cell_attention
suffix: _weighted_cell_att
4 changes: 3 additions & 1 deletion gifflar/model/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
super().__init__(feat_dim, hidden_dim, num_layers, batch_size, pre_transform_args)
self.output_dim = output_dim

self.pooling = GIFFLARPooling()
self.pooling = GIFFLARPooling(kwargs.get("pooling", "global_mean"))
# self.add_module('pooling', GIFFLARPooling(kwargs.get("pooling", "global_mean")))
self.task = task

self.head, self.loss, self.metrics = get_prediction_head(hidden_dim, output_dim, task)
Expand All @@ -53,6 +54,7 @@ def to(self, device: torch.device) -> "DownstreamGGIN":
self: The model moved to the specified device
"""
super(DownstreamGGIN, self).to(device)
self.pooling.to(device)
for split, metric in self.metrics.items():
self.metrics[split] = metric.to(device)
return self
Expand Down
140 changes: 114 additions & 26 deletions gifflar/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,24 @@
from torch_geometric.nn.inits import reset
from torch_geometric.utils import softmax
from torch_geometric.nn import GINConv, global_mean_pool
from torch_scatter import scatter_add

from gifflar.utils import get_metrics


class GlobalAttention(torch.nn.Module):
class MultiGlobalAttention(nn.Module):
def __init__(self, gas: dict):
super().__init__()
for name, ga in gas.items():
setattr(self, name, ga)

def __getitem__(self, item):
if hasattr(self, item):
return getattr(self, item)
return super().__getitem__(item)


class GlobalAttention(nn.Module):
r"""Global soft attention layer from the `"Gated Graph Sequence Neural
Networks" <https://arxiv.org/abs/1511.05493>`_ paper
Expand Down Expand Up @@ -44,21 +57,20 @@ def reset_parameters(self):
reset(self.gate_nn)
reset(self.nn)


def forward(self, x, batch, size=None):
""""""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.int64)

x = x.unsqueeze(-1) if x.dim() == 1 else x
size = batch[-1].item() + 1 if size is None else size
size = int(batch.max()) + 1 if size is None else size

gate = self.gate_nn(x).view(-1, 1)
x = self.nn(x) if self.nn is not None else x
assert gate.dim() == x.dim() and gate.size(0) == x.size(0)

gate = softmax(gate, batch, num_nodes=size)

# out = scatter_add(gate * x, index=batch, dim=0) # , dim_size=size)
src = torch.zeros_like(size)
out = src.scatter_add(src=gate * x, index=batch, dim=0)
out = scatter_add(gate * x, batch, dim=0, dim_size=size)

return out

Expand Down Expand Up @@ -190,18 +202,45 @@ def __init__(self, mode: str = "global_mean"):
self.attention = None
self.weights = None
if "weighted" in mode:
print("Create weighting")
self.weights = nn.Parameter(torch.ones(3), requires_grad=True)
if "attention" in mode:
print("Create attention")
if mode == "global_attention":
self.attention = {"": GlobalAttention(
gate_nn=get_gin_layer(1024, 1),
nn=get_gin_layer(1024, 1024),
)}
self.attention = MultiGlobalAttention({"": GlobalAttention(
gate_nn=nn.Linear(1024, 1),
nn=nn.Sequential(
nn.Linear(1024, 1024),
nn.PReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, 1024),
),
)})
else:
self.attention = MultiGlobalAttention({key: GlobalAttention(
gate_nn=nn.Linear(1024, 1),
nn=nn.Sequential(
nn.Linear(1024, 1024),
nn.PReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, 1024),
),
) for key in ["atoms", "bonds", "monosacchs"]})

def to(self, device):
if self.weights is not None:
self.weights = self.weights.to(device)
if self.attention is not None:
if self.mode == "global_attention":
self.attention[""].to(device)
else:
self.attention = {key: GlobalAttention(
gate_nn=get_gin_layer(1024, 1),
nn=get_gin_layer(1024, 1024),
) for key in ["atoms", "bonds", "monosacchs"]}
self.attention["atoms"].to(device)
self.attention["bonds"].to(device)
self.attention["monosacchs"].to(device)
super(GIFFLARPooling, self).to(device)
return self

def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Expand All @@ -210,24 +249,73 @@ def forward(self, nodes: dict[str, torch.Tensor], batch_ids: dict[str, torch.Ten
Args:
x: The input to the pooling layer
"""
device = nodes["atoms"].device
match self.mode:
case "global_mean":
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0)
)
case "local_mean":
nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()}
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0)
)
case "weighted_mean":
nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])}
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0)
)
case "global_attention":
nodes = {key: self.attention[""](nodes[key], batch_ids[key]) for key in nodes.keys()}
return self.attention[""](
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0)
)
case "local_attention":
nodes = {key: self.attention[key](nodes[key], batch_ids[key]) for key in nodes.keys()}
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0)
)
case "weighted_attention":
nodes = {key: self.attention[key](nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])}
case "cell_attention":
nodes = {key: self.attention[""](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()}
case "local_cell_attention":
nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()}
case "weighted_cell_attention":
nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])}
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([batch_ids["atoms"], batch_ids["bonds"], batch_ids["monosacchs"]], dim=0)
)
device = nodes["atoms"].device
return global_mean_pool(
torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
torch.concat([torch.arange(len(nodes[key])).to(device) for key in ["atoms", "bonds", "monosacchs"]], dim=0)
)
#case "cell_attention":
# nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()}
# return self.attention[""](
# torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
# torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0)
# )
#case "local_cell_attention":
# nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) for key in nodes.keys()}
# return global_mean_pool(
# torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
# torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0)
# )
#case "weighted_cell_attention":
# nodes = {key: global_mean_pool(nodes[key], batch_ids[key]) * self.weights[i] for i, key in enumerate(["atoms", "bonds", "monosacchs"])}
# return global_mean_pool(
# torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
# torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0)
# )
#case "double_attention":
# nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()}
# return global_mean_pool(
# torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
# torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0)
# )
#case "weighted_double_attention":
# nodes = {key: self.attention[key](global_mean_pool(nodes[key], batch_ids[key])) for key in nodes.keys()}
# return global_mean_pool(
# torch.concat([nodes["atoms"], nodes["bonds"], nodes["monosacchs"]], dim=0),
# torch.concat([torch.arange(len(nodes["atoms"])), torch.arange(len(nodes["bonds"])), torch.arange(len(nodes["monosacchs"]))], dim=0)
# )
case _:
raise ValueError(f"Pooling method {self.mode} not supported.")

0 comments on commit 3d5e03f

Please sign in to comment.