-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathroot_sampler.py
100 lines (80 loc) · 2.7 KB
/
root_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
from graph_helpers import pagerank_scores
from graph_tool.topology import shortest_distance
from tqdm import tqdm
def build_root_sampler_by_pagerank_score(g, obs, c, weights=None, eps=0.0):
# print('DEBUG: build_root_sampler_by_pagerank_score: eps={}'.format(eps))
pr_score = pagerank_scores(g, obs, eps, weights)
# print(g)
# print('len(obs): ', len(obs))
# print(pr_score)
nodes = np.arange(len(pr_score)) # shapes should be consistent
def aux():
return np.random.choice(nodes, size=1, p=pr_score)[0]
return aux
def build_true_root_sampler(c):
source = np.nonzero(c == 0)[0][0]
def aux():
return source
return aux
def build_min_dist_sampler(g, obs, weights=None, log=False):
"""minimum distance root sampler
"""
# print('min_dist: select root')
obs_set = set(obs)
min_dist = 9999999999
best_root = None
iters = g.vertices()
if log:
iters = tqdm(iters, total=g.num_vertices())
for v in iters:
v = int(v)
if v not in obs_set:
dist = shortest_distance(
g,
source=v,
target=obs,
weights=weights,
pred_map=False)
if dist.sum() < min_dist:
min_dist = dist.sum()
best_root = v
assert best_root is not None
def aux():
return best_root
return aux
def build_out_degree_root_sampler(g, power=2):
out_deg = np.power(g.degree_property_map('out').a, 2)
out_deg_norm = out_deg / out_deg.sum()
def aux():
return np.random.choice(g.num_vertices(), p=out_deg_norm)
return aux
def get_value_or_raise(d, key):
if key in d:
v = d[key]
del d[key]
return v
else:
raise KeyError('`{}` is not there'.format(d))
def get_root_sampler_by_name(name, **kwargs):
if name == 'true_root':
c = get_value_or_raise(kwargs, 'c')
assert c is not None, 'cascade `c` should be give'
return build_true_root_sampler(c)
elif name == 'pagerank':
g = get_value_or_raise(kwargs, 'g')
obs = get_value_or_raise(kwargs, 'obs')
c = get_value_or_raise(kwargs, 'c')
weights = get_value_or_raise(kwargs, 'weights')
return build_root_sampler_by_pagerank_score(
g, obs, c, weights=weights,
**kwargs)
elif name is None:
return None
elif name == 'min_dist':
g = get_value_or_raise(kwargs, 'g')
obs = get_value_or_raise(kwargs, 'obs')
weights = get_value_or_raise(kwargs, 'weights')
return build_min_dist_sampler(g, obs, weights=weights)
else:
raise ValueError('valid name ', name)