diff --git a/configs/lgi/full.yaml b/configs/lgi/full.yaml new file mode 100644 index 0000000..0c4dc7d --- /dev/null +++ b/configs/lgi/full.yaml @@ -0,0 +1,19 @@ +seed: 42 +root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data +logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs +origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl +model: + glycan_encoder: + name: gifflar + feat_dim: 128 + hidden_dim: 1024 + num_layers: 8 + pooling: global_mean + lectin_encoder: + name: ESM + layer_num: 33 + batch_size: 256 + epochs: 100 + learning_rate: 0.001 + optimizer: Adam + diff --git a/configs/lgi/test.yaml b/configs/lgi/test.yaml index ac9ad40..9f3f06d 100644 --- a/configs/lgi/test.yaml +++ b/configs/lgi/test.yaml @@ -1,18 +1,19 @@ seed: 42 -root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data -logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs -origin: path/to/dataset +root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data +logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs +origin: /home/daniel/Desktop/GIFFLAR/lgi_data.pkl model: glycan_encoder: name: gifflar feat_dim: 128 hidden_dim: 1024 - batch_size: 256 num_layers: 8 - pooling: global_pool + pooling: global_mean lectin_encoder: - name: esm - layer_num: 11 - epochs: 100 + name: ESM + layer_num: 33 + batch_size: 256 + epochs: 5 learning_rate: 0.001 optimizer: Adam + diff --git a/experiments/aquire_lgi_dataset.py b/experiments/aquire_lgi_dataset.py index 9773696..22af1e7 100644 --- a/experiments/aquire_lgi_dataset.py +++ b/experiments/aquire_lgi_dataset.py @@ -1,5 +1,6 @@ import pickle +from tqdm import tqdm import numpy as np from glycowork.glycan_data.loader import glycan_binding as lgi @@ -9,19 +10,18 @@ lgi.drop(columns=["target", "protein"], inplace=True) s = lgi.stack() -glycans = {f"Gly{i:04d}": iupac for i, iupac in enumerate(lgi.columns[:-2])} -glycans.update({iupac: f"Gly{i:04d}" for i, iupac in enumerate(lgi.columns[:-2])}) +glycans = {f"Gly{i:04d}": iupac for i, iupac in enumerate(lgi.columns)} +glycans.update({iupac: f"Gly{i:04d}" for i, iupac in enumerate(lgi.columns)}) lectins = {f"Lec{i:04d}": aa_seq for i, aa_seq in enumerate(lgi.index)} lectins.update({aa_seq: f"Lec{i:04d}" for i, aa_seq in enumerate(lgi.index)}) # Convert the Series to a list of triplets (row, col, val) data = [] -splits = np.random.choice(s.index, len(s)) -for i, ((aa_seq, iupac), val) in enumerate(s.items()): +splits = np.random.choice(["train", "val", "test"], len(s), p=[0.7, 0.2, 0.1]) +for i, ((aa_seq, iupac), val) in tqdm(enumerate(s.items())): data.append((lectins[aa_seq], glycans[iupac], val, splits[i])) - if i == 1000: - break -with open("lgi_data.pkl", "wb") as f: +with open("lgi_data_full.pkl", "wb") as f: pickle.dump((data, lectins, glycans), f) + diff --git a/experiments/lgi_model.py b/experiments/lgi_model.py index 96918ff..102b14e 100644 --- a/experiments/lgi_model.py +++ b/experiments/lgi_model.py @@ -22,6 +22,7 @@ def __init__(self, lectin_encoder: str, le_layer_num: int, path: str | None = No otherwise, such file will be created. """ self.path = Path(path or "data") / f"{lectin_encoder}_{le_layer_num}.pkl" + print("Path:", self.path.resolve()) self.encoder = ENCODER_MAP[lectin_encoder](le_layer_num) self.data = self._load() @@ -31,8 +32,13 @@ def query(self, aa_seq: str) -> torch.Tensor: self.data[aa_seq] = self.encoder(aa_seq) except: self.data[aa_seq] = None + return self.data[aa_seq] + def batch_query(self, aa_seqs) -> torch.Tensor: + # print([self.query(aa_seq) for aa_seq in aa_seqs]) + return torch.stack([self.query(aa_seq) for aa_seq in aa_seqs]) + class LGI_Model(LightningModule): def __init__( @@ -62,18 +68,26 @@ def __init__( self.head, self.loss, self.metrics = get_prediction_head(self.combined_dim, 1, "regression") + def to(self, device: torch.device): + super(LGI_Model, self).to(device) + self.glycan_encoder.to(device) + self.glycan_pooling.to(device) + self.head.to(device) + for split, metric in self.metrics.items(): + self.metrics[split] = metric.to(device) + def forward(self, data: HeteroData) -> dict[str, torch.Tensor]: glycan_node_embed = self.glycan_encoder(data) glycan_graph_embed = self.glycan_pooling(glycan_node_embed, data.batch_dict) - lectin_embed = self.lectin_embeddings.query(data["aa_seq"]) + lectin_embed = self.lectin_embeddings.batch_query(data["aa_seq"]) combined = torch.cat([glycan_graph_embed, lectin_embed], dim=-1) pred = self.head(combined) return { - "glycan_node_embed": glycan_node_embed, - "glycan_graph_embed": glycan_graph_embed, - "lectin_embed": lectin_embed, - "pred": pred, + "glycan_node_embeds": glycan_node_embed, + "glycan_graph_embeds": glycan_graph_embed, + "lectin_embeds": lectin_embed, + "preds": pred, } def shared_step(self, batch: HeteroData, stage: str) -> dict[str, torch.Tensor]: @@ -88,10 +102,11 @@ def shared_step(self, batch: HeteroData, stage: str) -> dict[str, torch.Tensor]: A dictionary containing the loss and the metrics """ fwd_dict = self(batch) - fwd_dict["labels"] = batch["y"] - fwd_dict["loss"] = self.loss(fwd_dict["pred"], fwd_dict["label"]) - self.metrics[stage].update(fwd_dict["pred"], fwd_dict["label"]) - self.log(f"{stage}/loss", fwd_dict["loss"]) + fwd_dict["labels"] = batch["y"]# .reshape(-1) + fwd_dict["preds"] = fwd_dict["preds"].reshape(-1) + fwd_dict["loss"] = self.loss(fwd_dict["preds"], fwd_dict["labels"]) + self.metrics[stage].update(fwd_dict["preds"], fwd_dict["labels"]) + self.log(f"{stage}/loss", fwd_dict["loss"], batch_size=len(fwd_dict["preds"])) return fwd_dict diff --git a/experiments/protein_encoding.py b/experiments/protein_encoding.py index 18362f3..87236e5 100644 --- a/experiments/protein_encoding.py +++ b/experiments/protein_encoding.py @@ -12,7 +12,7 @@ def forward(self, seq: str) -> torch.Tensor: pass def __call__(self, *args, **kwargs): - self.forward(*args, **kwargs) + return self.forward(*args, **kwargs) class Ankh(PLMEncoder): @@ -36,25 +36,27 @@ def forward(self, seq: str) -> torch.Tensor: output_attentions=False, output_hidden_states=True, ) - return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=2)[0] + return ankh.hidden_states[self.layer_num][:, :-1].mean(dim=1)[0] class ESM(PLMEncoder): def __init__(self, layer_num: int): super().__init__(layer_num) + self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") - self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") + self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.device) def forward(self, seq: str) -> torch.Tensor: a = self.tokenizer(seq) with torch.no_grad(): esm = self.model( - torch.Tensor(a["input_ids"]).long().reshape(1, -1), - torch.Tensor(a["attention_mask"]).long().reshape(1, -1), + torch.Tensor(a["input_ids"]).long().reshape(1, -1).to(self.device), + torch.Tensor(a["attention_mask"]).long().reshape(1, -1).to(self.device), output_attentions=False, output_hidden_states=True, ) - return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=2)[0] + print(seq) + return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=1)[0] class ProtBERT(PLMEncoder): @@ -142,4 +144,4 @@ def forward(self, seq: str) -> torch.Tensor: "ProtBert": 31, "ProstT5": 25, "AMPLIFY": ..., -} \ No newline at end of file +} diff --git a/experiments/train_lgi.py b/experiments/train_lgi.py index 6969387..1ca23c6 100644 --- a/experiments/train_lgi.py +++ b/experiments/train_lgi.py @@ -1,3 +1,6 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "1" + from argparse import ArgumentParser import time @@ -36,21 +39,24 @@ def main(config): ) # set up the logger - logger = CSVLogger(kwargs["logs_dir"], name=kwargs["model"]["name"] + (kwargs["model"].get("suffix", None) or "")) + glycan_model_name = kwargs["model"]["glycan_encoder"]["name"] + (kwargs["model"]["glycan_encoder"].get("suffix", None) or "") + lectin_model_name = kwargs["model"]["lectin_encoder"]["name"] + (kwargs["model"]["lectin_encoder"].get("suffix", None) or "") + logger = CSVLogger(kwargs["logs_dir"], name="LGI_" + glycan_model_name + lectin_model_name) logger.log_hyperparams(kwargs) glycan_encoder = GLYCAN_ENCODERS[kwargs["model"]["glycan_encoder"]["name"]](**kwargs["model"]["glycan_encoder"]) model = LGI_Model( glycan_encoder, kwargs["model"]["lectin_encoder"]["name"], - kwargs["model"]["lectin_encoder"]["le_layer_num"], + kwargs["model"]["lectin_encoder"]["layer_num"], ) - + model.to("cuda") + trainer = Trainer( callbacks=[RichProgressBar(), RichModelSummary()], logger=logger, - max_epochs=kwargs["model"]["max_epochs"], - accelerator="cpu", + max_epochs=kwargs["model"]["epochs"], + accelerator="gpu", ) start = time.time() trainer.fit(model, datamodule) diff --git a/gifflar/data/datasets.py b/gifflar/data/datasets.py index 8573653..c388dfa 100644 --- a/gifflar/data/datasets.py +++ b/gifflar/data/datasets.py @@ -214,7 +214,7 @@ def process(self) -> None: """Process the data and store it.""" print("Start processing") data = {k: [] for k in self.splits} - with open(self.filename, "r") as f: + with open(self.filename, "rb") as f: inter, lectin_map, glycan_map = pickle.load(f) # Load the glycan storage to speed up the preprocessing @@ -224,7 +224,7 @@ def process(self) -> None: if d is None: continue d["aa_seq"] = lectin_map[lectin_id] - d["y"] = torch.tensor(value) + d["y"] = torch.tensor([value]) d["ID"] = i data[split].append(d) diff --git a/gifflar/data/hetero.py b/gifflar/data/hetero.py index b891d2d..f656b4c 100644 --- a/gifflar/data/hetero.py +++ b/gifflar/data/hetero.py @@ -84,14 +84,18 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData] # Include data for the baselines and other kwargs for house-keeping baselines = {"gnngly", "sweetnet", "rgcn"} kwargs = {key: [] for key in dict(data[0]) if all(b not in key for b in baselines)} + # print([d["y"] for d in data]) # Store the node counts to offset edge indices when collating node_counts = {node_type: [0] for node_type in node_types} for d in data: for key in kwargs: # Collect all length-queryable fields - if not hasattr(d[key], "__len__") or len(d[key]) != 0: - kwargs[key].append(d[key]) + try: + if not hasattr(d[key], "__len__") or len(d[key]) != 0: + kwargs[key].append(d[key]) + except: + pass # Compute the offsets for each node type for sample identification after batching for node_type in node_types: @@ -129,6 +133,8 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData] # For each baseline, collate its node features and edge indices as well for b in baselines: + if not f"{b}_num_nodes" in data[0]: + continue kwargs[f"{b}_x"] = torch.cat([d[f"{b}_x"] for d in data], dim=0) edges = [] batch = [] @@ -145,18 +151,12 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData] kwargs[f"{b}_edge_index"] = torch.cat(edges, dim=1) kwargs[f"{b}_batch"] = torch.cat(batch, dim=0) if b == "rgcn": - #for d in data: - # print(d["rgcn_x"].shape) - # print(d["rgcn_edge_type"].shape) kwargs["rgcn_edge_type"] = torch.tensor(e_types) kwargs["rgcn_node_type"] = n_types if hasattr(data[0], "rgcn_rw_pe"): kwargs["rgcn_rw_pe"] = torch.cat([d["rgcn_rw_pe"] for d in data], dim=0) if hasattr(data[0], "rgcn_lap_pe"): kwargs["rgcn_lap_pe"] = torch.cat([d["rgcn_lap_pe"] for d in data], dim=0) - #print(kwargs["rgcn_x"].shape) - #print(kwargs["rgcn_edge_type"].shape) - #print(len(kwargs["rgcn_node_type"])) # Remove all incompletely given data and concat lists of tensors into single tensors num_nodes = {node_type: x_dict[node_type].shape[0] for node_type in node_types} diff --git a/gifflar/model/base.py b/gifflar/model/base.py index 0be0188..07df722 100644 --- a/gifflar/model/base.py +++ b/gifflar/model/base.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal, Optional, Any import torch from glycowork.glycan_data.loader import lib @@ -67,7 +67,7 @@ def __init__(self, feat_dim: int, hidden_dim: int, num_layers: int, batch_size: ("monosacchs", "boundary", "monosacchs") ] })) - + self.hidden_dim = hidden_dim self.batch_size = batch_size def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]: