-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathestimate_survey.py
176 lines (149 loc) · 5.3 KB
/
estimate_survey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
import json
import os
import typing
import numpy as np
import pandas as pd
from tqdm import tqdm
from lm_survey.samplers import AutoSampler, BaseSampler
from lm_survey.survey import Survey
def estimate_survey_costs(
sampler: BaseSampler,
survey_name: str,
*,
n_samples_per_dependent_variable: typing.Optional[int] = None,
n_top_mutual_info_dvs: typing.Optional[int] = None,
):
# TODO(vinhowe): fix this
survey_directory = survey_name
with open(
os.path.join(survey_directory, "independent-variables.json"), "r"
) as file:
independent_variable_names = json.load(file)
with open(os.path.join(survey_directory, "dependent-variables.json"), "r") as file:
dependent_variable_names = json.load(file)
data_filename = os.path.join(survey_directory, "responses.csv")
config_filename = os.path.join(survey_directory, "config.json")
survey = Survey(
name=survey_name,
data_filename=data_filename,
config_filename=config_filename,
independent_variable_names=independent_variable_names,
dependent_variable_names=dependent_variable_names,
)
if n_top_mutual_info_dvs is not None:
cached_mutual_info_stats_filename = os.path.join(
survey_directory, "cached_mutual_info_stats.csv"
)
if os.path.exists(cached_mutual_info_stats_filename):
mutual_info_stats = pd.read_csv(
cached_mutual_info_stats_filename, index_col=0
)
else:
mutual_info_stats = survey.mutual_info_stats()
mutual_info_stats.to_csv(cached_mutual_info_stats_filename)
# already sorted; get the first n_top_mutual_info_dvs from the index
dependent_variable_names = mutual_info_stats.index[:n_top_mutual_info_dvs]
# replace survey with a new one with only the top n_top_mutual_info_dvs
survey = Survey(
name=survey_name,
data_filename=data_filename,
config_filename=config_filename,
independent_variable_names=independent_variable_names,
dependent_variable_names=dependent_variable_names,
)
dependent_variable_samples = list(
survey.iterate(
n_samples_per_dependent_variable=n_samples_per_dependent_variable
)
)
# print random sample of prompts
# print(
# "\n===\n===\n===\n".join(
# np.random.choice(
# [
# dependent_variable_sample.prompt
# for dependent_variable_sample in dependent_variable_samples
# ],
# 10,
# )
# )
# )
prompt_count = len(dependent_variable_samples)
if hasattr(sampler, "batch_estimate_prompt_cost"):
completion_costs = sampler.batch_estimate_prompt_cost(
[
dependent_variable_sample.prompt
for dependent_variable_sample in dependent_variable_samples
]
)
else:
completion_costs = []
for dependent_variable_sample in tqdm(dependent_variable_samples):
completion_cost = sampler.estimate_prompt_cost(
dependent_variable_sample.prompt
)
completion_costs.append(completion_cost)
total_completion_cost = np.sum(completion_costs)
return {
"prompt_count": prompt_count,
"cost": total_completion_cost,
}
def main(
model_name: str,
survey_names: typing.List[str],
n_samples_per_dependent_variable: typing.Optional[int] = None,
n_top_mutual_info_dvs: typing.Optional[int] = None,
) -> None:
sampler = AutoSampler(model_name=model_name)
survey_costs = {}
for survey_name in tqdm(survey_names):
estimate = estimate_survey_costs(
sampler=sampler,
survey_name=survey_name,
n_samples_per_dependent_variable=n_samples_per_dependent_variable,
n_top_mutual_info_dvs=n_top_mutual_info_dvs,
)
survey_costs[survey_name] = estimate
total_cost = sum([estimate["cost"] for estimate in survey_costs.values()])
total_prompt_count = sum(
[estimate["prompt_count"] for estimate in survey_costs.values()]
)
if len(survey_names) > 1:
print(f"Cost per survey:")
for survey_name, survey_cost in survey_costs.items():
print(
f"{survey_name}: ${(survey_cost['cost'] / 100):.2f} ({survey_cost['prompt_count']}"
" prompts)"
)
print(f"Total cost: ${(total_cost / 100):.2f} ({total_prompt_count} prompts)")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
type=str,
required=True,
)
parser.add_argument(
"-n",
"--n_samples_per_dependent_variable",
type=int,
)
parser.add_argument(
"--n_top_mutual_info_dvs",
type=int,
)
# Positional argument for survey dir(s)
parser.add_argument(
"survey_name",
nargs="+",
type=str,
)
args = parser.parse_args()
main(
model_name=args.model_name,
survey_names=args.survey_name,
n_samples_per_dependent_variable=args.n_samples_per_dependent_variable,
n_top_mutual_info_dvs=args.n_top_mutual_info_dvs,
)