Skip to content

Commit

Permalink
LGI implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Oct 22, 2024
1 parent 2fe83c6 commit c8f2988
Show file tree
Hide file tree
Showing 9 changed files with 422 additions and 10 deletions.
18 changes: 18 additions & 0 deletions configs/lgi/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
seed: 42
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs
origin: path/to/dataset
model:
glycan_encoder:
name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
pooling: global_pool
lectin_encoder:
name: esm
layer_num: 11
epochs: 100
learning_rate: 0.001
optimizer: Adam
22 changes: 18 additions & 4 deletions experiments/aquire_lgi_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from glycowork.glycan_data.loader import glycan_binding as lgi
import pickle

import pandas as pd
import numpy as np
from glycowork.glycan_data.loader import glycan_binding as lgi


# Use stack to convert to a Series with a MultiIndex
lgi.index = lgi["target"]
lgi.drop(columns=["target", "protein"], inplace=True)
s = lgi.stack()

glycans = {f"Gly{i:04d}": iupac for i, iupac in enumerate(lgi.columns[:-2])}
glycans.update({iupac: f"Gly{i:04d}" for i, iupac in enumerate(lgi.columns[:-2])})

lectins = {f"Lec{i:04d}": aa_seq for i, aa_seq in enumerate(lgi.index)}
lectins.update({aa_seq: f"Lec{i:04d}" for i, aa_seq in enumerate(lgi.index)})

# Convert the Series to a list of triplets (row, col, val)
triplets = [(i, j, val) for (i, j), val in s.items()]
data = []
splits = np.random.choice(s.index, len(s))
for i, ((aa_seq, iupac), val) in enumerate(s.items()):
data.append((lectins[aa_seq], glycans[iupac], val, splits[i]))
if i == 1000:
break

print(len(triplets))
with open("lgi_data.pkl", "wb") as f:
pickle.dump((data, lectins, glycans), f)
140 changes: 140 additions & 0 deletions experiments/lgi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from pathlib import Path
from typing import Any, Literal

import torch
from pytorch_lightning import LightningModule
from torch_geometric.data import HeteroData

from experiments.protein_encoding import ENCODER_MAP, EMBED_SIZES
from gifflar.data.utils import GlycanStorage
from gifflar.model.base import GlycanGIN
from gifflar.model.baselines.sweetnet import SweetNetLightning
from gifflar.model.utils import GIFFLARPooling, get_prediction_head


class LectinStorage(GlycanStorage):
def __init__(self, lectin_encoder: str, le_layer_num: int, path: str | None = None):
"""
Initialize the wrapper around a dict.
Args:
path: Path to the directory. If there's a lectin_storage.pkl, it will be used to fill this object,
otherwise, such file will be created.
"""
self.path = Path(path or "data") / f"{lectin_encoder}_{le_layer_num}.pkl"
self.encoder = ENCODER_MAP[lectin_encoder](le_layer_num)
self.data = self._load()

def query(self, aa_seq: str) -> torch.Tensor:
if aa_seq not in self.data:
try:
self.data[aa_seq] = self.encoder(aa_seq)
except:
self.data[aa_seq] = None
return self.data[aa_seq]


class LGI_Model(LightningModule):
def __init__(
self,
glycan_encoder: GlycanGIN | SweetNetLightning,
lectin_encoder: str,
le_layer_num: int,
**kwargs: Any,
):
"""
Initialize the LGI model, a model for predicting lectin-glycan interactions.
Args:
glycan_encoder: The glycan encoder model
lectin_encoder: The lectin encoder model
le_layer_num: The number of layers to use in the lectin encoder
kwargs: Additional arguments
"""
super().__init__()
self.glycan_encoder = glycan_encoder
self.glycan_pooling = GIFFLARPooling("global_mean")
self.lectin_encoder = lectin_encoder
self.le_layer_num = le_layer_num

self.lectin_embeddings = LectinStorage(lectin_encoder, le_layer_num)
self.combined_dim = glycan_encoder.hidden_dim + EMBED_SIZES[lectin_encoder]

self.head, self.loss, self.metrics = get_prediction_head(self.combined_dim, 1, "regression")

def forward(self, data: HeteroData) -> dict[str, torch.Tensor]:
glycan_node_embed = self.glycan_encoder(data)
glycan_graph_embed = self.glycan_pooling(glycan_node_embed, data.batch_dict)
lectin_embed = self.lectin_embeddings.query(data["aa_seq"])
combined = torch.cat([glycan_graph_embed, lectin_embed], dim=-1)
pred = self.head(combined)

return {
"glycan_node_embed": glycan_node_embed,
"glycan_graph_embed": glycan_graph_embed,
"lectin_embed": lectin_embed,
"pred": pred,
}

def shared_step(self, batch: HeteroData, stage: str) -> dict[str, torch.Tensor]:
"""
Compute the shared step of the model.
Args:
data: The data to process
stage: The stage of the model
Returns:
A dictionary containing the loss and the metrics
"""
fwd_dict = self(batch)
fwd_dict["labels"] = batch["y"]
fwd_dict["loss"] = self.loss(fwd_dict["pred"], fwd_dict["label"])
self.metrics[stage].update(fwd_dict["pred"], fwd_dict["label"])
self.log(f"{stage}/loss", fwd_dict["loss"])

return fwd_dict

def training_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]:
"""Compute the training step of the model"""
return self.shared_step(batch, "train")

def validation_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]:
"""Compute the validation step of the model"""
return self.shared_step(batch, "val")

def test_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]:
"""Compute the testing step of the model"""
return self.shared_step(batch, "test")

def shared_end(self, stage: Literal["train", "val", "test"]):
"""
Compute the shared end of the model.
Args:
stage: The stage of the model
"""
metrics = self.metrics[stage].compute()
self.log_dict(metrics)
self.metrics[stage].reset()

def on_train_epoch_end(self) -> None:
"""Compute the end of the training epoch"""
self.shared_end("train")

def on_validation_epoch_end(self) -> None:
"""Compute the end of the validation"""
self.shared_end("val")

def on_test_epoch_end(self) -> None:
"""Compute the end of the testing"""
self.shared_end("test")

def configure_optimizers(self):
"""Configure the optimizer and the learning rate scheduler of the model"""
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return {
"optimizer": optimizer,
"lr_scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5),
"monitor": "val/loss",
}
145 changes: 145 additions & 0 deletions experiments/protein_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import re

import torch
from transformers import T5EncoderModel, AutoTokenizer, AutoModel, BertTokenizer, T5Tokenizer


class PLMEncoder:
def __init__(self, layer_num: int):
self.layer_num = layer_num

def forward(self, seq: str) -> torch.Tensor:
pass

def __call__(self, *args, **kwargs):
self.forward(*args, **kwargs)


class Ankh(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base")
self.model = T5EncoderModel.from_pretrained("ElnaggarLab/ankh-base")

def forward(self, seq: str) -> torch.Tensor:
outputs = self.tokenizer.batch_encode_plus(
[list(seq)],
add_special_tokens=True,
padding=True,
is_split_into_words=True,
return_tensors="pt",
)
with torch.no_grad():
ankh = self.model(
input_ids=outputs["input_ids"],
attention_mask=outputs["attention_mask"],
output_attentions=False,
output_hidden_states=True,
)
return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=2)[0]


class ESM(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")

def forward(self, seq: str) -> torch.Tensor:
a = self.tokenizer(seq)
with torch.no_grad():
esm = self.model(
torch.Tensor(a["input_ids"]).long().reshape(1, -1),
torch.Tensor(a["attention_mask"]).long().reshape(1, -1),
output_attentions=False,
output_hidden_states=True,
)
return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0]


class ProtBERT(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
self.model = AutoModel.from_pretrained("Rostlab/prot_bert")

def forward(self, seq: str) -> torch.Tensor:
sequence_w_spaces = ' '.join(seq)
encoded_input = self.tokenizer(
sequence_w_spaces,
return_tensors='pt'
)
with torch.no_grad():
protbert = self.model(
**encoded_input,
output_attentions=False,
output_hidden_states=True,
)
return protbert.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0]


class ProstT5(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5")
self.model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")

def forward(self, seq: str) -> torch.Tensor:
seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in [seq]]
seq = ["<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in seq]
ids = self.tokenizer.batch_encode_plus(
seq,
add_special_tokens=True,
padding="longest",
return_tensors='pt'
)
with torch.no_grad():
prostt5 = self.model(
ids.input_ids,
attention_mask=ids.attention_mask,
output_attentions=False,
output_hidden_states=True,
)
return prostt5.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0]


class AMPLIFY(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True)
self.model = AutoModel.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True).to("cuda")

def forward(self, seq: str) -> torch.Tensor:
a = self.tokenizer.encode(seq, return_tensors="pt").to("cuda")
with torch.no_grad():
amplify = self.model(
a,
output_attentions=False,
output_hidden_states=True,
)
pass


ENCODER_MAP = {
"Ankh": Ankh,
"ESM": ESM,
"ProtBert": ProtBERT,
"ProstT5": ProstT5,
"AMPLIFY": AMPLIFY,
}

EMBED_SIZES = {
"Ankh": 768,
"ESM": 1280,
"ProtBert": 1024,
"ProstT5": 1024,
"AMPLIFY": ...,
}

MAX_LAYERS = {
"Ankh": 49,
"ESM": 34,
"ProtBert": 31,
"ProstT5": 25,
"AMPLIFY": ...,
}
Loading

0 comments on commit c8f2988

Please sign in to comment.