Skip to content

Commit

Permalink
Code improvements to run on subsets of the original dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Oct 23, 2024
1 parent e00d74b commit 6e85dee
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 13 deletions.
32 changes: 32 additions & 0 deletions configs/lgi/all.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
seed: 42
#root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
#logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
#origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs
origin: /home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl
model:
glycan_encoder:
- name: gifflar
feat_dim: 128
hidden_dim: 1024
num_layers: 8
pooling: global_mean
- name: sweetnet
feat_dim: 128
hidden_dim: 1024
num_layers: 16
lectin_encoder:
- name: ESM
layer_num: 33
- name: Ankh
layer_num: 48
- name: ProtBert
layer_num: 30
- name: ProstT5
layer_num: 24
batch_size: 256
epochs: 100
learning_rate: 0.001
optimizer: Adam

9 changes: 6 additions & 3 deletions configs/lgi/full.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
seed: 42
root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
#root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
#logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
#origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs
origin: /home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl
model:
glycan_encoder:
name: gifflar
Expand Down
9 changes: 6 additions & 3 deletions configs/lgi/test.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
seed: 42
root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
origin: /home/daniel/Desktop/GIFFLAR/lgi_data.pkl
#root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
#logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
#origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs
origin: /home/rjo21/Desktop/GIFFLAR/lgi_data.pkl
model:
glycan_encoder:
name: gifflar
Expand Down
5 changes: 4 additions & 1 deletion experiments/aquire_lgi_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
import random

from tqdm import tqdm
import numpy as np
Expand All @@ -22,6 +23,8 @@
for i, ((aa_seq, iupac), val) in tqdm(enumerate(s.items())):
data.append((lectins[aa_seq], glycans[iupac], val, splits[i]))

with open("lgi_data_full.pkl", "wb") as f:
data = random.sample(data, int(len(data) * 0.20))

with open("lgi_data_20.pkl", "wb") as f:
pickle.dump((data, lectins, glycans), f)

34 changes: 31 additions & 3 deletions experiments/train_lgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from argparse import ArgumentParser
import time
Expand All @@ -19,15 +19,37 @@
from gifflar.train import setup
from gifflar.utils import read_yaml_config, hash_dict

import torch
print(torch.cuda.is_available())

GLYCAN_ENCODERS = {
"gifflar": GlycanGIN,
"sweetnet": SweetNetLightning,
}


def main(config):
kwargs = read_yaml_config(config)
def unfold_config(config: dict):
if isinstance(config["model"]["glycan_encoder"], dict):
ges = [config["model"]["glycan_encoder"]]
else:
ges = config["model"]["glycan_encoder"]
del config["model"]["glycan_encoder"]

if isinstance(config["model"]["lectin_encoder"], dict):
les = [config["model"]["lectin_encoder"]]
else:
les = config["model"]["lectin_encoder"]
del config["model"]["lectin_encoder"]

for le in les:
for ge in ges:
tmp_config = copy.deepcopy(config)
tmp_config["model"]["lectin_encoder"] = le
tmp_config["model"]["glycan_encoder"] = ge
yield tmp_config


def train(**kwargs):
kwargs["pre-transforms"] = {"GIFFLARTransform": "", "SweetNetTransform": ""}
kwargs["hash"] = hash_dict(kwargs["pre-transforms"])
seed_everything(kwargs["seed"])
Expand Down Expand Up @@ -63,6 +85,12 @@ def main(config):
print("Training took", time.time() - start, "s")


def main(config):
custom_args = read_yaml_config(config)
for args in unfold_config(custom_args):
train(**args)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("config", type=str, help="Path to YAML config file")
Expand Down
2 changes: 1 addition & 1 deletion gifflar/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from pathlib import Path
from typing import Any
Expand Down
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ rdkit>=2022
scikit-learn
numpy
pandas
glyles
git+https://github.com/BojarLab/glycowork.git@1123e1cee9f189ea5e82c6c1cda9764749b6da36
glyles>=1.0.0
git+https://github.com/BojarLab/glycowork.git@145877da82cf0200b062945dfb85584b9d9ef30d
jsonargparse
rich
pytorch-lightning
pytest
pyyaml
networkx
torchmetrics
transformers
sentencepiece
xformers==0.0.28.post1

0 comments on commit 6e85dee

Please sign in to comment.