Skip to content

Commit

Permalink
Add eval and tables (#64)
Browse files Browse the repository at this point in the history
* add 2qa

* save

* change prompt

* eval v2

* add tables

* add reviewer, prompt

* add reviews

* rename

* tables

* new line

* update

* update

* rename
  • Loading branch information
infwinston authored Mar 29, 2023
1 parent b63fdd6 commit bb558d8
Show file tree
Hide file tree
Showing 16 changed files with 1,016 additions and 2 deletions.
35 changes: 33 additions & 2 deletions chatserver/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,37 @@ def dict(self):
sep="###",
)

conv_v1_2 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)

conv_bair_v1 = Conversation(
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
Expand All @@ -112,9 +143,9 @@ def dict(self):
)


default_conversation = conv_v1
default_conversation = conv_v1_2
conv_templates = {
"v1": conv_v1,
"v1": conv_v1_2,
"bair_v1": conv_bair_v1,
}

Expand Down
67 changes: 67 additions & 0 deletions chatserver/eval/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import json
from tqdm import tqdm
import ray

from chatserver.conversation import default_conversation
from chatserver.utils import disable_torch_init

@ray.remote(num_gpus=1)
@torch.inference_mode()
def eval_model(model_name, questions_file, answers_file):
# Model
disable_torch_init()
model_name = os.path.expanduser(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16).cuda()


qa_file = open(os.path.expanduser(questions_file), "r")
ans_file = open(os.path.expanduser(answers_file), "w")
for i, line in enumerate(tqdm(qa_file)):
idx = json.loads(line)["id"]
qs = json.loads(line)["question"]
cat = json.loads(line)["category"]
conv = default_conversation.copy()
conv.append_message(conv.roles[0], qs)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
output_ids = model.generate(
torch.as_tensor(inputs.input_ids).cuda(),
do_sample=True,
temperature=0.7,
max_new_tokens=1024)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
try:
index = outputs.index(conv.sep, len(prompt))
except ValueError:
outputs += conv.sep
index = outputs.index(conv.sep, len(prompt))

outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
ans_file.write(json.dumps({"id": idx, "answer": outputs, "category": cat}) + "\n")
ans_file.flush()
ans_file.close()

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--questions-file", type=str, default="mini_evals/qa.jsonl")
parser.add_argument("--answers-file", type=str, default="answers.jsonl")
args = parser.parse_args()

ray.init()
handle = []
for i in range(1, 5):
model_name = args.model_name
model_name.replace('~/', '')
print(model_name)
question_file = f'mini_evals/qa_v2-{i}.jsonl'
answers_file = f'answers/v4/answers-v2-{i}.jsonl'
handle.append(eval_model.remote(model_name, question_file, answers_file))

results = ray.get(handle)
105 changes: 105 additions & 0 deletions chatserver/eval/eval_gpt_review.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import argparse
import json
import os

import openai
import tqdm
import ray
import time

@ray.remote(num_cpus=4)
def get_eval(content: str, max_tokens: int):
try:
response = openai.ChatCompletion.create(
model='gpt-4',
messages=[{
'role': 'system',
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
}, {
'role': 'user',
'content': content,
}],
temperature=0.2, # TODO: figure out which temperature is best for evaluation
max_tokens=max_tokens,
)
except Exception as e:
print(e)
return 'error'
return response['choices'][0]['message']['content']


def parse_score(review):
try:
score_pair = review.split('\n')[0]
score_pair = score_pair.replace(',', ' ')
sp = score_pair.split(' ')
if len(sp) == 2:
return [float(sp[0]), float(sp[1])]
else:
print('error', review)
return [-1, -1]
except Exception as e:
print(e)
print('error', review)
return [-1, -1]


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
parser.add_argument('-q', '--question')
# parser.add_argument('-a', '--answer')
parser.add_argument('-a', '--answer-list', nargs='+', default=[])
parser.add_argument('-r', '--rule')
parser.add_argument('-o', '--output')
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
args = parser.parse_args()

ray.init()

f_q = open(os.path.expanduser(args.question))
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))

review_file = open(f'{args.output}', 'w')

js_list = []
handles = []
idx = 0
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
# if idx == 10:
# break

ques = json.loads(ques_js)['question']
ans1 = json.loads(ans1_js)['answer']
ans2 = json.loads(ans2_js)['answer']

category = json.loads(ques_js)['category']
if category in rule_dict:
rule = rule_dict[category]
else:
rule = rule_dict['default']
prompt = rule['prompt']
role = rule['role']
content = (f'[Question]\n{ques}\n\n'
f'[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n'
f'[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n'
f'[System]\n{prompt}\n\n')
js_list.append({
'id': idx+1,
'question': ques,
'answer1': ans1,
'answer2': ans2,
'category': category})
idx += 1
handles.append(get_eval.remote(content, args.max_tokens))
# To avoid the rate limit set by OpenAI
time.sleep(10)

reviews = ray.get(handles)
for idx, review in enumerate(reviews):
scores = parse_score(review)
js_list[idx]['content'] = review
js_list[idx]['tuple'] = scores
review_file.write(json.dumps(js_list[idx]) + '\n')
review_file.close()
Loading

0 comments on commit bb558d8

Please sign in to comment.