-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
86 lines (72 loc) · 2.95 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
from pathlib import Path
from textwrap import dedent
import pandas as pd
import numpy as np
from pprint import pprint
from utils.metrics import compute_metrics
from utils.results import load_results
parser = argparse.ArgumentParser(
description="Evaluate model results on the validation set",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"results",
type=Path,
help=dedent(
"""\
Path to a results file with either a .pt suffix (loadable via `torch.load`) or
with a .pkl suffix (loadable via `pickle.load(open(path, 'rb'))`). The
loaded data should be in one of the following formats:
[
{
'interaction_output': np.ndarray of float32, shape [44],
'annotation_id': str, e.g. 'P01_101_1'
}, ... # repeated entries
]
or
{
'interaction_output': np.ndarray of float32, shape [N, 44],
'annotation_id': np.ndarray of str, shape [N,]
}
"""
),
)
parser.add_argument("labels", type=Path, help="Labels (pickled dataframe)")
parser.add_argument("--redact_background", action="store_true", help="Whether to redact background labels in the evaluation")
def collate(results):
return {k: [r[k] for r in results] for k in results[0].keys()}
def main(args):
labels: pd.DataFrame = pd.read_pickle(args.labels)
if "annotation_id" in labels.columns:
labels.set_index("annotation_id", inplace=True)
if args.redact_background:
redacted_ids = labels[labels["class"] == "background"].index.values
labels = labels["class_id"]
results = load_results(args.results)
if results["interaction_output"].ndim == 3: # Support multi-view evaluation
results["interaction_output"] = np.sum(results["interaction_output"], axis=1)
interaction_output = results["interaction_output"]
annotation_ids = results["annotation_id"]
if args.redact_background:
keep_indices = [i for i, a_id in enumerate(annotation_ids) if not a_id in redacted_ids]
interaction_output = results["interaction_output"][keep_indices]
annotation_ids = results["annotation_id"][keep_indices]
scores = {
"interaction": interaction_output,
}
metrics = compute_metrics(
labels.loc[annotation_ids],
scores
)
display_metrics = dict()
task_accuracies = metrics["accuracies"]["interaction"]
for k, task_accuracy in zip((1, 5), task_accuracies):
display_metrics[f"all_interaction_accuracy_at_{k}"] = task_accuracy
display_metrics[f"all_interaction_mCA"] = metrics["mCA"]
display_metrics[f"all_interaction_mAP"] = metrics["mAP"]
display_metrics = {metric: value * 100 for metric, value in display_metrics.items()}
display_metrics[f"all_interaction_mAUC"] = metrics["mAUC"]
pprint(display_metrics)
if __name__ == "__main__":
main(parser.parse_args())