-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmicrosoft_gpt2.py
141 lines (118 loc) · 4.45 KB
/
microsoft_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
from pprint import pprint
from typing import List
from tqdm import tqdm
from transformers import (
GPT2LMHeadModel,
GPT2Tokenizer,
)
import torch
from util import data_io, util_methods
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("USING: %s" % DEFAULT_DEVICE)
def build_gpt2_input(utterances: List[str]):
# assert all([isinstance(s,str) for s in utterances])
# pprint({k:u for k,u in enumerate(utterances)})
utts = [
tokenizer.encode(u + tokenizer.eos_token, return_tensors="pt")
for u in utterances
]
return torch.cat(utts, dim=-1).to(DEFAULT_DEVICE)
# batch = [[u + tokenizer.eos_token for u in utterances] for utterances in dialogues]
#
# return tokenizer.batch_encode_plus(
# batch, return_tensors="pt", truncation=True, padding="max_length",
# ).to(DEFAULT_DEVICE)
def topicalchat(
file_name="train",
data_path=os.environ["HOME"] + "/data/QA/topical-chat/processed_output",
limit=None,
):
backgrounds = data_io.read_lines(
os.path.join(data_path, file_name) + ".fct", limit=limit
)
dialogs = data_io.read_lines(
os.path.join(data_path, file_name) + ".src", limit=limit
)
targets = data_io.read_lines(
os.path.join(data_path, file_name) + ".tgt", limit=limit
)
for b, d, t in tqdm(zip(backgrounds, dialogs, targets)):
turns = d.split("_eos")[:-1] + [t.strip("_go").strip("_eos")]
yield turns[-3:]
def dialogue_test():
"""
User Does money buy happiness?
Bot Depends how much money you spend on it .
User What is the best way to buy happiness ?
Bot You just have to be a millionaire by your early 20s, then you can be happy .
User This is so difficult !
Bot You have no idea how hard it is to be a millionaire and happy . There is a reason the rich have a lot of money
"""
user_inputs = [
"Does money buy happiness?",
"What is the best way to buy happiness ?",
"This is so difficult !",
]
# Let's chat for 5 lines
for step, user_input in enumerate(user_inputs):
# encode the new user input, add the eos_token and return a tensor in Pytorch
# user_input = input(">> User:")
new_user_input_ids = tokenizer.encode(
user_input + tokenizer.eos_token, return_tensors="pt"
)
# append the new user input tokens to the chat history
bot_input_ids = (
torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
if step > 0
else new_user_input_ids
)
# generated a response while limiting the total chat history to 1000 tokens,
chat_history_ids = model.generate(
bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id
)
# print(tokenizer.decode(chat_history_ids))
# pretty print last ouput tokens from bot
print(chat_history_ids.shape)
print(
"DialoGPT: {}".format(
tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1] :][0],
skip_special_tokens=False,
)
)
)
"""
install transformers from 28th of April, cause microsofts GPT was committed this day
pip install git+https://github.com/huggingface/transformers.git@d714dfeaa8f019a634f2d565fc161f9b17fe85fb
"""
if __name__ == "__main__":
tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-large")
model = (
GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-large")
.to(DEFAULT_DEVICE)
.eval()
)
def answer(input):
# print(input.shape)
with torch.no_grad():
chat_history_ids = model.generate(
input, max_length=1000, pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.7,
# num_beams=3
)
output = tokenizer.decode(chat_history_ids[:, input.shape[-1]:][0],
skip_special_tokens=True, )
# print("OUTPUT: %s"%output)
return output
file_name = "valid_freq"
dialogues_g = topicalchat(
file_name=file_name,
data_path=os.environ["HOME"]
+ "/Response-Generation-Baselines/processed_output",
limit=None
)
g = (answer(build_gpt2_input(utts)) for utts in dialogues_g)
data_io.write_lines("microsoft-gpt2-%s.pred"%file_name, g)
# dialogue_test()