Skip to content

Commit

Permalink
Debugging for final downstream comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Joeres authored and Roman Joeres committed Aug 19, 2024
1 parent 709aca9 commit e211984
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 115 deletions.
65 changes: 34 additions & 31 deletions configs/downstream/all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data
logs_dir: logs
datasets:
#- name: Immunogenicity
# task: classification
#- name: Glycosylation
# task: classification
- name: Immunogenicity
task: classification
- name: Glycosylation
task: classification
- name: Taxonomy_Domain
task: multilabel
#- name: Taxonomy_Kingdom
# task: multilabel
#- name: Taxonomy_Phylum
# task: multilabel
#- name: Taxonomy_Class
Expand All @@ -24,7 +26,7 @@ datasets:
pre-transforms:
model:
#- name: rf
# n_estimators: 500
# n_estimators: 500
# n_jobs: -1
# random_state: 42
#- name: svm
Expand All @@ -35,32 +37,33 @@ model:
hidden_dim: 1024
batch_size: 256
num_layers: 3
epochs: 1
epochs: 100
patience: 30
learning_rate: 0
optimizer: Adam
- name: sweetnet
hidden_dim: 1024
batch_size: 512
epochs: 1
patience: 30
learning_rate: 0.001
optimizer: Adam
suffix:
- name: gnngly
hidden_dim: 1024
batch_size: 512
num_layers: 5
epochs: 1
patience: 30
learning_rate: 0.001
optimizer: Adam
suffix:
- name: gifflar
hidden_dim: 1024
batch_size: 512
num_layers: 8
epochs: 1
learning_rate: 0.001
optimizer: Adam
suffix:
#- name: sweetnet
# hidden_dim: 1024
# batch_size: 512
# num_layers: 16
# epochs: 1
# patience: 30
# learning_rate: 0.001
# optimizer: Adam
# suffix:
#- name: gnngly
# hidden_dim: 1024
# batch_size: 512
# num_layers: 8
# epochs: 1
# patience: 30
# learning_rate: 0.001
# optimizer: Adam
# suffix:
#- name: gifflar
# hidden_dim: 1024
# batch_size: 256
# num_layers: 8
# epochs: 100
# learning_rate: 0.001
# optimizer: Adam
# suffix:
2 changes: 1 addition & 1 deletion configs/downstream/both.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pre-transforms:
model:
- name: gifflar
hidden_dim: 1024
batch_size: 512
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
Expand Down
17 changes: 0 additions & 17 deletions configs/downstream/gnngly.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/downstream/lppe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pre-transforms:
model:
- name: gifflar
hidden_dim: 1024
batch_size: 512
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
Expand Down
2 changes: 1 addition & 1 deletion configs/downstream/rwpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pre-transforms:
model:
- name: gifflar
hidden_dim: 1024
batch_size: 512
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
Expand Down
38 changes: 7 additions & 31 deletions gifflar/baselines/gnngly.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict, Literal
from collections import OrderedDict

import torch
from torch import nn
from torch_geometric.nn import GCNConv, global_mean_pool, Sequential
from torch_geometric.nn import GCNConv, global_mean_pool

from gifflar.model import DownstreamGGIN

Expand Down Expand Up @@ -30,39 +31,15 @@ def __init__(self, hidden_dim: int, num_layers: int, output_dim: int,
"""
Initialize the model following the papers description.
"""
super().__init__(14, output_dim, task)
super().__init__(hidden_dim, output_dim, task)

del self.convs
del self.head

self.layers = Sequential('x, edge_index', [GCNConv(hidden_dim, hidden_dim) for _ in range(num_layers)])

# Five layers of plain graph convolution with a hidden dimension of 14.
# self.layers = [
# GCNConv(133, 14),
# GCNConv(14, 14),
# GCNConv(14, 14),
# GCNConv(14, 14),
# GCNConv(14, 14),
# ]
self.layers = nn.Sequential(OrderedDict([(f"layer{l + 1}", GCNConv((133 if l == 0 else hidden_dim), hidden_dim)) for l in range(num_layers)]))

# ASSUMPTION: mean pooling
self.pooling = global_mean_pool

# ASSUMPTION: a prediction head that seems quite elaborate given the other parts of the paper
# self.head = nn.Sequential(
# nn.Dropout(0.1),
# nn.Linear(14, 64),
# nn.PReLU(),
# nn.BatchNorm1d(64),
# nn.Linear(64, output_dim),
# )

def to(self, device):
super(GNNGLY, self).to(device)
# self.layers = [l.to(device) for l in self.layers]


def forward(self, batch):
"""
Forward the data though the model.
Expand All @@ -79,12 +56,11 @@ def forward(self, batch):
edge_index = batch["gnngly_edge_index"]

# Propagate the data through the model
# for layer in self.layers:
# x = layer(x, edge_index)
for layer in self.layers:
x = layer(x, edge_index)

# Compute the graph embeddings and make the final prediction based on this
node_embed = self.layers(x, edge_index)
graph_embed = self.pooling(node_embed, batch_ids)
graph_embed = self.pooling(x, batch_ids)
pred = self.head(graph_embed)

return {
Expand Down
47 changes: 19 additions & 28 deletions gifflar/baselines/sweetnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Literal, Dict
from collections import OrderedDict

import torch
from glycowork.ml.models import prep_model
from torch import nn
from torch_geometric.nn import global_mean_pool, GraphConv, Sequential
import torch.nn.functional as F
from glycowork.glycan_data.loader import lib

from gifflar.data import HeteroDataBatch
from gifflar.model import DownstreamGGIN
Expand All @@ -27,9 +29,9 @@ def __init__(self, hidden_dim: int, output_dim: int, num_layers: int,
del self.head

# Load the untrained model from glycowork
# self.model = prep_model("SweetNet", output_dim, hidden_dim=hidden_dim)
self.layers = Sequential('x, edge_index', [GraphConv(hidden_dim, hidden_dim) for _ in range(num_layers)])

self.item_embedding = nn.Embedding(len(lib), hidden_dim)
self.layers = nn.Sequential(OrderedDict([(f"layer{l + 1}", GraphConv(hidden_dim, hidden_dim)) for l in range(num_layers)]))
self.head = nn.Sequential(
nn.Linear(hidden_dim, 1024),
nn.BatchNorm1d(1024),
Expand All @@ -38,9 +40,9 @@ def __init__(self, hidden_dim: int, output_dim: int, num_layers: int,
nn.BatchNorm1d(128),
nn.LeakyReLU(),
nn.Dropout(0.5),
nn.Linear(64, output_dim),
nn.Linear(128, output_dim),
)

def forward(self, batch: HeteroDataBatch) -> Dict[str, torch.Tensor]:
"""
Forward the data though the model.
Expand All @@ -52,33 +54,22 @@ def forward(self, batch: HeteroDataBatch) -> Dict[str, torch.Tensor]:
Dict holding the node embeddings, the graph embedding, and the final model prediction
"""
# Extract monosaccharide graph from the heterogeneous graph
x = batch["x_dict"]["monosacchs"]
batch_ids = batch["batch_dict"]["monosacchs"]
edge_index = batch["edge_index_dict"]["monosacchs", "boundary", "monosacchs"]
x = batch["sweetnet_x"]
batch_ids = batch["sweetnet_batch"]
edge_index = batch["sweetnet_edge_index"]

# Getting node features
# x = self.model.item_embedding(x)
# x = x.squeeze(1)
#
# Graph convolution operations
# x = F.leaky_relu(self.model.conv1(x, edge_index))
# x = F.leaky_relu(self.model.conv2(x, edge_index))
# node_embeds = F.leaky_relu(self.model.conv3(x, edge_index))
# graph_embed = global_mean_pool(node_embeds, batch_ids)
#
# Fully connected part
# x = self.model.act1(self.model.bn1(self.model.lin1(graph_embed)))
# x_out = self.model.bn2(self.model.lin2(x))
# x = F.dropout(self.model.act2(x_out), p=0.5, training=self.model.training)
#
# x = self.model.lin3(x)
x = self.item_embedding(x)
x = x.squeeze(1)

for layer in self.layers:
x = layer(x, edge_index)

node_embeds = self.layers(x, edge_index)
graph_embed = global_mean_pool(node_embeds, batch_ids)
x = self.head(graph_embed)
graph_embed = global_mean_pool(x, batch_ids)
pred = self.head(graph_embed)

return {
"node_embed": node_embeds,
"node_embed": x,
"graph_embed": graph_embed,
"preds": x,
"preds": pred,
}
18 changes: 14 additions & 4 deletions gifflar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def get_gin_layer(hidden_dim: int):
}


class MultiEmbedding(nn.Module):
def __init__(self, embeddings: dict[str, nn.Embedding]):
super().__init__()
for name, embedding in embeddings.items():
setattr(self, name, embedding)

def forward(self, input_, name):
return getattr(self, name).forward(input_)


class GlycanGIN(LightningModule):
def __init__(self, hidden_dim: int, num_layers: int, task: Literal["regression", "classification", "multilabel"],
pre_transform_args: Optional[Dict] = None):
Expand All @@ -44,11 +54,11 @@ def __init__(self, hidden_dim: int, num_layers: int, task: Literal["regression",
self.addendum.append(pre_transforms[name].attr_name)
rand_dim -= args["dim"]

self.embedding = {
self.embedding = MultiEmbedding({
"atoms": nn.Embedding(len(atom_map) + 2, rand_dim),
"bonds": nn.Embedding(len(bond_map) + 2, rand_dim),
"monosacchs": nn.Embedding(len(lib) + 2, rand_dim),
}
})

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
Expand All @@ -69,7 +79,7 @@ def __init__(self, hidden_dim: int, num_layers: int, task: Literal["regression",
def forward(self, batch):
for key in batch.x_dict.keys():
# Compute random encodings for the atom type and include positional encodings
pes = [self.embedding[key](batch.x_dict[key])]
pes = [self.embedding.forward(batch.x_dict[key], key)]
for pe in self.addendum:
pes.append(batch[f"{key}_{pe}"])

Expand Down Expand Up @@ -119,7 +129,7 @@ def to(self, device):
super(DownstreamGGIN, self).to(device)
for split, metric in self.metrics.items():
self.metrics[split] = metric.to(device)
self.embedding = {k: e.to(device) for k, e in self.embedding.items()}
# self.embedding = {k: e.to(device) for k, e in self.embedding.items()}

def forward(self, batch):
node_embed, graph_embed = super().forward(batch)
Expand Down
2 changes: 1 addition & 1 deletion gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def forward(self, data: Union[Data, HeteroData]):
else:
# data = [transform(d) for d in data]
t_data = []
for d in tqdm(data, leave=False):
for d in tqdm(data, desc=str(transform), leave=False):
t_data.append(transform(d))
data = t_data
# s_bar.update(1)
Expand Down

0 comments on commit e211984

Please sign in to comment.