-
Notifications
You must be signed in to change notification settings - Fork 0
/
real_cascade_experiment.py
140 lines (114 loc) · 4.81 KB
/
real_cascade_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import numpy as np
import pandas as pd
import pickle as pkl
from tqdm import tqdm
from graph_tool import load_graph
from graph_tool.topology import label_largest_component
from joblib import Parallel, delayed
from paper_experiment import get_tree
from cascade import observe_cascade
from gt_utils import extract_edges
from utils import cascade_size, get_last_modified_date
from evaluate import edge_order_accuracy
def one_run(g, q, result_dir, i,
verbose):
result_path = result_dir + '/{}.pkl'.format(i)
try:
dt = get_last_modified_date(result_path)
if dt.year == 2018:
# we don't recalcualte
print('skipping {}'.format(result_path))
return
except FileNotFoundError:
pass
obs = observe_cascade(infection_times, source=None, q=q)
tree = get_tree(g, infection_times, source=None, obs_nodes=obs, method=method, verbose=verbose)
pred_edges = extract_edges(tree)
pkl.dump((obs, pred_edges),
open(result_path, 'wb'))
def run_k_runs(g, q, infection_times, method,
k, result_dir,
verbose=False):
Parallel(n_jobs=-1)(delayed(one_run)(g, q, result_dir, i,
verbose)
for i in tqdm(range(k), total=k))
def evaluate(pred_edges, infection_times, obs):
pred_nodes = set([i for e in pred_edges for i in e]) - set(obs)
true_nodes = set(np.nonzero(infection_times >= 0)[0]) - set(obs)
correct_nodes = pred_nodes.intersection(true_nodes)
# prec = len(correct_nodes) / len(pred_nodes)
# rec = len(correct_nodes) / len(true_nodes)
n_correct_edges, n_pred_edges = edge_order_accuracy(pred_edges, infection_times, return_count=True)
return (len(correct_nodes), len(pred_nodes), len(true_nodes),
n_correct_edges, n_pred_edges)
def evaluate_from_result_dir(result_dir, infection_times, k):
rows = []
paths = [result_dir + "/{}.pkl".format(i) for i in range(k)]
for p in paths:
print(p)
# TODO: add root
try:
obs, pred_edges = pkl.load(open(p, 'rb'))
scores = evaluate(pred_edges, infection_times, obs)
rows.append(scores)
except FileNotFoundError:
print(p, ' not found')
path = result_dir + ".pkl" # {q}.pkl
if rows:
df = pd.DataFrame(rows, columns=['n.correct_nodes',
'n.pred_nodes',
'n.true_nodes',
'n.correct_edges',
'n.pred_edges'])
return (path, df)
else:
if os.path.exists(path):
os.remove(path)
return None
if __name__ == '__main__':
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--cascade_id', required=True)
parser.add_argument('-m', '--method', required=True)
parser.add_argument('-q', '--report_proba', type=float, default=0.1)
parser.add_argument('-k', '--repeat_times', type=int, default=100)
parser.add_argument('-o', '--output_dir', default='output/real_cascade')
parser.add_argument('-s', '--small_cascade', action='store_true')
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('--evaluate', type=bool, default=False)
args = parser.parse_args()
g = load_graph('data/digg/graph.gt')
if args.small_cascade:
cascade_path = 'data/digg/small_cascade_{}.pkl'.format(args.cascade_id)
else:
cascade_path = 'data/digg/cascade_{}.pkl'.format(args.cascade_id)
print('cascade_path: ', cascade_path)
infection_times = pkl.load(open(cascade_path,
'rb'))
print('cascade size: ', len(np.nonzero(infection_times > 0)[0]))
q = args.report_proba
k = args.repeat_times
method = args.method
output_dir = args.output_dir
result_dir = os.path.join(output_dir, method, "{}".format(q))
if not os.path.exists(result_dir):
os.makedirs(result_dir)
if not args.evaluate:
print('run experiment...', 'q=', q, ', method=', method, 'cascade: ', args.cascade_id,
'cascade size: ', cascade_size(infection_times))
print(g)
print(sum(label_largest_component(g).a))
run_k_runs(g, q, infection_times, method, k, result_dir, verbose=args.verbose)
else:
print('evaluate...')
path, df = evaluate_from_result_dir(result_dir,
infection_times=infection_times,
k=k)
print('writing to {}'.format(path))
if args.small_cascade:
df.to_pickle(path)
else:
summary = df.describe()
print(summary)
summary.to_pickle(path)