-
Notifications
You must be signed in to change notification settings - Fork 2
/
bayes_net.py
175 lines (134 loc) · 4.64 KB
/
bayes_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Classes to represent Bayesian networks.
"""
from toposort import toposort_flatten as flatten
from misc import dict_to_string
from scipy import rand, prod
class BayesNet(object):
"""
Object to hold and evaluate probabilities
for BayesNets described by cpts.
"""
def __init__(self, graph, cpt):
"""
graph: dictionary of form {child:{parent1, parent2},}.
Expects {None} for the parents of root nodes.
cpt: dictionary, holds conditional probabilities for the
nodes of the network. In general, the form is expected to be:
cpt = {child: prior_probs, child:{parent+value: probs,
parent+value: probs}}
Example:
cpt = {"A":[0.2,0.8], B:{"A0":..., "A1":...}}
The values next to the node name correspond to the values
of the parent node.
"""
self.nodes = flatten(graph)
self.nodes = self.nodes[1::]
self.graph = graph
self.cpt = cpt
def joint_prob(self, node_values):
"""
Calculate the joint probability of an instantiation
of the graph.
Input:
node_values: dictionary, assumed to be of form {node1:value1,
node2:value2, ...}
"""
result = 1.0
for node in node_values:
if self.is_root_node(node):
# root node
result *= self.prior(node, node_values[node])
else:
result *= self.cond_prob(node, node_values[node], node_values)
return result
def cond_prob(self, child, state, all_vals):
"""
Evaluates the conditional probability
P(child = state | all_vals)
by looking up the values from the Icpt table.
"""
parents = {key: int(all_vals[key]) for key in self.graph[child]}
key = dict_to_string(parents)
result = self.cpt[child][key][state]
return result
def prior(self, node, value):
"""
Returns the prior of a root node.
"""
result = None
if self.is_root_node(node):
result = self.cpt[node][value]
return result
def sample(self, set_nodes={}):
"""
Generate single sample from BN.
This only assumes binary variables.
"""
# sample all but the already set nodes
nodes = (n for n in self.nodes if n not in set_nodes)
sample = set_nodes.copy()
for node in nodes:
if self.is_root_node(node):
p = self.prior(node, True)
else:
p = self.cond_prob(node, True, sample)
sample[node] = int(rand() < p)
return sample
def msample(self, num_of_samples=100):
"""
Generate multiple samples.
"""
samples = [None] * num_of_samples
for i in range(num_of_samples):
samples[i] = self.sample()
return samples
def is_root_node(self, node):
result = (self.graph[node] == {None})
return result
class BNNoisyORLeaky(BayesNet):
"""
"""
def __init__(self, graph, lambdas, prior, clipper=False):
"""
graph: dictionary of form {child:{parent1, parent2},}.
Expects {None} for the parents of root nodes.
lambdas: dictionary, contains the lambdas from each node.
Format assumed to be {node:{leak_node:value, parent_1:value}}
The values next to the node name correspond to the values
of the parent node.
"""
self.nodes = flatten(graph)
self.nodes = self.nodes[1::]
self.graph = graph
self.lambdas = lambdas
self.prior_dict = prior
self.clipper = clipper
def prior(self, node, node_value):
if node_value is True:
result = self.prior_dict[node][1]
else:
result = self.prior_dict[node][0]
return result
def cond_prob(self, child, value, all_vals):
"""
Evaluates the conditional probability
P(child=value|all_vals)
for a noisy-or model.
Arguments:
child: name of child nodes
value: boole, value of child node
parents: dict, {Parent1:boole1, Parent2:boole2, ...}
"""
rel_lambdas = self.lambdas[child]
result = 1
n = len(rel_lambdas)
for key in rel_lambdas:
if key != "leak_node": # TODO need to make sure this is parsed correctly by the cpt
if all_vals[key] == int(True):
result *= rel_lambdas[key]
# # multiply by leak node
result *= rel_lambdas["leak_node"]
if value == int(True):
result = 1 - result
return result