-
Notifications
You must be signed in to change notification settings - Fork 111
/
evaluate_results.py
62 lines (50 loc) · 1.79 KB
/
evaluate_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import json
import numpy as np
from src.utils.evaluation_utils import compare_answer
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--file", type=str, default=None, help="results file to evaluate"
)
argparser.add_argument("--detail", action="store_true", help="print wrong answers")
argparser.add_argument("--k", type=int, default=None)
# file has to be a json file that is formatted as a list of dictionaries
# that contain the following keys:
# - label: the correct answer (unnormalized)
# - answer: the answer given by the model (unnormalized)
args = argparser.parse_args()
file = args.file
with open(file, "r") as f:
results = json.load(f)
is_corrects = []
all_times = []
for i, (idx, x) in enumerate(results.items()):
if args.k is not None and i >= args.k:
break
is_correct = compare_answer(x["answer"], x["label"])
if args.detail and not is_correct:
print(i, x["answer"], "<>", x["label"])
is_corrects.append(is_correct)
all_times.append(x["time"])
num_correct = sum(is_corrects)
N = len(results) if args.k is None else args.k
print(f"Results")
print(f"Raw: {num_correct} / {N} = {num_correct / N}")
# compute mean and std of times
print(f"Mean time: {np.mean(all_times)}")
print(f"Std time: {np.std(all_times)}")
input_tokens = 0
output_tokens = 0
total_examples = 0
for i, (id, example) in enumerate(results.items()):
if args.k is not None and i >= args.k:
break
if "stats" not in example:
break
total_examples += 1
total = example["stats"]["total"]
input_tokens += total["input_tokens"]
output_tokens += total["output_tokens"]
if total_examples > 0:
print(f"Average input tokens: {input_tokens / total_examples}")
print(f"Average output tokens: {output_tokens / total_examples}")