-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IndexError in generate protein notebook #152
Comments
Hi, thanks for your interest in our models! :) |
Hi @mheinzinger thanks for your help. I'm interested in generating protein sequences. Perhaps you could help me navigate my options. I started with the provided notebook only because that's where users tend to go when they are trying to get something up and running. I didn't use ProtT5 because it is not an example in Since you highly recommend ProtT5, I propose we remove XLNet generate script and replace it with one that both works and uses a recommended model. If you could help me modify the notebook, I'd be happy to make a PR. Does this sound like a plan to you? If so, here's what I've done so far: import torch
-from transformers import XLNetLMHeadModel, XLNetTokenizer,pipeline
+from transformers import T5Tokenizer, pipeline, T5ForConditionalGeneration
import re
import os
import requests
from tqdm.auto import tqdm
-tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)
+tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
-model = XLNetLMHeadModel.from_pretrained("Rostlab/prot_xlnet")
+# TypeError: T5ForConditionalGeneration.__init__() got an unexpected keyword argument 'do_lower_case'
+model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_uniref50")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()
sequences_Example = "A E T C Z A O"
sequences_Example = re.sub(r"[UZOB]", "<unk>", sequences_Example)
ids = tokenizer.encode(sequences_Example, add_special_tokens=False)
input_ids = torch.tensor(ids).unsqueeze(0).to(device)
max_length = 100
temperature = 1.0
k = 0
p = 0.9
repetition_penalty = 1.0
num_return_sequences = 3
output_ids = model.generate(
input_ids=input_ids,
max_length=max_length,
temperature=temperature,
top_k=k,
top_p=p,
repetition_penalty=repetition_penalty,
do_sample=True,
num_return_sequences=num_return_sequences,
)
output_sequences = [" ".join(" ".join(tokenizer.decode(output_id)).split()) for output_id in output_ids]
print('Generated Sequences\n')
for output_sequence in output_sequences:
print(output_sequence)
+# Output:
+# < p a d > A E T C P A < / s >
+# < p a d > A E T C P A < / s >
+# < p a d > A E T C R A < / s > Since the T5 model supports conditional generation, I think it would be nice to see a few examples for how that works. |
You are perfectly right, unfortunately, I did not find the time yet to update the examples. |
Hello,
Thanks for the great tool. I'm excited to use ProtTrans for generating protein sequences, but I'm getting an index error in the example notebook (https://github.com/agemagician/ProtTrans/blob/master/Generate/ProtXLNet.ipynb).
The error occurs when running cell 12:
Here's the full traceback:
This has occurred on two different sets of hardware.
Thanks for taking a look.
Evan
The text was updated successfully, but these errors were encountered: