diff --git a/configs/lgi/all.yaml b/configs/lgi/all.yaml index 8a01495..5aff15b 100644 --- a/configs/lgi/all.yaml +++ b/configs/lgi/all.yaml @@ -12,10 +12,10 @@ model: hidden_dim: 1024 num_layers: 8 pooling: global_mean - - name: sweetnet - feat_dim: 128 - hidden_dim: 1024 - num_layers: 16 + #- name: sweetnet + # feat_dim: 128 + # hidden_dim: 1024 + # num_layers: 16 lectin_encoder: - name: ESM layer_num: 33 diff --git a/configs/lgi/sweetnet.yaml b/configs/lgi/sweetnet.yaml new file mode 100644 index 0000000..5aff15b --- /dev/null +++ b/configs/lgi/sweetnet.yaml @@ -0,0 +1,32 @@ +seed: 42 +#root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data +#logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs +#origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl +root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data +logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs +origin: /home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl +model: + glycan_encoder: + - name: gifflar + feat_dim: 128 + hidden_dim: 1024 + num_layers: 8 + pooling: global_mean + #- name: sweetnet + # feat_dim: 128 + # hidden_dim: 1024 + # num_layers: 16 + lectin_encoder: + - name: ESM + layer_num: 33 + - name: Ankh + layer_num: 48 + - name: ProtBert + layer_num: 30 + - name: ProstT5 + layer_num: 24 + batch_size: 256 + epochs: 100 + learning_rate: 0.001 + optimizer: Adam + diff --git a/experiments/lectinoracle.py b/experiments/lectinoracle.py new file mode 100644 index 0000000..f887036 --- /dev/null +++ b/experiments/lectinoracle.py @@ -0,0 +1,104 @@ +from tqdm import tqdm +import torch +from torch_geometric.data import Batch +from torch_geometric.data.data import Data +from torch_geometric.loader import DataLoader +from glycowork.ml.model_training import train_model, SAM +from glycowork.ml.models import prep_model + +from gifflar.data.modules import LGI_GDM +from gifflar.data.datasets import GlycanOnDiskDataset +from experiments.lgi_model import LectinStorage + +le = LectinStorage("ESM", 33) + +class LGI_OnDiskDataset(GlycanOnDiskDataset): + @property + def processed_file_names(self): + """Return the list of processed file names.""" + return [split + ".db" for split in ["train", "val", "test"]] + + +def get_ds(dl, split_idx: int): + ds = LGI_OnDiskDataset(root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=split_idx) + data = [] + for x in tqdm(dl): + data.append(Data( + labels=x["sweetnet_x"], + y=x["y"], + edge_index=x["sweetnet_edge_index"], + aa_seq=x["aa_seq"][0], + )) + if len(data) == 100: + ds.extend(data) + del data + data = [] + if len(data) != 0: + ds.extend(data) + del data + + +def collate_lgi(data): + for d in data: + d["train_idx"] = le.query(d["aa_seq"]) + + offset = 0 + labels, edges, y, train_idx, batch = [], [], [], [], [] + for i, d in enumerate(data): + labels.append(d["labels"]) + edges.append(torch.stack([ + d["edge_index"][0] + offset, + d["edge_index"][1] + offset, + ])) + offset += len(d["labels"]) + y.append(d["y"]) + train_idx.append(le.query(d["aa_seq"])) + batch += [i for _ in range(len(d["labels"]))] + + labels = torch.cat(labels, dim=0) + edges = torch.cat(edges, dim=1) + y = torch.stack(y) + train_idx = torch.stack(train_idx) + batch = torch.tensor(batch) + + return Batch( + labels=labels, + edge_index=edges, + y=y, + train_idx=train_idx, + batch=batch, + ) + +datamodule = LGI_GDM( + root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data", filename="/home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl", hash_code="8b34af2a", + batch_size=1, transform=None, pre_transform={"GIFFLARTransform": "", "SweetNetTransform": ""}, +) + +#get_ds(datamodule.train_dataloader(), 0) +#get_ds(datamodule.val_dataloader(), 1) +#get_ds(datamodule.test_dataloader(), 2) + +train_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=0) +val_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=1) + +model = prep_model("LectinOracle", num_classes=1) +optimizer = torch.optim.Adam(model.parameters()) +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + +m = train_model( + model=model, + dataloaders={"train": torch.utils.data.DataLoader(train_set, batch_size=128, collate_fn=collate_lgi), + "val": torch.utils.data.DataLoader(val_set, batch_size=128, collate_fn=collate_lgi)}, + criterion=torch.nn.MSELoss(), + optimizer=optimizer, + scheduler=scheduler, + return_metrics=True, + mode="regression", + num_epochs=100, + patience=100, +) + +import pickle + +with open("lectinoracle_metrics.pkl", "wb") as f: + pickle.dump(m, f) diff --git a/experiments/lgi_model.py b/experiments/lgi_model.py index 948c870..ee27dca 100644 --- a/experiments/lgi_model.py +++ b/experiments/lgi_model.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Any, Literal +import copy import torch from pytorch_lightning import LightningModule @@ -31,7 +32,8 @@ def query(self, aa_seq: str) -> torch.Tensor: if aa_seq not in self.data: try: self.data[aa_seq] = self.encoder(aa_seq) - except: + except Exception as e: + print(e) self.data[aa_seq] = None return self.data[aa_seq] @@ -59,6 +61,7 @@ def __init__( kwargs: Additional arguments """ super().__init__() + self.glycan_encoder = glycan_encoder self.glycan_pooling = GIFFLARPooling("global_mean") self.lectin_encoder = lectin_encoder @@ -140,6 +143,7 @@ def on_train_epoch_end(self) -> None: def on_validation_epoch_end(self) -> None: """Compute the end of the validation""" self.shared_end("val") + self.lectin_embeddings.close() def on_test_epoch_end(self) -> None: """Compute the end of the testing""" diff --git a/experiments/protein_encoding.py b/experiments/protein_encoding.py index bf8779c..6f0edc6 100644 --- a/experiments/protein_encoding.py +++ b/experiments/protein_encoding.py @@ -26,7 +26,6 @@ def forward(self, seq: str) -> torch.Tensor: outputs = self.tokenizer.batch_encode_plus( [list(seq)], add_special_tokens=True, - padding=True, is_split_into_words=True, return_tensors="pt", ) @@ -37,7 +36,7 @@ def forward(self, seq: str) -> torch.Tensor: output_attentions=False, output_hidden_states=True, ) - return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=1)[0] + return ankh.hidden_states[self.layer_num][:, 1:].mean(dim=1)[0] class ESM(PLMEncoder): @@ -68,15 +67,17 @@ def forward(self, seq: str) -> torch.Tensor: sequence_w_spaces = ' '.join(seq) encoded_input = self.tokenizer( sequence_w_spaces, - return_tensors='pt' + return_tensors='pt', ) with torch.no_grad(): protbert = self.model( - **encoded_input, + input_ids=encoded_input["input_ids"].to(self.device), + attention_mask=encoded_input["attention_mask"].to(self.device), + token_type_ids=encoded_input["token_type_ids"].to(self.device), output_attentions=False, output_hidden_states=True, ) - return protbert.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] + return protbert.hidden_states[self.layer_num][:, 1:-1].mean(dim=1)[0] class ProstT5(PLMEncoder): @@ -87,12 +88,11 @@ def __init__(self, layer_num: int): def forward(self, seq: str) -> torch.Tensor: seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in [seq]] - seq = ["" + " " + s if s.isupper() else "" + " " + s for s in seq] + seq = ["" + " " + s.upper() for s in seq] ids = self.tokenizer.batch_encode_plus( seq, add_special_tokens=True, - padding="longest", - return_tensors='pt' + return_tensors='pt', ) with torch.no_grad(): prostt5 = self.model( @@ -101,7 +101,7 @@ def forward(self, seq: str) -> torch.Tensor: output_attentions=False, output_hidden_states=True, ) - return prostt5.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] + return prostt5.hidden_states[self.layer_num][:, 1:-1].mean(dim=1)[0] class AMPLIFY(PLMEncoder): diff --git a/gifflar/data/datasets.py b/gifflar/data/datasets.py index a6308b3..f6bcf9f 100644 --- a/gifflar/data/datasets.py +++ b/gifflar/data/datasets.py @@ -98,7 +98,7 @@ def process_(self, data: list[HeteroData], path_idx: Path | str): torch.save((data, self.dataset_args), self.processed_paths[path_idx]) -class GlycanDataset(GlycanOnDeskDataset): +class GlycanDataset(GlycanOnDiskDataset): def __init__( self, root: str | Path, diff --git a/gifflar/model/baselines/sweetnet.py b/gifflar/model/baselines/sweetnet.py index 5708212..fe80636 100644 --- a/gifflar/model/baselines/sweetnet.py +++ b/gifflar/model/baselines/sweetnet.py @@ -3,7 +3,7 @@ import torch from torch import nn -from torch_geometric.nn import global_mean_pool, GraphConv +from torch_geometric.nn import global_mean_pool, GraphConv, BatchNorm from glycowork.glycan_data.loader import lib from gifflar.data.hetero import HeteroDataBatch @@ -37,7 +37,11 @@ def __init__( # Load the untrained model from glycowork self.item_embedding = nn.Embedding(len(lib), hidden_dim) - self.layers = nn.Sequential(OrderedDict([(f"layer{l + 1}", GraphConv(hidden_dim, hidden_dim)) for l in range(num_layers)])) + layers = [] + for l in range(num_layers): + layers.append((f"layer_{l + 1}_gc", GraphConv(hidden_dim, hidden_dim))) + layers.append((f"layer_{l + 1}_bn", BatchNorm(hidden_dim))) + self.layers = nn.Sequential(OrderedDict(layers)) if self.task is not None: del self.head @@ -71,8 +75,8 @@ def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]: x = self.item_embedding(x) x = x.squeeze(1) - for layer in self.layers: - x = layer(x, edge_index) + for l in range(0, len(self.layers), 2): + x = self.layers[l + 1](self.layers[l](x, edge_index)) graph_embed = global_mean_pool(x, batch_ids) pred = None diff --git a/requirements.txt b/requirements.txt index 81aba64..f21dd43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ torchmetrics transformers sentencepiece xformers==0.0.28.post1 +protobuf