Skip to content

Commit

Permalink
Bug fixes for LGI predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
GlycanConnector committed Oct 23, 2024
1 parent 35cf483 commit e00d74b
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 48 deletions.
19 changes: 19 additions & 0 deletions configs/lgi/full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
seed: 42
root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
model:
glycan_encoder:
name: gifflar
feat_dim: 128
hidden_dim: 1024
num_layers: 8
pooling: global_mean
lectin_encoder:
name: ESM
layer_num: 33
batch_size: 256
epochs: 100
learning_rate: 0.001
optimizer: Adam

17 changes: 9 additions & 8 deletions configs/lgi/test.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
seed: 42
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs
origin: path/to/dataset
root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
origin: /home/daniel/Desktop/GIFFLAR/lgi_data.pkl
model:
glycan_encoder:
name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
pooling: global_pool
pooling: global_mean
lectin_encoder:
name: esm
layer_num: 11
epochs: 100
name: ESM
layer_num: 33
batch_size: 256
epochs: 5
learning_rate: 0.001
optimizer: Adam

14 changes: 7 additions & 7 deletions experiments/aquire_lgi_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle

from tqdm import tqdm
import numpy as np
from glycowork.glycan_data.loader import glycan_binding as lgi

Expand All @@ -9,19 +10,18 @@
lgi.drop(columns=["target", "protein"], inplace=True)
s = lgi.stack()

glycans = {f"Gly{i:04d}": iupac for i, iupac in enumerate(lgi.columns[:-2])}
glycans.update({iupac: f"Gly{i:04d}" for i, iupac in enumerate(lgi.columns[:-2])})
glycans = {f"Gly{i:04d}": iupac for i, iupac in enumerate(lgi.columns)}
glycans.update({iupac: f"Gly{i:04d}" for i, iupac in enumerate(lgi.columns)})

lectins = {f"Lec{i:04d}": aa_seq for i, aa_seq in enumerate(lgi.index)}
lectins.update({aa_seq: f"Lec{i:04d}" for i, aa_seq in enumerate(lgi.index)})

# Convert the Series to a list of triplets (row, col, val)
data = []
splits = np.random.choice(s.index, len(s))
for i, ((aa_seq, iupac), val) in enumerate(s.items()):
splits = np.random.choice(["train", "val", "test"], len(s), p=[0.7, 0.2, 0.1])
for i, ((aa_seq, iupac), val) in tqdm(enumerate(s.items())):
data.append((lectins[aa_seq], glycans[iupac], val, splits[i]))
if i == 1000:
break

with open("lgi_data.pkl", "wb") as f:
with open("lgi_data_full.pkl", "wb") as f:
pickle.dump((data, lectins, glycans), f)

33 changes: 24 additions & 9 deletions experiments/lgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, lectin_encoder: str, le_layer_num: int, path: str | None = No
otherwise, such file will be created.
"""
self.path = Path(path or "data") / f"{lectin_encoder}_{le_layer_num}.pkl"
print("Path:", self.path.resolve())
self.encoder = ENCODER_MAP[lectin_encoder](le_layer_num)
self.data = self._load()

Expand All @@ -31,8 +32,13 @@ def query(self, aa_seq: str) -> torch.Tensor:
self.data[aa_seq] = self.encoder(aa_seq)
except:
self.data[aa_seq] = None

return self.data[aa_seq]

def batch_query(self, aa_seqs) -> torch.Tensor:
# print([self.query(aa_seq) for aa_seq in aa_seqs])
return torch.stack([self.query(aa_seq) for aa_seq in aa_seqs])


class LGI_Model(LightningModule):
def __init__(
Expand Down Expand Up @@ -62,18 +68,26 @@ def __init__(

self.head, self.loss, self.metrics = get_prediction_head(self.combined_dim, 1, "regression")

def to(self, device: torch.device):
super(LGI_Model, self).to(device)
self.glycan_encoder.to(device)
self.glycan_pooling.to(device)
self.head.to(device)
for split, metric in self.metrics.items():
self.metrics[split] = metric.to(device)

def forward(self, data: HeteroData) -> dict[str, torch.Tensor]:
glycan_node_embed = self.glycan_encoder(data)
glycan_graph_embed = self.glycan_pooling(glycan_node_embed, data.batch_dict)
lectin_embed = self.lectin_embeddings.query(data["aa_seq"])
lectin_embed = self.lectin_embeddings.batch_query(data["aa_seq"])
combined = torch.cat([glycan_graph_embed, lectin_embed], dim=-1)
pred = self.head(combined)

return {
"glycan_node_embed": glycan_node_embed,
"glycan_graph_embed": glycan_graph_embed,
"lectin_embed": lectin_embed,
"pred": pred,
"glycan_node_embeds": glycan_node_embed,
"glycan_graph_embeds": glycan_graph_embed,
"lectin_embeds": lectin_embed,
"preds": pred,
}

def shared_step(self, batch: HeteroData, stage: str) -> dict[str, torch.Tensor]:
Expand All @@ -88,10 +102,11 @@ def shared_step(self, batch: HeteroData, stage: str) -> dict[str, torch.Tensor]:
A dictionary containing the loss and the metrics
"""
fwd_dict = self(batch)
fwd_dict["labels"] = batch["y"]
fwd_dict["loss"] = self.loss(fwd_dict["pred"], fwd_dict["label"])
self.metrics[stage].update(fwd_dict["pred"], fwd_dict["label"])
self.log(f"{stage}/loss", fwd_dict["loss"])
fwd_dict["labels"] = batch["y"]# .reshape(-1)
fwd_dict["preds"] = fwd_dict["preds"].reshape(-1)
fwd_dict["loss"] = self.loss(fwd_dict["preds"], fwd_dict["labels"])
self.metrics[stage].update(fwd_dict["preds"], fwd_dict["labels"])
self.log(f"{stage}/loss", fwd_dict["loss"], batch_size=len(fwd_dict["preds"]))

return fwd_dict

Expand Down
16 changes: 9 additions & 7 deletions experiments/protein_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def forward(self, seq: str) -> torch.Tensor:
pass

def __call__(self, *args, **kwargs):
self.forward(*args, **kwargs)
return self.forward(*args, **kwargs)


class Ankh(PLMEncoder):
Expand All @@ -36,25 +36,27 @@ def forward(self, seq: str) -> torch.Tensor:
output_attentions=False,
output_hidden_states=True,
)
return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=2)[0]
return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=1)[0]


class ESM(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.device)

def forward(self, seq: str) -> torch.Tensor:
a = self.tokenizer(seq)
with torch.no_grad():
esm = self.model(
torch.Tensor(a["input_ids"]).long().reshape(1, -1),
torch.Tensor(a["attention_mask"]).long().reshape(1, -1),
torch.Tensor(a["input_ids"]).long().reshape(1, -1).to(self.device),
torch.Tensor(a["attention_mask"]).long().reshape(1, -1).to(self.device),
output_attentions=False,
output_hidden_states=True,
)
return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0]
print(seq)
return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=1)[0]


class ProtBERT(PLMEncoder):
Expand Down Expand Up @@ -142,4 +144,4 @@ def forward(self, seq: str) -> torch.Tensor:
"ProtBert": 31,
"ProstT5": 25,
"AMPLIFY": ...,
}
}
16 changes: 11 additions & 5 deletions experiments/train_lgi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from argparse import ArgumentParser
import time

Expand Down Expand Up @@ -36,21 +39,24 @@ def main(config):
)

# set up the logger
logger = CSVLogger(kwargs["logs_dir"], name=kwargs["model"]["name"] + (kwargs["model"].get("suffix", None) or ""))
glycan_model_name = kwargs["model"]["glycan_encoder"]["name"] + (kwargs["model"]["glycan_encoder"].get("suffix", None) or "")
lectin_model_name = kwargs["model"]["lectin_encoder"]["name"] + (kwargs["model"]["lectin_encoder"].get("suffix", None) or "")
logger = CSVLogger(kwargs["logs_dir"], name="LGI_" + glycan_model_name + lectin_model_name)
logger.log_hyperparams(kwargs)

glycan_encoder = GLYCAN_ENCODERS[kwargs["model"]["glycan_encoder"]["name"]](**kwargs["model"]["glycan_encoder"])
model = LGI_Model(
glycan_encoder,
kwargs["model"]["lectin_encoder"]["name"],
kwargs["model"]["lectin_encoder"]["le_layer_num"],
kwargs["model"]["lectin_encoder"]["layer_num"],
)

model.to("cuda")

trainer = Trainer(
callbacks=[RichProgressBar(), RichModelSummary()],
logger=logger,
max_epochs=kwargs["model"]["max_epochs"],
accelerator="cpu",
max_epochs=kwargs["model"]["epochs"],
accelerator="gpu",
)
start = time.time()
trainer.fit(model, datamodule)
Expand Down
4 changes: 2 additions & 2 deletions gifflar/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def process(self) -> None:
"""Process the data and store it."""
print("Start processing")
data = {k: [] for k in self.splits}
with open(self.filename, "r") as f:
with open(self.filename, "rb") as f:
inter, lectin_map, glycan_map = pickle.load(f)

# Load the glycan storage to speed up the preprocessing
Expand All @@ -224,7 +224,7 @@ def process(self) -> None:
if d is None:
continue
d["aa_seq"] = lectin_map[lectin_id]
d["y"] = torch.tensor(value)
d["y"] = torch.tensor([value])
d["ID"] = i
data[split].append(d)

Expand Down
16 changes: 8 additions & 8 deletions gifflar/data/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,18 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]
# Include data for the baselines and other kwargs for house-keeping
baselines = {"gnngly", "sweetnet", "rgcn"}
kwargs = {key: [] for key in dict(data[0]) if all(b not in key for b in baselines)}
# print([d["y"] for d in data])

# Store the node counts to offset edge indices when collating
node_counts = {node_type: [0] for node_type in node_types}
for d in data:
for key in kwargs:
# Collect all length-queryable fields
if not hasattr(d[key], "__len__") or len(d[key]) != 0:
kwargs[key].append(d[key])
try:
if not hasattr(d[key], "__len__") or len(d[key]) != 0:
kwargs[key].append(d[key])
except:
pass

# Compute the offsets for each node type for sample identification after batching
for node_type in node_types:
Expand Down Expand Up @@ -129,6 +133,8 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]

# For each baseline, collate its node features and edge indices as well
for b in baselines:
if not f"{b}_num_nodes" in data[0]:
continue
kwargs[f"{b}_x"] = torch.cat([d[f"{b}_x"] for d in data], dim=0)
edges = []
batch = []
Expand All @@ -145,18 +151,12 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]
kwargs[f"{b}_edge_index"] = torch.cat(edges, dim=1)
kwargs[f"{b}_batch"] = torch.cat(batch, dim=0)
if b == "rgcn":
#for d in data:
# print(d["rgcn_x"].shape)
# print(d["rgcn_edge_type"].shape)
kwargs["rgcn_edge_type"] = torch.tensor(e_types)
kwargs["rgcn_node_type"] = n_types
if hasattr(data[0], "rgcn_rw_pe"):
kwargs["rgcn_rw_pe"] = torch.cat([d["rgcn_rw_pe"] for d in data], dim=0)
if hasattr(data[0], "rgcn_lap_pe"):
kwargs["rgcn_lap_pe"] = torch.cat([d["rgcn_lap_pe"] for d in data], dim=0)
#print(kwargs["rgcn_x"].shape)
#print(kwargs["rgcn_edge_type"].shape)
#print(len(kwargs["rgcn_node_type"]))

# Remove all incompletely given data and concat lists of tensors into single tensors
num_nodes = {node_type: x_dict[node_type].shape[0] for node_type in node_types}
Expand Down
4 changes: 2 additions & 2 deletions gifflar/model/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal, Optional, Any

import torch
from glycowork.glycan_data.loader import lib
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(self, feat_dim: int, hidden_dim: int, num_layers: int, batch_size:
("monosacchs", "boundary", "monosacchs")
]
}))

self.hidden_dim = hidden_dim
self.batch_size = batch_size

def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
Expand Down

0 comments on commit e00d74b

Please sign in to comment.