-
Notifications
You must be signed in to change notification settings - Fork 4
/
plot.py
103 lines (83 loc) · 3.85 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob2
import seaborn as sns
def find_pareto_points(obtained_scores, threshold=0.02):
n = len(obtained_scores)
if n == 1:
return obtained_scores
pareto_index = []
high_low = np.max(obtained_scores, axis=0) - np.min(obtained_scores, axis=0)
for i in range(n):
if not any(np.all((obtained_scores - obtained_scores[i] - threshold * high_low) > 0.0, axis=1)):
pareto_index.append(i)
points = obtained_scores[np.array(pareto_index)]
arg_index = np.argsort(points[:, 0])
points = points[arg_index]
print(points)
sorted_index = [0]
remaining_index = np.ones(len(points))
i = 0
remaining_index[i] = 0
while sum(remaining_index):
distance = ((points[np.where(remaining_index)] - points[i]) ** 2 ).sum(axis=1)
min_index = np.where(remaining_index > 0)[0][np.argmin(distance)]
sorted_index.append(min_index)
i = min_index
remaining_index[i] = 0
return points[np.array(sorted_index)]
index = 1
colors = sns.color_palette('Paired')
def plot_points(dir, label, style='-*', color='b', shift=[0,0], txt_color='black', normalize_path=None, reverse=True):
threshold = 0.01
desired_scores = []
obtained_scores = []
paths = [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*.csv'))]
paths += [os.path.abspath(path) for path in glob2.glob(os.path.join(dir, '*', '*.csv'))]
pref_lis = []
for path in paths:
if '.csv' in path:
full_path = path
data = pd.read_csv(full_path)
# morlhf has less points, let the threshold larger to make the frontier better
if 'ppo' in path and len(paths) <= 5:
threshold = 0.5
obtained_scores.append([np.mean(data['obtained_score1']), np.mean(data['obtained_score2'])])
if 'pref' in path:
# get the preference
if 'eval_data_pref' in path:
pref = path.split('eval_data_pref')[-1].strip().split('_')[0]
pref_lis.append(float(pref))
print(pref_lis)
desired_scores = np.array(desired_scores)
obtained_scores = np.array(obtained_scores)
if normalize_path is not None:
norm_info = np.load(normalize_path)
norm_info = np.array(norm_info).reshape(2, 2)
for i in range(2):
obtained_scores[:, i] = (obtained_scores[:, i] - norm_info[i][0]) / norm_info[i][1]
global index
markersize = 10 if ('*' in style or 'o' in style) else 9
pareto_points = find_pareto_points(obtained_scores, threshold)
plt.scatter(obtained_scores[:, 0], obtained_scores[:, 1], marker=style[-1], color=colors[index], s=markersize + 60)
if len(pref_lis):
for i in range(len(obtained_scores)):
plt.annotate('{}'.format(round(pref_lis[i], 1)), (obtained_scores[i, 0] + shift[0], obtained_scores[i, 1] + shift[1]), size=4, color=txt_color)
plt.plot(pareto_points[:, 0], pareto_points[:, 1], style, c=colors[index], markersize=markersize, label=label)
index += 2
plt.figure(figsize=(5, 4))
name1 = 'harmless'
name2 = 'helpful'
### replace the paths to your own paths
plot_points('./logs_trl/eval_pretrained', 'Llama 2 base', '*')
plot_points('./logs_trl/eval_sft_alldata', 'SFT', '*')
plot_points('./eval_ppo_pref/', 'MORLHF', '--D', shift=[-0.012, -0.022])
plot_points('./logs_ppo/eval_pposoups_llamma2_klreg0.2', 'Rewarded Soups', style='--s', shift=[-0.012, -0.022])
plot_points('.logs_trl/evalnew_onlinefix_helpful_harmlesshelpful_iter2', 'RiC', style='-o', shift=[-0.012, -0.022], txt_color='white')
plt.xlabel('$R_1$ ({})'.format(name1), fontsize=12)
plt.ylabel('$R_2$ ({})'.format(name2), fontsize=12)
plt.legend(fontsize=11, loc='lower left')
plt.tight_layout()
plt.savefig('ric_assistant_{}_{}.pdf'.format(name1, name2))