diff --git a/configs/lgi/test.yaml b/configs/lgi/test.yaml new file mode 100644 index 0000000..ac9ad40 --- /dev/null +++ b/configs/lgi/test.yaml @@ -0,0 +1,18 @@ +seed: 42 +root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data +logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs +origin: path/to/dataset +model: + glycan_encoder: + name: gifflar + feat_dim: 128 + hidden_dim: 1024 + batch_size: 256 + num_layers: 8 + pooling: global_pool + lectin_encoder: + name: esm + layer_num: 11 + epochs: 100 + learning_rate: 0.001 + optimizer: Adam diff --git a/experiments/aquire_lgi_dataset.py b/experiments/aquire_lgi_dataset.py index de29286..9773696 100644 --- a/experiments/aquire_lgi_dataset.py +++ b/experiments/aquire_lgi_dataset.py @@ -1,13 +1,27 @@ -from glycowork.glycan_data.loader import glycan_binding as lgi +import pickle -import pandas as pd import numpy as np +from glycowork.glycan_data.loader import glycan_binding as lgi # Use stack to convert to a Series with a MultiIndex +lgi.index = lgi["target"] +lgi.drop(columns=["target", "protein"], inplace=True) s = lgi.stack() +glycans = {f"Gly{i:04d}": iupac for i, iupac in enumerate(lgi.columns[:-2])} +glycans.update({iupac: f"Gly{i:04d}" for i, iupac in enumerate(lgi.columns[:-2])}) + +lectins = {f"Lec{i:04d}": aa_seq for i, aa_seq in enumerate(lgi.index)} +lectins.update({aa_seq: f"Lec{i:04d}" for i, aa_seq in enumerate(lgi.index)}) + # Convert the Series to a list of triplets (row, col, val) -triplets = [(i, j, val) for (i, j), val in s.items()] +data = [] +splits = np.random.choice(s.index, len(s)) +for i, ((aa_seq, iupac), val) in enumerate(s.items()): + data.append((lectins[aa_seq], glycans[iupac], val, splits[i])) + if i == 1000: + break -print(len(triplets)) +with open("lgi_data.pkl", "wb") as f: + pickle.dump((data, lectins, glycans), f) diff --git a/experiments/lgi_model.py b/experiments/lgi_model.py new file mode 100644 index 0000000..96918ff --- /dev/null +++ b/experiments/lgi_model.py @@ -0,0 +1,140 @@ +from pathlib import Path +from typing import Any, Literal + +import torch +from pytorch_lightning import LightningModule +from torch_geometric.data import HeteroData + +from experiments.protein_encoding import ENCODER_MAP, EMBED_SIZES +from gifflar.data.utils import GlycanStorage +from gifflar.model.base import GlycanGIN +from gifflar.model.baselines.sweetnet import SweetNetLightning +from gifflar.model.utils import GIFFLARPooling, get_prediction_head + + +class LectinStorage(GlycanStorage): + def __init__(self, lectin_encoder: str, le_layer_num: int, path: str | None = None): + """ + Initialize the wrapper around a dict. + + Args: + path: Path to the directory. If there's a lectin_storage.pkl, it will be used to fill this object, + otherwise, such file will be created. + """ + self.path = Path(path or "data") / f"{lectin_encoder}_{le_layer_num}.pkl" + self.encoder = ENCODER_MAP[lectin_encoder](le_layer_num) + self.data = self._load() + + def query(self, aa_seq: str) -> torch.Tensor: + if aa_seq not in self.data: + try: + self.data[aa_seq] = self.encoder(aa_seq) + except: + self.data[aa_seq] = None + return self.data[aa_seq] + + +class LGI_Model(LightningModule): + def __init__( + self, + glycan_encoder: GlycanGIN | SweetNetLightning, + lectin_encoder: str, + le_layer_num: int, + **kwargs: Any, + ): + """ + Initialize the LGI model, a model for predicting lectin-glycan interactions. + + Args: + glycan_encoder: The glycan encoder model + lectin_encoder: The lectin encoder model + le_layer_num: The number of layers to use in the lectin encoder + kwargs: Additional arguments + """ + super().__init__() + self.glycan_encoder = glycan_encoder + self.glycan_pooling = GIFFLARPooling("global_mean") + self.lectin_encoder = lectin_encoder + self.le_layer_num = le_layer_num + + self.lectin_embeddings = LectinStorage(lectin_encoder, le_layer_num) + self.combined_dim = glycan_encoder.hidden_dim + EMBED_SIZES[lectin_encoder] + + self.head, self.loss, self.metrics = get_prediction_head(self.combined_dim, 1, "regression") + + def forward(self, data: HeteroData) -> dict[str, torch.Tensor]: + glycan_node_embed = self.glycan_encoder(data) + glycan_graph_embed = self.glycan_pooling(glycan_node_embed, data.batch_dict) + lectin_embed = self.lectin_embeddings.query(data["aa_seq"]) + combined = torch.cat([glycan_graph_embed, lectin_embed], dim=-1) + pred = self.head(combined) + + return { + "glycan_node_embed": glycan_node_embed, + "glycan_graph_embed": glycan_graph_embed, + "lectin_embed": lectin_embed, + "pred": pred, + } + + def shared_step(self, batch: HeteroData, stage: str) -> dict[str, torch.Tensor]: + """ + Compute the shared step of the model. + + Args: + data: The data to process + stage: The stage of the model + + Returns: + A dictionary containing the loss and the metrics + """ + fwd_dict = self(batch) + fwd_dict["labels"] = batch["y"] + fwd_dict["loss"] = self.loss(fwd_dict["pred"], fwd_dict["label"]) + self.metrics[stage].update(fwd_dict["pred"], fwd_dict["label"]) + self.log(f"{stage}/loss", fwd_dict["loss"]) + + return fwd_dict + + def training_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]: + """Compute the training step of the model""" + return self.shared_step(batch, "train") + + def validation_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]: + """Compute the validation step of the model""" + return self.shared_step(batch, "val") + + def test_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]: + """Compute the testing step of the model""" + return self.shared_step(batch, "test") + + def shared_end(self, stage: Literal["train", "val", "test"]): + """ + Compute the shared end of the model. + + Args: + stage: The stage of the model + """ + metrics = self.metrics[stage].compute() + self.log_dict(metrics) + self.metrics[stage].reset() + + def on_train_epoch_end(self) -> None: + """Compute the end of the training epoch""" + self.shared_end("train") + + def on_validation_epoch_end(self) -> None: + """Compute the end of the validation""" + self.shared_end("val") + + def on_test_epoch_end(self) -> None: + """Compute the end of the testing""" + self.shared_end("test") + + def configure_optimizers(self): + """Configure the optimizer and the learning rate scheduler of the model""" + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return { + "optimizer": optimizer, + "lr_scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5), + "monitor": "val/loss", + } diff --git a/experiments/protein_encoding.py b/experiments/protein_encoding.py new file mode 100644 index 0000000..18362f3 --- /dev/null +++ b/experiments/protein_encoding.py @@ -0,0 +1,145 @@ +import re + +import torch +from transformers import T5EncoderModel, AutoTokenizer, AutoModel, BertTokenizer, T5Tokenizer + + +class PLMEncoder: + def __init__(self, layer_num: int): + self.layer_num = layer_num + + def forward(self, seq: str) -> torch.Tensor: + pass + + def __call__(self, *args, **kwargs): + self.forward(*args, **kwargs) + + +class Ankh(PLMEncoder): + def __init__(self, layer_num: int): + super().__init__(layer_num) + self.tokenizer = AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base") + self.model = T5EncoderModel.from_pretrained("ElnaggarLab/ankh-base") + + def forward(self, seq: str) -> torch.Tensor: + outputs = self.tokenizer.batch_encode_plus( + [list(seq)], + add_special_tokens=True, + padding=True, + is_split_into_words=True, + return_tensors="pt", + ) + with torch.no_grad(): + ankh = self.model( + input_ids=outputs["input_ids"], + attention_mask=outputs["attention_mask"], + output_attentions=False, + output_hidden_states=True, + ) + return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=2)[0] + + +class ESM(PLMEncoder): + def __init__(self, layer_num: int): + super().__init__(layer_num) + self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + + def forward(self, seq: str) -> torch.Tensor: + a = self.tokenizer(seq) + with torch.no_grad(): + esm = self.model( + torch.Tensor(a["input_ids"]).long().reshape(1, -1), + torch.Tensor(a["attention_mask"]).long().reshape(1, -1), + output_attentions=False, + output_hidden_states=True, + ) + return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] + + +class ProtBERT(PLMEncoder): + def __init__(self, layer_num: int): + super().__init__(layer_num) + self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) + self.model = AutoModel.from_pretrained("Rostlab/prot_bert") + + def forward(self, seq: str) -> torch.Tensor: + sequence_w_spaces = ' '.join(seq) + encoded_input = self.tokenizer( + sequence_w_spaces, + return_tensors='pt' + ) + with torch.no_grad(): + protbert = self.model( + **encoded_input, + output_attentions=False, + output_hidden_states=True, + ) + return protbert.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] + + +class ProstT5(PLMEncoder): + def __init__(self, layer_num: int): + super().__init__(layer_num) + self.tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5") + self.model = T5EncoderModel.from_pretrained("Rostlab/ProstT5") + + def forward(self, seq: str) -> torch.Tensor: + seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in [seq]] + seq = ["" + " " + s if s.isupper() else "" + " " + s for s in seq] + ids = self.tokenizer.batch_encode_plus( + seq, + add_special_tokens=True, + padding="longest", + return_tensors='pt' + ) + with torch.no_grad(): + prostt5 = self.model( + ids.input_ids, + attention_mask=ids.attention_mask, + output_attentions=False, + output_hidden_states=True, + ) + return prostt5.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] + + +class AMPLIFY(PLMEncoder): + def __init__(self, layer_num: int): + super().__init__(layer_num) + self.tokenizer = AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True) + self.model = AutoModel.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True).to("cuda") + + def forward(self, seq: str) -> torch.Tensor: + a = self.tokenizer.encode(seq, return_tensors="pt").to("cuda") + with torch.no_grad(): + amplify = self.model( + a, + output_attentions=False, + output_hidden_states=True, + ) + pass + + +ENCODER_MAP = { + "Ankh": Ankh, + "ESM": ESM, + "ProtBert": ProtBERT, + "ProstT5": ProstT5, + "AMPLIFY": AMPLIFY, +} + +EMBED_SIZES = { + "Ankh": 768, + "ESM": 1280, + "ProtBert": 1024, + "ProstT5": 1024, + "AMPLIFY": ..., +} + +MAX_LAYERS = { + "Ankh": 49, + "ESM": 34, + "ProtBert": 31, + "ProstT5": 25, + "AMPLIFY": ..., +} \ No newline at end of file diff --git a/experiments/train_lgi.py b/experiments/train_lgi.py new file mode 100644 index 0000000..6969387 --- /dev/null +++ b/experiments/train_lgi.py @@ -0,0 +1,63 @@ +from argparse import ArgumentParser +import time + +from numpy.f2py.cfuncs import callbacks +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import RichProgressBar, RichModelSummary +from pytorch_lightning.loggers import CSVLogger +from sympy.physics.units import acceleration +from torch_geometric import seed_everything + +from experiments.lgi_model import LGI_Model +from gifflar.data.modules import DownsteamGDM, LGI_GDM +from gifflar.model.base import GlycanGIN +from gifflar.model.baselines.sweetnet import SweetNetLightning +from gifflar.pretransforms import get_pretransforms +from gifflar.train import setup +from gifflar.utils import read_yaml_config, hash_dict + + +GLYCAN_ENCODERS = { + "gifflar": GlycanGIN, + "sweetnet": SweetNetLightning, +} + + +def main(config): + kwargs = read_yaml_config(config) + kwargs["pre-transforms"] = {"GIFFLARTransform": "", "SweetNetTransform": ""} + kwargs["hash"] = hash_dict(kwargs["pre-transforms"]) + seed_everything(kwargs["seed"]) + + datamodule = LGI_GDM( + root=kwargs["root_dir"], filename=kwargs["origin"], hash_code=kwargs["hash"], + batch_size=kwargs["model"].get("batch_size", 1), transform=None, + pre_transform=get_pretransforms("", **(kwargs["pre-transforms"] or {})), + ) + + # set up the logger + logger = CSVLogger(kwargs["logs_dir"], name=kwargs["model"]["name"] + (kwargs["model"].get("suffix", None) or "")) + logger.log_hyperparams(kwargs) + + glycan_encoder = GLYCAN_ENCODERS[kwargs["model"]["glycan_encoder"]["name"]](**kwargs["model"]["glycan_encoder"]) + model = LGI_Model( + glycan_encoder, + kwargs["model"]["lectin_encoder"]["name"], + kwargs["model"]["lectin_encoder"]["le_layer_num"], + ) + + trainer = Trainer( + callbacks=[RichProgressBar(), RichModelSummary()], + logger=logger, + max_epochs=kwargs["model"]["max_epochs"], + accelerator="cpu", + ) + start = time.time() + trainer.fit(model, datamodule) + print("Training took", time.time() - start, "s") + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument("config", type=str, help="Path to YAML config file") + main(parser.parse_args().config) diff --git a/gifflar/data/datasets.py b/gifflar/data/datasets.py index a88cd82..8573653 100644 --- a/gifflar/data/datasets.py +++ b/gifflar/data/datasets.py @@ -1,3 +1,4 @@ +import pickle from pathlib import Path from typing import Union, Optional, Callable, Any @@ -206,3 +207,28 @@ def process(self) -> None: print("Processed", sum(len(v) for v in data.values()), "entries") for split in self.splits: self.process_(data[split], path_idx=self.splits[split]) + + +class LGIDataset(DownstreamGDs): + def process(self) -> None: + """Process the data and store it.""" + print("Start processing") + data = {k: [] for k in self.splits} + with open(self.filename, "r") as f: + inter, lectin_map, glycan_map = pickle.load(f) + + # Load the glycan storage to speed up the preprocessing + gs = GlycanStorage(Path(self.root).parent) + for i, (lectin_id, glycan_id, value, split) in tqdm(enumerate(inter)): + d = gs.query(glycan_map[glycan_id]) + if d is None: + continue + d["aa_seq"] = lectin_map[lectin_id] + d["y"] = torch.tensor(value) + d["ID"] = i + data[split].append(d) + + gs.close() + print("Processed", sum(len(v) for v in data.values()), "entries") + for split in self.splits: + self.process_(data[split], path_idx=self.splits[split]) diff --git a/gifflar/data/modules.py b/gifflar/data/modules.py index dc63bfd..608d06b 100644 --- a/gifflar/data/modules.py +++ b/gifflar/data/modules.py @@ -5,7 +5,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import ConcatDataset, DataLoader -from gifflar.data.datasets import DownstreamGDs, PretrainGDs +from gifflar.data.datasets import DownstreamGDs, PretrainGDs, LGIDataset from gifflar.data.hetero import hetero_collate @@ -103,6 +103,8 @@ def __init__( class DownsteamGDM(GlycanDataModule): """DataModule for downstream tasks on glycan data.""" + ds_class = DownstreamGDs + def __init__( self, root: str | Path, @@ -126,15 +128,18 @@ def __init__( **dataset_args: Additional arguments to pass to the DownstreamGDs """ super().__init__(batch_size) - self.train = DownstreamGDs( + self.train = self.ds_class( root=root, filename=filename, split="train", hash_code=hash_code, transform=transform, pre_transform=pre_transform, **dataset_args, ) - self.val = DownstreamGDs( + self.val = self.ds_class( root=root, filename=filename, split="val", hash_code=hash_code, transform=transform, pre_transform=pre_transform, **dataset_args, ) - self.test = DownstreamGDs( + self.test = self.ds_class( root=root, filename=filename, split="test", hash_code=hash_code, transform=transform, pre_transform=pre_transform, **dataset_args, ) + +class LGI_GDM(DownsteamGDM): + ds_class = LGIDataset \ No newline at end of file diff --git a/gifflar/model/base.py b/gifflar/model/base.py index a8b024a..0be0188 100644 --- a/gifflar/model/base.py +++ b/gifflar/model/base.py @@ -20,7 +20,7 @@ class GlycanGIN(LightningModule): def __init__(self, feat_dim: int, hidden_dim: int, num_layers: int, batch_size: int = 32, - pre_transform_args: Optional[dict] = None): + pre_transform_args: Optional[dict] = None, **kwargs: Any): """ Initialize the GlycanGIN model, the base for all DL-models in this package @@ -29,6 +29,7 @@ def __init__(self, feat_dim: int, hidden_dim: int, num_layers: int, batch_size: hidden_dim: The hidden dimension of the model num_layers: The number of GIN layers to use pre_transform_args: A dictionary of pre-transforms to apply to the input data + kwargs: Additional arguments (ignored) """ super().__init__() diff --git a/gifflar/train.py b/gifflar/train.py index 7abd562..d572753 100644 --- a/gifflar/train.py +++ b/gifflar/train.py @@ -58,13 +58,13 @@ def setup(count: int = 4, **kwargs: Any) -> tuple[dict, DownsteamGDM, Logger | N pre_transform=get_pretransforms(data_config["name"], **(kwargs["pre-transforms"] or {})), **data_config, ) data_config["num_classes"] = datamodule.train.dataset_args["num_classes"] + kwargs["dataset"]["filepath"] = str(data_config["filepath"]) if count == 2: return data_config, datamodule, None, None # set up the logger logger = CSVLogger(kwargs["logs_dir"], name=kwargs["model"]["name"] + (kwargs["model"].get("suffix", None) or "")) - kwargs["dataset"]["filepath"] = str(data_config["filepath"]) logger.log_hyperparams(kwargs) if count == 3: