-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
54 lines (45 loc) · 1.69 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import os
import torch.nn.functional as F
from components.model import GPT, GPTConfig, CPU_Unpickler
from components.tokenizer import Tokenizer
from argparse import ArgumentParser
# Argument parsing
parser = ArgumentParser()
parser.add_argument("--weight", type=str, default="wikidata_ct1280")
parser.add_argument("--chat", type=bool, default=False)
args, leftovers = parser.parse_known_args()
model = GPT(GPTConfig)
tokenizer = Tokenizer()
savepath = os.path.join(".", "logs", args.weight, "log.pkl")
with open(savepath, "rb") as filehandler:
model.load_state_dict(CPU_Unpickler(filehandler).load()["best_weight"])
model.eval()
while True:
# conv = (
# [76]
# + tokenizer.encode(
# "You are an AI assistant. You will be given a task. You must generate a detailed and long answer."
# )
# + [77]
# + tokenizer.encode("\n")
# )
conv = []
while True:
print("Input: ", end="")
ques = input()
ques = tokenizer.encode(ques)
ques = [78] + ques + [79] + tokenizer.encode("\n") + [80]
# Add the question with the previous conversation so that GPT knows
# what happened before
conv = conv + ques
conv_len = len(conv)
# print("-+-+"*20)
# print("".join(tokenizer.decode(conv)))
# print("-+-+"*20)
ques = torch.tensor(conv, dtype=torch.long).unsqueeze(0)
ans = model.generate(ques, max_new_tokens=256, temperature=0.5).detach()[0].tolist()
decoded_ans = "".join(tokenizer.decode(ans[conv_len:]))
print(decoded_ans)
# Add the answer to the conversation
conv = ans + [81] + tokenizer.encode("\n")