-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:BojarLab/GIFFLAR
- Loading branch information
Showing
9 changed files
with
422 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
seed: 42 | ||
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data | ||
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs | ||
origin: path/to/dataset | ||
model: | ||
glycan_encoder: | ||
name: gifflar | ||
feat_dim: 128 | ||
hidden_dim: 1024 | ||
batch_size: 256 | ||
num_layers: 8 | ||
pooling: global_pool | ||
lectin_encoder: | ||
name: esm | ||
layer_num: 11 | ||
epochs: 100 | ||
learning_rate: 0.001 | ||
optimizer: Adam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,27 @@ | ||
from glycowork.glycan_data.loader import glycan_binding as lgi | ||
import pickle | ||
|
||
import pandas as pd | ||
import numpy as np | ||
from glycowork.glycan_data.loader import glycan_binding as lgi | ||
|
||
|
||
# Use stack to convert to a Series with a MultiIndex | ||
lgi.index = lgi["target"] | ||
lgi.drop(columns=["target", "protein"], inplace=True) | ||
s = lgi.stack() | ||
|
||
glycans = {f"Gly{i:04d}": iupac for i, iupac in enumerate(lgi.columns[:-2])} | ||
glycans.update({iupac: f"Gly{i:04d}" for i, iupac in enumerate(lgi.columns[:-2])}) | ||
|
||
lectins = {f"Lec{i:04d}": aa_seq for i, aa_seq in enumerate(lgi.index)} | ||
lectins.update({aa_seq: f"Lec{i:04d}" for i, aa_seq in enumerate(lgi.index)}) | ||
|
||
# Convert the Series to a list of triplets (row, col, val) | ||
triplets = [(i, j, val) for (i, j), val in s.items()] | ||
data = [] | ||
splits = np.random.choice(s.index, len(s)) | ||
for i, ((aa_seq, iupac), val) in enumerate(s.items()): | ||
data.append((lectins[aa_seq], glycans[iupac], val, splits[i])) | ||
if i == 1000: | ||
break | ||
|
||
print(len(triplets)) | ||
with open("lgi_data.pkl", "wb") as f: | ||
pickle.dump((data, lectins, glycans), f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from pathlib import Path | ||
from typing import Any, Literal | ||
|
||
import torch | ||
from pytorch_lightning import LightningModule | ||
from torch_geometric.data import HeteroData | ||
|
||
from experiments.protein_encoding import ENCODER_MAP, EMBED_SIZES | ||
from gifflar.data.utils import GlycanStorage | ||
from gifflar.model.base import GlycanGIN | ||
from gifflar.model.baselines.sweetnet import SweetNetLightning | ||
from gifflar.model.utils import GIFFLARPooling, get_prediction_head | ||
|
||
|
||
class LectinStorage(GlycanStorage): | ||
def __init__(self, lectin_encoder: str, le_layer_num: int, path: str | None = None): | ||
""" | ||
Initialize the wrapper around a dict. | ||
Args: | ||
path: Path to the directory. If there's a lectin_storage.pkl, it will be used to fill this object, | ||
otherwise, such file will be created. | ||
""" | ||
self.path = Path(path or "data") / f"{lectin_encoder}_{le_layer_num}.pkl" | ||
self.encoder = ENCODER_MAP[lectin_encoder](le_layer_num) | ||
self.data = self._load() | ||
|
||
def query(self, aa_seq: str) -> torch.Tensor: | ||
if aa_seq not in self.data: | ||
try: | ||
self.data[aa_seq] = self.encoder(aa_seq) | ||
except: | ||
self.data[aa_seq] = None | ||
return self.data[aa_seq] | ||
|
||
|
||
class LGI_Model(LightningModule): | ||
def __init__( | ||
self, | ||
glycan_encoder: GlycanGIN | SweetNetLightning, | ||
lectin_encoder: str, | ||
le_layer_num: int, | ||
**kwargs: Any, | ||
): | ||
""" | ||
Initialize the LGI model, a model for predicting lectin-glycan interactions. | ||
Args: | ||
glycan_encoder: The glycan encoder model | ||
lectin_encoder: The lectin encoder model | ||
le_layer_num: The number of layers to use in the lectin encoder | ||
kwargs: Additional arguments | ||
""" | ||
super().__init__() | ||
self.glycan_encoder = glycan_encoder | ||
self.glycan_pooling = GIFFLARPooling("global_mean") | ||
self.lectin_encoder = lectin_encoder | ||
self.le_layer_num = le_layer_num | ||
|
||
self.lectin_embeddings = LectinStorage(lectin_encoder, le_layer_num) | ||
self.combined_dim = glycan_encoder.hidden_dim + EMBED_SIZES[lectin_encoder] | ||
|
||
self.head, self.loss, self.metrics = get_prediction_head(self.combined_dim, 1, "regression") | ||
|
||
def forward(self, data: HeteroData) -> dict[str, torch.Tensor]: | ||
glycan_node_embed = self.glycan_encoder(data) | ||
glycan_graph_embed = self.glycan_pooling(glycan_node_embed, data.batch_dict) | ||
lectin_embed = self.lectin_embeddings.query(data["aa_seq"]) | ||
combined = torch.cat([glycan_graph_embed, lectin_embed], dim=-1) | ||
pred = self.head(combined) | ||
|
||
return { | ||
"glycan_node_embed": glycan_node_embed, | ||
"glycan_graph_embed": glycan_graph_embed, | ||
"lectin_embed": lectin_embed, | ||
"pred": pred, | ||
} | ||
|
||
def shared_step(self, batch: HeteroData, stage: str) -> dict[str, torch.Tensor]: | ||
""" | ||
Compute the shared step of the model. | ||
Args: | ||
data: The data to process | ||
stage: The stage of the model | ||
Returns: | ||
A dictionary containing the loss and the metrics | ||
""" | ||
fwd_dict = self(batch) | ||
fwd_dict["labels"] = batch["y"] | ||
fwd_dict["loss"] = self.loss(fwd_dict["pred"], fwd_dict["label"]) | ||
self.metrics[stage].update(fwd_dict["pred"], fwd_dict["label"]) | ||
self.log(f"{stage}/loss", fwd_dict["loss"]) | ||
|
||
return fwd_dict | ||
|
||
def training_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]: | ||
"""Compute the training step of the model""" | ||
return self.shared_step(batch, "train") | ||
|
||
def validation_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]: | ||
"""Compute the validation step of the model""" | ||
return self.shared_step(batch, "val") | ||
|
||
def test_step(self, batch: HeteroData, batch_idx: int) -> dict[str, torch.Tensor]: | ||
"""Compute the testing step of the model""" | ||
return self.shared_step(batch, "test") | ||
|
||
def shared_end(self, stage: Literal["train", "val", "test"]): | ||
""" | ||
Compute the shared end of the model. | ||
Args: | ||
stage: The stage of the model | ||
""" | ||
metrics = self.metrics[stage].compute() | ||
self.log_dict(metrics) | ||
self.metrics[stage].reset() | ||
|
||
def on_train_epoch_end(self) -> None: | ||
"""Compute the end of the training epoch""" | ||
self.shared_end("train") | ||
|
||
def on_validation_epoch_end(self) -> None: | ||
"""Compute the end of the validation""" | ||
self.shared_end("val") | ||
|
||
def on_test_epoch_end(self) -> None: | ||
"""Compute the end of the testing""" | ||
self.shared_end("test") | ||
|
||
def configure_optimizers(self): | ||
"""Configure the optimizer and the learning rate scheduler of the model""" | ||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | ||
return { | ||
"optimizer": optimizer, | ||
"lr_scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5), | ||
"monitor": "val/loss", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import re | ||
|
||
import torch | ||
from transformers import T5EncoderModel, AutoTokenizer, AutoModel, BertTokenizer, T5Tokenizer | ||
|
||
|
||
class PLMEncoder: | ||
def __init__(self, layer_num: int): | ||
self.layer_num = layer_num | ||
|
||
def forward(self, seq: str) -> torch.Tensor: | ||
pass | ||
|
||
def __call__(self, *args, **kwargs): | ||
self.forward(*args, **kwargs) | ||
|
||
|
||
class Ankh(PLMEncoder): | ||
def __init__(self, layer_num: int): | ||
super().__init__(layer_num) | ||
self.tokenizer = AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base") | ||
self.model = T5EncoderModel.from_pretrained("ElnaggarLab/ankh-base") | ||
|
||
def forward(self, seq: str) -> torch.Tensor: | ||
outputs = self.tokenizer.batch_encode_plus( | ||
[list(seq)], | ||
add_special_tokens=True, | ||
padding=True, | ||
is_split_into_words=True, | ||
return_tensors="pt", | ||
) | ||
with torch.no_grad(): | ||
ankh = self.model( | ||
input_ids=outputs["input_ids"], | ||
attention_mask=outputs["attention_mask"], | ||
output_attentions=False, | ||
output_hidden_states=True, | ||
) | ||
return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=2)[0] | ||
|
||
|
||
class ESM(PLMEncoder): | ||
def __init__(self, layer_num: int): | ||
super().__init__(layer_num) | ||
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") | ||
self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") | ||
|
||
def forward(self, seq: str) -> torch.Tensor: | ||
a = self.tokenizer(seq) | ||
with torch.no_grad(): | ||
esm = self.model( | ||
torch.Tensor(a["input_ids"]).long().reshape(1, -1), | ||
torch.Tensor(a["attention_mask"]).long().reshape(1, -1), | ||
output_attentions=False, | ||
output_hidden_states=True, | ||
) | ||
return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] | ||
|
||
|
||
class ProtBERT(PLMEncoder): | ||
def __init__(self, layer_num: int): | ||
super().__init__(layer_num) | ||
self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) | ||
self.model = AutoModel.from_pretrained("Rostlab/prot_bert") | ||
|
||
def forward(self, seq: str) -> torch.Tensor: | ||
sequence_w_spaces = ' '.join(seq) | ||
encoded_input = self.tokenizer( | ||
sequence_w_spaces, | ||
return_tensors='pt' | ||
) | ||
with torch.no_grad(): | ||
protbert = self.model( | ||
**encoded_input, | ||
output_attentions=False, | ||
output_hidden_states=True, | ||
) | ||
return protbert.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] | ||
|
||
|
||
class ProstT5(PLMEncoder): | ||
def __init__(self, layer_num: int): | ||
super().__init__(layer_num) | ||
self.tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5") | ||
self.model = T5EncoderModel.from_pretrained("Rostlab/ProstT5") | ||
|
||
def forward(self, seq: str) -> torch.Tensor: | ||
seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in [seq]] | ||
seq = ["<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in seq] | ||
ids = self.tokenizer.batch_encode_plus( | ||
seq, | ||
add_special_tokens=True, | ||
padding="longest", | ||
return_tensors='pt' | ||
) | ||
with torch.no_grad(): | ||
prostt5 = self.model( | ||
ids.input_ids, | ||
attention_mask=ids.attention_mask, | ||
output_attentions=False, | ||
output_hidden_states=True, | ||
) | ||
return prostt5.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] | ||
|
||
|
||
class AMPLIFY(PLMEncoder): | ||
def __init__(self, layer_num: int): | ||
super().__init__(layer_num) | ||
self.tokenizer = AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True) | ||
self.model = AutoModel.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True).to("cuda") | ||
|
||
def forward(self, seq: str) -> torch.Tensor: | ||
a = self.tokenizer.encode(seq, return_tensors="pt").to("cuda") | ||
with torch.no_grad(): | ||
amplify = self.model( | ||
a, | ||
output_attentions=False, | ||
output_hidden_states=True, | ||
) | ||
pass | ||
|
||
|
||
ENCODER_MAP = { | ||
"Ankh": Ankh, | ||
"ESM": ESM, | ||
"ProtBert": ProtBERT, | ||
"ProstT5": ProstT5, | ||
"AMPLIFY": AMPLIFY, | ||
} | ||
|
||
EMBED_SIZES = { | ||
"Ankh": 768, | ||
"ESM": 1280, | ||
"ProtBert": 1024, | ||
"ProstT5": 1024, | ||
"AMPLIFY": ..., | ||
} | ||
|
||
MAX_LAYERS = { | ||
"Ankh": 49, | ||
"ESM": 34, | ||
"ProtBert": 31, | ||
"ProstT5": 25, | ||
"AMPLIFY": ..., | ||
} |
Oops, something went wrong.