-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_helpers.py
58 lines (49 loc) · 1.84 KB
/
test_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from graph_helpers import has_vertex
def check_tree_samples(qs, c, trees, every=1):
# make sure the tree sampels are updated
# exclude the last query
for i, q in enumerate(qs[:-1]):
if i % every == 0:
for t in trees:
if c[q] >= 0:
if isinstance(t, set):
assert q in t
else:
assert has_vertex(t, q)
else:
if isinstance(t, set):
assert q not in t
else:
assert not has_vertex(t, q)
def check_error_esitmator(qs, c, est, every=1):
# make sure the tree sampels are updated
# exclude the last query
for i, q in enumerate(qs[:-1]):
if i % every == 0:
if c[q] >= 0:
# infected
assert est._m[q, :].sum() == est.n_col
else:
# uninfected
assert est._m[q, :].sum() == 0
assert (est.n_row, est.n_col) == est._m.shape
def check_samples_so_far(g, sampler, estimator, obs_inf, obs_uninf):
assert len(sampler.samples) == sampler.n_samples
for v in obs_inf:
for t in sampler.samples:
assert isinstance(t, set), 'should be set'
assert v in t, 'should be in sample'
assert estimator._m[v, :].sum() == estimator.n_col
assert estimator._m.shape == (g.num_vertices(), sampler.n_samples)
for v in obs_uninf:
for t in sampler.samples:
assert isinstance(t, set), 'should be set'
assert v not in t, 'should be in sample'
assert estimator._m[v, :].sum() == 0
def check_probas_so_far(probas, inf, uninf):
# print(inf)
# print(uninf)
for v in inf:
assert probas[v] == 1.0
for v in uninf:
assert probas[v] == 0.0