-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
92 lines (64 loc) · 2.81 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import esm
import random
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pytorch_lightning as pl
from glob import glob
import numpy as np
import time
import pickle
import pandas as pd
from os import walk
from torch.utils.data import Dataset
class protein_dataset(Dataset):
# pandas dataframe -> dataset
def __init__(self, path):
df = pd.read_csv(path)
self.df = df
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
return self.df.iloc[idx]
class dataset_collator:
# processes the pandas dataframe row into model inputs and targets
def __init__(self):
alphabet = esm.Alphabet.from_architecture('ESM-1b')
self.batch_converter = alphabet.get_batch_converter()
def __call__(self, raw_batch):
batch = {}
proteins = []
peptides = []
msa_ids = []
peptide_classes = []
peptide_scores = []
_, _, pep_token = self.batch_converter([(0, raw_batch[0]['peptide_derived_sequence'])])
batch = {}
batch['peptide_source'] = pep_token
energy_list = raw_batch[0]['mean_score'][1:-1].replace('\n', '').split(' ')
energy_list = ' '.join(energy_list).split()
energy_list = [float(escore) for escore in energy_list]
energy_list = [float(escore) for escore in energy_list]
energy_list = [0 if ele > 0 else ele for ele in energy_list]
batch['peptide_scores'] = (torch.FloatTensor(energy_list) <= -1).unsqueeze(dim=1)
batch['peptide_scores'] = batch['peptide_scores'].type(torch.LongTensor)
batch['peptide_scores_value'] = (torch.FloatTensor(energy_list)/100.0).unsqueeze(dim=1)
batch['id'] = raw_batch[0]['new_id']
return batch
class ProteinPeptideDataModule(pl.LightningDataModule):
def __init__(self,
dataset_path):
super().__init__()
self.dataset_path = dataset_path
def setup(self, stage):
self.test_ds = protein_dataset(self.dataset_path+'testing_dataset_mean.csv')
self.val_ds = protein_dataset(self.dataset_path+'validation_dataset_mean.csv')
self.train_ds = protein_dataset(self.dataset_path+'training_dataset_mean.csv')
self.collator = dataset_collator()
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size = 1, collate_fn=self.collator, num_workers=8, drop_last=False, pin_memory = True, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size = 1, collate_fn=self.collator, num_workers=8, drop_last=False, pin_memory = True)
def test_dataloader(self):
return DataLoader(self.test_ds, batch_size = 1, collate_fn=self.collator, num_workers=8, drop_last=False, pin_memory = True)