From e211984e3a407031f0a61efd5d96814184b0ef78 Mon Sep 17 00:00:00 2001 From: Roman Joeres Date: Mon, 19 Aug 2024 15:02:31 +0200 Subject: [PATCH] Debugging for final downstream comparison --- configs/downstream/all.yaml | 65 ++++++++++++++++++---------------- configs/downstream/both.yaml | 2 +- configs/downstream/gnngly.yaml | 17 --------- configs/downstream/lppe.yaml | 2 +- configs/downstream/rwpe.yaml | 2 +- gifflar/baselines/gnngly.py | 38 ++++---------------- gifflar/baselines/sweetnet.py | 47 ++++++++++-------------- gifflar/model.py | 18 +++++++--- gifflar/pretransforms.py | 2 +- 9 files changed, 78 insertions(+), 115 deletions(-) delete mode 100644 configs/downstream/gnngly.yaml diff --git a/configs/downstream/all.yaml b/configs/downstream/all.yaml index 98f1fbd..e9255b2 100644 --- a/configs/downstream/all.yaml +++ b/configs/downstream/all.yaml @@ -3,12 +3,14 @@ data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/ root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data logs_dir: logs datasets: - #- name: Immunogenicity - # task: classification - #- name: Glycosylation - # task: classification + - name: Immunogenicity + task: classification + - name: Glycosylation + task: classification - name: Taxonomy_Domain task: multilabel + #- name: Taxonomy_Kingdom + # task: multilabel #- name: Taxonomy_Phylum # task: multilabel #- name: Taxonomy_Class @@ -24,7 +26,7 @@ datasets: pre-transforms: model: #- name: rf - # n_estimators: 500 + # n_estimators: 500 # n_jobs: -1 # random_state: 42 #- name: svm @@ -35,32 +37,33 @@ model: hidden_dim: 1024 batch_size: 256 num_layers: 3 - epochs: 1 + epochs: 100 patience: 30 learning_rate: 0 optimizer: Adam - - name: sweetnet - hidden_dim: 1024 - batch_size: 512 - epochs: 1 - patience: 30 - learning_rate: 0.001 - optimizer: Adam - suffix: - - name: gnngly - hidden_dim: 1024 - batch_size: 512 - num_layers: 5 - epochs: 1 - patience: 30 - learning_rate: 0.001 - optimizer: Adam - suffix: - - name: gifflar - hidden_dim: 1024 - batch_size: 512 - num_layers: 8 - epochs: 1 - learning_rate: 0.001 - optimizer: Adam - suffix: + #- name: sweetnet + # hidden_dim: 1024 + # batch_size: 512 + # num_layers: 16 + # epochs: 1 + # patience: 30 + # learning_rate: 0.001 + # optimizer: Adam + # suffix: + #- name: gnngly + # hidden_dim: 1024 + # batch_size: 512 + # num_layers: 8 + # epochs: 1 + # patience: 30 + # learning_rate: 0.001 + # optimizer: Adam + # suffix: + #- name: gifflar + # hidden_dim: 1024 + # batch_size: 256 + # num_layers: 8 + # epochs: 100 + # learning_rate: 0.001 + # optimizer: Adam + # suffix: diff --git a/configs/downstream/both.yaml b/configs/downstream/both.yaml index b9c6c54..43d5d1a 100644 --- a/configs/downstream/both.yaml +++ b/configs/downstream/both.yaml @@ -31,7 +31,7 @@ pre-transforms: model: - name: gifflar hidden_dim: 1024 - batch_size: 512 + batch_size: 256 num_layers: 8 epochs: 100 learning_rate: 0.001 diff --git a/configs/downstream/gnngly.yaml b/configs/downstream/gnngly.yaml deleted file mode 100644 index d381b90..0000000 --- a/configs/downstream/gnngly.yaml +++ /dev/null @@ -1,17 +0,0 @@ -seed: 42 -data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/ -root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data -logs_dir: logs -datasets: - - name: Taxonomy_Domain - task: multilabel -pre-transforms: -model: - - name: gnngly - hidden_dim: 1024 - batch_size: 512 - num_layers: 5 - epochs: 1 - patience: 30 - learning_rate: 0.001 - optimizer: Adam diff --git a/configs/downstream/lppe.yaml b/configs/downstream/lppe.yaml index fba5f63..f66ba25 100644 --- a/configs/downstream/lppe.yaml +++ b/configs/downstream/lppe.yaml @@ -29,7 +29,7 @@ pre-transforms: model: - name: gifflar hidden_dim: 1024 - batch_size: 512 + batch_size: 256 num_layers: 8 epochs: 100 learning_rate: 0.001 diff --git a/configs/downstream/rwpe.yaml b/configs/downstream/rwpe.yaml index 7daa5c4..708d4bb 100644 --- a/configs/downstream/rwpe.yaml +++ b/configs/downstream/rwpe.yaml @@ -29,7 +29,7 @@ pre-transforms: model: - name: gifflar hidden_dim: 1024 - batch_size: 512 + batch_size: 256 num_layers: 8 epochs: 100 learning_rate: 0.001 diff --git a/gifflar/baselines/gnngly.py b/gifflar/baselines/gnngly.py index a16f1bb..76183e2 100644 --- a/gifflar/baselines/gnngly.py +++ b/gifflar/baselines/gnngly.py @@ -1,8 +1,9 @@ from typing import Dict, Literal +from collections import OrderedDict import torch from torch import nn -from torch_geometric.nn import GCNConv, global_mean_pool, Sequential +from torch_geometric.nn import GCNConv, global_mean_pool from gifflar.model import DownstreamGGIN @@ -30,39 +31,15 @@ def __init__(self, hidden_dim: int, num_layers: int, output_dim: int, """ Initialize the model following the papers description. """ - super().__init__(14, output_dim, task) + super().__init__(hidden_dim, output_dim, task) del self.convs - del self.head - self.layers = Sequential('x, edge_index', [GCNConv(hidden_dim, hidden_dim) for _ in range(num_layers)]) - - # Five layers of plain graph convolution with a hidden dimension of 14. - # self.layers = [ - # GCNConv(133, 14), - # GCNConv(14, 14), - # GCNConv(14, 14), - # GCNConv(14, 14), - # GCNConv(14, 14), - # ] + self.layers = nn.Sequential(OrderedDict([(f"layer{l + 1}", GCNConv((133 if l == 0 else hidden_dim), hidden_dim)) for l in range(num_layers)])) # ASSUMPTION: mean pooling self.pooling = global_mean_pool - # ASSUMPTION: a prediction head that seems quite elaborate given the other parts of the paper - # self.head = nn.Sequential( - # nn.Dropout(0.1), - # nn.Linear(14, 64), - # nn.PReLU(), - # nn.BatchNorm1d(64), - # nn.Linear(64, output_dim), - # ) - - def to(self, device): - super(GNNGLY, self).to(device) - # self.layers = [l.to(device) for l in self.layers] - - def forward(self, batch): """ Forward the data though the model. @@ -79,12 +56,11 @@ def forward(self, batch): edge_index = batch["gnngly_edge_index"] # Propagate the data through the model - # for layer in self.layers: - # x = layer(x, edge_index) + for layer in self.layers: + x = layer(x, edge_index) # Compute the graph embeddings and make the final prediction based on this - node_embed = self.layers(x, edge_index) - graph_embed = self.pooling(node_embed, batch_ids) + graph_embed = self.pooling(x, batch_ids) pred = self.head(graph_embed) return { diff --git a/gifflar/baselines/sweetnet.py b/gifflar/baselines/sweetnet.py index 7d815e5..6795a7a 100644 --- a/gifflar/baselines/sweetnet.py +++ b/gifflar/baselines/sweetnet.py @@ -1,10 +1,12 @@ from typing import Literal, Dict +from collections import OrderedDict import torch from glycowork.ml.models import prep_model from torch import nn from torch_geometric.nn import global_mean_pool, GraphConv, Sequential import torch.nn.functional as F +from glycowork.glycan_data.loader import lib from gifflar.data import HeteroDataBatch from gifflar.model import DownstreamGGIN @@ -27,9 +29,9 @@ def __init__(self, hidden_dim: int, output_dim: int, num_layers: int, del self.head # Load the untrained model from glycowork - # self.model = prep_model("SweetNet", output_dim, hidden_dim=hidden_dim) - self.layers = Sequential('x, edge_index', [GraphConv(hidden_dim, hidden_dim) for _ in range(num_layers)]) - + self.item_embedding = nn.Embedding(len(lib), hidden_dim) + self.layers = nn.Sequential(OrderedDict([(f"layer{l + 1}", GraphConv(hidden_dim, hidden_dim)) for l in range(num_layers)])) + self.head = nn.Sequential( nn.Linear(hidden_dim, 1024), nn.BatchNorm1d(1024), @@ -38,9 +40,9 @@ def __init__(self, hidden_dim: int, output_dim: int, num_layers: int, nn.BatchNorm1d(128), nn.LeakyReLU(), nn.Dropout(0.5), - nn.Linear(64, output_dim), + nn.Linear(128, output_dim), ) - + def forward(self, batch: HeteroDataBatch) -> Dict[str, torch.Tensor]: """ Forward the data though the model. @@ -52,33 +54,22 @@ def forward(self, batch: HeteroDataBatch) -> Dict[str, torch.Tensor]: Dict holding the node embeddings, the graph embedding, and the final model prediction """ # Extract monosaccharide graph from the heterogeneous graph - x = batch["x_dict"]["monosacchs"] - batch_ids = batch["batch_dict"]["monosacchs"] - edge_index = batch["edge_index_dict"]["monosacchs", "boundary", "monosacchs"] + x = batch["sweetnet_x"] + batch_ids = batch["sweetnet_batch"] + edge_index = batch["sweetnet_edge_index"] # Getting node features - # x = self.model.item_embedding(x) - # x = x.squeeze(1) - # - # Graph convolution operations - # x = F.leaky_relu(self.model.conv1(x, edge_index)) - # x = F.leaky_relu(self.model.conv2(x, edge_index)) - # node_embeds = F.leaky_relu(self.model.conv3(x, edge_index)) - # graph_embed = global_mean_pool(node_embeds, batch_ids) - # - # Fully connected part - # x = self.model.act1(self.model.bn1(self.model.lin1(graph_embed))) - # x_out = self.model.bn2(self.model.lin2(x)) - # x = F.dropout(self.model.act2(x_out), p=0.5, training=self.model.training) - # - # x = self.model.lin3(x) + x = self.item_embedding(x) + x = x.squeeze(1) + + for layer in self.layers: + x = layer(x, edge_index) - node_embeds = self.layers(x, edge_index) - graph_embed = global_mean_pool(node_embeds, batch_ids) - x = self.head(graph_embed) + graph_embed = global_mean_pool(x, batch_ids) + pred = self.head(graph_embed) return { - "node_embed": node_embeds, + "node_embed": x, "graph_embed": graph_embed, - "preds": x, + "preds": pred, } diff --git a/gifflar/model.py b/gifflar/model.py index 39d66eb..f9af4e3 100644 --- a/gifflar/model.py +++ b/gifflar/model.py @@ -33,6 +33,16 @@ def get_gin_layer(hidden_dim: int): } +class MultiEmbedding(nn.Module): + def __init__(self, embeddings: dict[str, nn.Embedding]): + super().__init__() + for name, embedding in embeddings.items(): + setattr(self, name, embedding) + + def forward(self, input_, name): + return getattr(self, name).forward(input_) + + class GlycanGIN(LightningModule): def __init__(self, hidden_dim: int, num_layers: int, task: Literal["regression", "classification", "multilabel"], pre_transform_args: Optional[Dict] = None): @@ -44,11 +54,11 @@ def __init__(self, hidden_dim: int, num_layers: int, task: Literal["regression", self.addendum.append(pre_transforms[name].attr_name) rand_dim -= args["dim"] - self.embedding = { + self.embedding = MultiEmbedding({ "atoms": nn.Embedding(len(atom_map) + 2, rand_dim), "bonds": nn.Embedding(len(bond_map) + 2, rand_dim), "monosacchs": nn.Embedding(len(lib) + 2, rand_dim), - } + }) self.convs = torch.nn.ModuleList() for _ in range(num_layers): @@ -69,7 +79,7 @@ def __init__(self, hidden_dim: int, num_layers: int, task: Literal["regression", def forward(self, batch): for key in batch.x_dict.keys(): # Compute random encodings for the atom type and include positional encodings - pes = [self.embedding[key](batch.x_dict[key])] + pes = [self.embedding.forward(batch.x_dict[key], key)] for pe in self.addendum: pes.append(batch[f"{key}_{pe}"]) @@ -119,7 +129,7 @@ def to(self, device): super(DownstreamGGIN, self).to(device) for split, metric in self.metrics.items(): self.metrics[split] = metric.to(device) - self.embedding = {k: e.to(device) for k, e in self.embedding.items()} + # self.embedding = {k: e.to(device) for k, e in self.embedding.items()} def forward(self, batch): node_embed, graph_embed = super().forward(batch) diff --git a/gifflar/pretransforms.py b/gifflar/pretransforms.py index a6a2bda..b90a3dc 100644 --- a/gifflar/pretransforms.py +++ b/gifflar/pretransforms.py @@ -261,7 +261,7 @@ def forward(self, data: Union[Data, HeteroData]): else: # data = [transform(d) for d in data] t_data = [] - for d in tqdm(data, leave=False): + for d in tqdm(data, desc=str(transform), leave=False): t_data.append(transform(d)) data = t_data # s_bar.update(1)