Skip to content

Commit

Permalink
Code for MLSB poster graphic
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Dec 3, 2024
1 parent 4bf8bc4 commit 524fc99
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 19 deletions.
8 changes: 4 additions & 4 deletions configs/lgi/all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ model:
hidden_dim: 1024
num_layers: 8
pooling: global_mean
- name: sweetnet
feat_dim: 128
hidden_dim: 1024
num_layers: 16
#- name: sweetnet
# feat_dim: 128
# hidden_dim: 1024
# num_layers: 16
lectin_encoder:
- name: ESM
layer_num: 33
Expand Down
32 changes: 32 additions & 0 deletions configs/lgi/sweetnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
seed: 42
#root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
#logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
#origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs
origin: /home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl
model:
glycan_encoder:
- name: gifflar
feat_dim: 128
hidden_dim: 1024
num_layers: 8
pooling: global_mean
#- name: sweetnet
# feat_dim: 128
# hidden_dim: 1024
# num_layers: 16
lectin_encoder:
- name: ESM
layer_num: 33
- name: Ankh
layer_num: 48
- name: ProtBert
layer_num: 30
- name: ProstT5
layer_num: 24
batch_size: 256
epochs: 100
learning_rate: 0.001
optimizer: Adam

104 changes: 104 additions & 0 deletions experiments/lectinoracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from tqdm import tqdm
import torch
from torch_geometric.data import Batch
from torch_geometric.data.data import Data
from torch_geometric.loader import DataLoader
from glycowork.ml.model_training import train_model, SAM
from glycowork.ml.models import prep_model

from gifflar.data.modules import LGI_GDM
from gifflar.data.datasets import GlycanOnDiskDataset
from experiments.lgi_model import LectinStorage

le = LectinStorage("ESM", 33)

class LGI_OnDiskDataset(GlycanOnDiskDataset):
@property
def processed_file_names(self):
"""Return the list of processed file names."""
return [split + ".db" for split in ["train", "val", "test"]]


def get_ds(dl, split_idx: int):
ds = LGI_OnDiskDataset(root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=split_idx)
data = []
for x in tqdm(dl):
data.append(Data(
labels=x["sweetnet_x"],
y=x["y"],
edge_index=x["sweetnet_edge_index"],
aa_seq=x["aa_seq"][0],
))
if len(data) == 100:
ds.extend(data)
del data
data = []
if len(data) != 0:
ds.extend(data)
del data


def collate_lgi(data):
for d in data:
d["train_idx"] = le.query(d["aa_seq"])

offset = 0
labels, edges, y, train_idx, batch = [], [], [], [], []
for i, d in enumerate(data):
labels.append(d["labels"])
edges.append(torch.stack([
d["edge_index"][0] + offset,
d["edge_index"][1] + offset,
]))
offset += len(d["labels"])
y.append(d["y"])
train_idx.append(le.query(d["aa_seq"]))
batch += [i for _ in range(len(d["labels"]))]

labels = torch.cat(labels, dim=0)
edges = torch.cat(edges, dim=1)
y = torch.stack(y)
train_idx = torch.stack(train_idx)
batch = torch.tensor(batch)

return Batch(
labels=labels,
edge_index=edges,
y=y,
train_idx=train_idx,
batch=batch,
)

datamodule = LGI_GDM(
root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data", filename="/home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl", hash_code="8b34af2a",
batch_size=1, transform=None, pre_transform={"GIFFLARTransform": "", "SweetNetTransform": ""},
)

#get_ds(datamodule.train_dataloader(), 0)
#get_ds(datamodule.val_dataloader(), 1)
#get_ds(datamodule.test_dataloader(), 2)

train_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=0)
val_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=1)

model = prep_model("LectinOracle", num_classes=1)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

m = train_model(
model=model,
dataloaders={"train": torch.utils.data.DataLoader(train_set, batch_size=128, collate_fn=collate_lgi),
"val": torch.utils.data.DataLoader(val_set, batch_size=128, collate_fn=collate_lgi)},
criterion=torch.nn.MSELoss(),
optimizer=optimizer,
scheduler=scheduler,
return_metrics=True,
mode="regression",
num_epochs=100,
patience=100,
)

import pickle

with open("lectinoracle_metrics.pkl", "wb") as f:
pickle.dump(m, f)
6 changes: 5 additions & 1 deletion experiments/lgi_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import Any, Literal
import copy

import torch
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -31,7 +32,8 @@ def query(self, aa_seq: str) -> torch.Tensor:
if aa_seq not in self.data:
try:
self.data[aa_seq] = self.encoder(aa_seq)
except:
except Exception as e:
print(e)
self.data[aa_seq] = None

return self.data[aa_seq]
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(
kwargs: Additional arguments
"""
super().__init__()

self.glycan_encoder = glycan_encoder
self.glycan_pooling = GIFFLARPooling("global_mean")
self.lectin_encoder = lectin_encoder
Expand Down Expand Up @@ -140,6 +143,7 @@ def on_train_epoch_end(self) -> None:
def on_validation_epoch_end(self) -> None:
"""Compute the end of the validation"""
self.shared_end("val")
self.lectin_embeddings.close()

def on_test_epoch_end(self) -> None:
"""Compute the end of the testing"""
Expand Down
18 changes: 9 additions & 9 deletions experiments/protein_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def forward(self, seq: str) -> torch.Tensor:
outputs = self.tokenizer.batch_encode_plus(
[list(seq)],
add_special_tokens=True,
padding=True,
is_split_into_words=True,
return_tensors="pt",
)
Expand All @@ -37,7 +36,7 @@ def forward(self, seq: str) -> torch.Tensor:
output_attentions=False,
output_hidden_states=True,
)
return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=1)[0]
return ankh.hidden_states[self.layer_num][:, 1:].mean(dim=1)[0]


class ESM(PLMEncoder):
Expand Down Expand Up @@ -68,15 +67,17 @@ def forward(self, seq: str) -> torch.Tensor:
sequence_w_spaces = ' '.join(seq)
encoded_input = self.tokenizer(
sequence_w_spaces,
return_tensors='pt'
return_tensors='pt',
)
with torch.no_grad():
protbert = self.model(
**encoded_input,
input_ids=encoded_input["input_ids"].to(self.device),
attention_mask=encoded_input["attention_mask"].to(self.device),
token_type_ids=encoded_input["token_type_ids"].to(self.device),
output_attentions=False,
output_hidden_states=True,
)
return protbert.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0]
return protbert.hidden_states[self.layer_num][:, 1:-1].mean(dim=1)[0]


class ProstT5(PLMEncoder):
Expand All @@ -87,12 +88,11 @@ def __init__(self, layer_num: int):

def forward(self, seq: str) -> torch.Tensor:
seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in [seq]]
seq = ["<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in seq]
seq = ["<AA2fold>" + " " + s.upper() for s in seq]
ids = self.tokenizer.batch_encode_plus(
seq,
add_special_tokens=True,
padding="longest",
return_tensors='pt'
return_tensors='pt',
)
with torch.no_grad():
prostt5 = self.model(
Expand All @@ -101,7 +101,7 @@ def forward(self, seq: str) -> torch.Tensor:
output_attentions=False,
output_hidden_states=True,
)
return prostt5.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0]
return prostt5.hidden_states[self.layer_num][:, 1:-1].mean(dim=1)[0]


class AMPLIFY(PLMEncoder):
Expand Down
2 changes: 1 addition & 1 deletion gifflar/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def process_(self, data: list[HeteroData], path_idx: Path | str):
torch.save((data, self.dataset_args), self.processed_paths[path_idx])


class GlycanDataset(GlycanOnDeskDataset):
class GlycanDataset(GlycanOnDiskDataset):
def __init__(
self,
root: str | Path,
Expand Down
12 changes: 8 additions & 4 deletions gifflar/model/baselines/sweetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from torch import nn
from torch_geometric.nn import global_mean_pool, GraphConv
from torch_geometric.nn import global_mean_pool, GraphConv, BatchNorm
from glycowork.glycan_data.loader import lib

from gifflar.data.hetero import HeteroDataBatch
Expand Down Expand Up @@ -37,7 +37,11 @@ def __init__(

# Load the untrained model from glycowork
self.item_embedding = nn.Embedding(len(lib), hidden_dim)
self.layers = nn.Sequential(OrderedDict([(f"layer{l + 1}", GraphConv(hidden_dim, hidden_dim)) for l in range(num_layers)]))
layers = []
for l in range(num_layers):
layers.append((f"layer_{l + 1}_gc", GraphConv(hidden_dim, hidden_dim)))
layers.append((f"layer_{l + 1}_bn", BatchNorm(hidden_dim)))
self.layers = nn.Sequential(OrderedDict(layers))

if self.task is not None:
del self.head
Expand Down Expand Up @@ -71,8 +75,8 @@ def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
x = self.item_embedding(x)
x = x.squeeze(1)

for layer in self.layers:
x = layer(x, edge_index)
for l in range(0, len(self.layers), 2):
x = self.layers[l + 1](self.layers[l](x, edge_index))

graph_embed = global_mean_pool(x, batch_ids)
pred = None
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ torchmetrics
transformers
sentencepiece
xformers==0.0.28.post1
protobuf

0 comments on commit 524fc99

Please sign in to comment.