-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
161 lines (127 loc) · 4.88 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import numpy as np
import pandas as pd
import pickle as pkl
from graph_tool.all import load_graph
from glob import glob
from tqdm import tqdm
from utils import edges2graph
from sklearn.metrics import matthews_corrcoef
from feasibility import is_arborescence
def edge_order_accuracy(pred_edges, infection_times, return_count=False):
n_correct_edges = sum(1
for u, v in pred_edges
if infection_times[u] <= infection_times[v])
if return_count:
return n_correct_edges, len(pred_edges)
else:
return n_correct_edges / len(pred_edges)
def matthew_cc(true_set, pred_set, n):
y_true = np.zeros(n)
y_pred = np.zeros(n)
y_true[list(true_set)] = 1
y_pred[list(pred_set)] = 1
return matthews_corrcoef(y_true, y_pred)
# @profile
def evaluate_performance(g, root, source, pred_edges, obs_nodes, infection_times,
true_edges):
# change -1 to infinity (for order comparison)
# infection_times[infection_times == -1] = float('inf')
obs_set = set(obs_nodes)
true_nodes = {i for e in true_edges for i in e}
pred_nodes = {i for e in pred_edges for i in e}
common_nodes = true_nodes.intersection(pred_nodes)
# remove observations
true_nodes -= obs_set
pred_nodes -= obs_set
# mcc = matthew_cc(true_nodes, pred_nodes, g.num_vertices())
correct_nodes = true_nodes.intersection(pred_nodes)
try:
n_prec = len(correct_nodes) / len(pred_nodes)
except ZeroDivisionError:
n_prec = 0
n_rec = len(correct_nodes) / len(true_nodes)
obj = len(pred_edges)
pred_tree = edges2graph(g, pred_edges)
# root = next(v
# for v in pred_tree.vertices()
# if v.in_degree() == 0 and v.out_degree() > 0)
assert is_arborescence(pred_tree)
# pred_times = fill_missing_time(g, pred_tree, root, obs_nodes,
# infection_times, debug=False)
# consider only predicted nodes that are actual infections
# nodes = list(common_nodes)
# rank_corr = kendalltau(pred_times[nodes], infection_times[nodes])[0]
# common_edges = set(pred_edges).intersection(true_edges)
# e_prec = len(common_edges) / len(pred_edges)
# e_rec = len(common_edges) / len(true_edges)
# order accuracy on edge
edges = [e for e in pred_edges
if (e[0] in common_nodes and
e[1] in common_nodes)]
if len(edges) > 0:
order_accuracy = edge_order_accuracy(edges, infection_times)
else:
order_accuracy = 0.0
return (n_prec, n_rec, obj, order_accuracy)
def evaluate_from_result_dir(g, result_dir, qs):
for q in tqdm(qs):
rows = []
for p in glob(result_dir + "/{}/*.pkl".format(q)):
# print(p)
# TODO: add root
infection_times, source, obs_nodes, true_edges, pred_edges = pkl.load(open(p, 'rb'))
root = None
try:
scores = evaluate_performance(g, root, source, pred_edges, obs_nodes,
infection_times, true_edges)
except AssertionError:
import sys
print(p)
print(sys.exc_info()[0])
raise
rows.append(scores)
path = result_dir + "/{}.pkl".format(q)
if rows:
df = pd.DataFrame(rows, columns=['n.prec', 'n.rec',
'obj',
# 'e.prec', 'e.rec',
# 'rank-corr',
'order accuracy'
])
yield (path, df)
else:
if os.path.exists(path):
os.remove(path)
yield None
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gtype', required=True)
parser.add_argument('-l', '--model', required=True)
parser.add_argument('-m', '--method', required=True)
parser.add_argument('-q', '--qs', type=float, nargs="+")
parser.add_argument('-o', '--output_dir', default='outputs/paper_experiment')
args = parser.parse_args()
gtype = args.gtype
qs = args.qs
method = args.method
model = args.model
output_dir = args.output_dir
print("""graph: {}
model: {}
qs: {}
method: {}""".format(gtype, model, qs, method))
result_dir = "{output_dir}/{gtype}/{model}/{method}/qs".format(
output_dir=output_dir,
gtype=gtype,
model=model,
method=method)
g = load_graph('data/{}/graph.gt'.format(gtype))
for r in evaluate_from_result_dir(g, result_dir, qs):
if r:
path, df = r
print('writing to {}'.format(path))
df.describe().to_pickle(path)
else:
print('not result.')