Skip to content

Commit

Permalink
Minor run fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Oct 24, 2024
1 parent 6e85dee commit b2dba11
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
12 changes: 6 additions & 6 deletions configs/lgi/all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ model:
hidden_dim: 1024
num_layers: 8
pooling: global_mean
- name: sweetnet
feat_dim: 128
hidden_dim: 1024
num_layers: 16
#- name: sweetnet
# feat_dim: 128
# hidden_dim: 1024
# num_layers: 16
lectin_encoder:
- name: ESM
layer_num: 33
#- name: ESM
# layer_num: 33
- name: Ankh
layer_num: 48
- name: ProtBert
Expand Down
21 changes: 10 additions & 11 deletions experiments/protein_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class PLMEncoder:
def __init__(self, layer_num: int):
self.layer_num = layer_num
self.device = "cuda" if torch.cuda.is_available() else "cpu"

def forward(self, seq: str) -> torch.Tensor:
pass
Expand All @@ -19,7 +20,7 @@ class Ankh(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base")
self.model = T5EncoderModel.from_pretrained("ElnaggarLab/ankh-base")
self.model = T5EncoderModel.from_pretrained("ElnaggarLab/ankh-base").to(self.device)

def forward(self, seq: str) -> torch.Tensor:
outputs = self.tokenizer.batch_encode_plus(
Expand All @@ -31,8 +32,8 @@ def forward(self, seq: str) -> torch.Tensor:
)
with torch.no_grad():
ankh = self.model(
input_ids=outputs["input_ids"],
attention_mask=outputs["attention_mask"],
input_ids=outputs["input_ids"].to(self.device),
attention_mask=outputs["attention_mask"].to(self.device),
output_attentions=False,
output_hidden_states=True,
)
Expand All @@ -42,7 +43,6 @@ def forward(self, seq: str) -> torch.Tensor:
class ESM(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.device)

Expand All @@ -55,15 +55,14 @@ def forward(self, seq: str) -> torch.Tensor:
output_attentions=False,
output_hidden_states=True,
)
print(seq)
return esm.hidden_states[self.layer_num][:, 1:-1].mean(dim=1)[0]


class ProtBERT(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
self.model = AutoModel.from_pretrained("Rostlab/prot_bert")
self.model = AutoModel.from_pretrained("Rostlab/prot_bert").to(self.device)

def forward(self, seq: str) -> torch.Tensor:
sequence_w_spaces = ' '.join(seq)
Expand All @@ -84,7 +83,7 @@ class ProstT5(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5")
self.model = T5EncoderModel.from_pretrained("Rostlab/ProstT5")
self.model = T5EncoderModel.from_pretrained("Rostlab/ProstT5").to(self.device)

def forward(self, seq: str) -> torch.Tensor:
seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in [seq]]
Expand All @@ -97,8 +96,8 @@ def forward(self, seq: str) -> torch.Tensor:
)
with torch.no_grad():
prostt5 = self.model(
ids.input_ids,
attention_mask=ids.attention_mask,
ids.input_ids.to(self.device),
attention_mask=ids.attention_mask.to(self.device),
output_attentions=False,
output_hidden_states=True,
)
Expand All @@ -109,10 +108,10 @@ class AMPLIFY(PLMEncoder):
def __init__(self, layer_num: int):
super().__init__(layer_num)
self.tokenizer = AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True)
self.model = AutoModel.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True).to("cuda")
self.model = AutoModel.from_pretrained("chandar-lab/AMPLIFY_350M", trust_remote_code=True).to(self.device)

def forward(self, seq: str) -> torch.Tensor:
a = self.tokenizer.encode(seq, return_tensors="pt").to("cuda")
a = self.tokenizer.encode(seq, return_tensors="pt").to(self.device)
with torch.no_grad():
amplify = self.model(
a,
Expand Down
1 change: 1 addition & 0 deletions experiments/train_lgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from argparse import ArgumentParser
import time
import copy

from numpy.f2py.cfuncs import callbacks
from pytorch_lightning import Trainer
Expand Down

0 comments on commit b2dba11

Please sign in to comment.