-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
76 lines (60 loc) · 2.67 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import joblib
def summarize_trial(agents):
### 更改判断方式,让不是correct的统统判到incorrect,从而输出log
# correct = [a for a in agents if a.is_correct()]
# incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
correct,incorrect = [],[]
for a in agents:
if a.is_correct():
correct.append(a)
else:
incorrect.append(a)
return correct, incorrect
def remove_fewshot(prompt: str) -> str:
# 中文化,对应的是promptsmed.py
prefix = prompt.split('下面是一些示例:')[0]
suffix = prompt.split('(示例结束)')[1]
return prefix.strip('\n').strip() + '\n' + suffix.strip('\n').strip()
def log_trial(agents, trial_n):
correct, incorrect = summarize_trial(agents)
log = f"""
########################################
BEGIN TRIAL {trial_n}
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}
#######################################
"""
log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
for agent in correct:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
for agent in incorrect:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
return log
def summarize_react_trial(agents):
correct = [a for a in agents if a.is_correct()]
halted = [a for a in agents if a.is_halted()]
incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
return correct, incorrect, halted
def log_react_trial(agents, trial_n):
correct, incorrect, halted = summarize_react_trial(agents)
log = f"""
########################################
BEGIN TRIAL {trial_n}
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}
#######################################
"""
log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
for agent in correct:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
for agent in incorrect:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
log += '------------- BEGIN HALTED AGENTS -----------\n\n'
for agent in halted:
log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'
return log
def save_agents(agents, dir: str):
os.makedirs(dir, exist_ok=True)
for i, agent in enumerate(agents):
joblib.dump(agent, os.path.join(dir, f'{i}.joblib'))