-
Notifications
You must be signed in to change notification settings - Fork 54
/
dagmm.py
239 lines (193 loc) · 8.03 KB
/
dagmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
from dagmm.compression_net import CompressionNet
from dagmm.estimation_net import EstimationNet
from dagmm.gmm import GMM
from os import makedirs
from os.path import exists, join
class DAGMM:
""" Deep Autoencoding Gaussian Mixture Model.
This implementation is based on the paper:
Bo Zong+ (2018) Deep Autoencoding Gaussian Mixture Model
for Unsupervised Anomaly Detection, ICLR 2018
(this is UNOFFICIAL implementation)
"""
MODEL_FILENAME = "DAGMM_model"
SCALER_FILENAME = "DAGMM_scaler"
def __init__(self, comp_hiddens, comp_activation,
est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100,
learning_rate=0.0001, lambda1=0.1, lambda2=0.0001,
normalize=True, random_seed=123):
"""
Parameters
----------
comp_hiddens : list of int
sizes of hidden layers of compression network
For example, if the sizes are [n1, n2],
structure of compression network is:
input_size -> n1 -> n2 -> n1 -> input_sizes
comp_activation : function
activation function of compression network
est_hiddens : list of int
sizes of hidden layers of estimation network.
The last element of this list is assigned as n_comp.
For example, if the sizes are [n1, n2],
structure of estimation network is:
input_size -> n1 -> n2 (= n_comp)
est_activation : function
activation function of estimation network
est_dropout_ratio : float (optional)
dropout ratio of estimation network applied during training
if 0 or None, dropout is not applied.
minibatch_size: int (optional)
mini batch size during training
epoch_size : int (optional)
epoch size during training
learning_rate : float (optional)
learning rate during training
lambda1 : float (optional)
a parameter of loss function (for energy term)
lambda2 : float (optional)
a parameter of loss function
(for sum of diagonal elements of covariance)
normalize : bool (optional)
specify whether input data need to be normalized.
by default, input data is normalized.
random_seed : int (optional)
random seed used when fit() is called.
"""
self.comp_net = CompressionNet(comp_hiddens, comp_activation)
self.est_net = EstimationNet(est_hiddens, est_activation)
self.est_dropout_ratio = est_dropout_ratio
n_comp = est_hiddens[-1]
self.gmm = GMM(n_comp)
self.minibatch_size = minibatch_size
self.epoch_size = epoch_size
self.learning_rate = learning_rate
self.lambda1 = lambda1
self.lambda2 = lambda2
self.normalize = normalize
self.scaler = None
self.seed = random_seed
self.graph = None
self.sess = None
def __del__(self):
if self.sess is not None:
self.sess.close()
def fit(self, x):
""" Fit the DAGMM model according to the given data.
Parameters
----------
x : array-like, shape (n_samples, n_features)
Training data.
"""
n_samples, n_features = x.shape
if self.normalize:
self.scaler = scaler = StandardScaler()
x = scaler.fit_transform(x)
with tf.Graph().as_default() as graph:
self.graph = graph
tf.set_random_seed(self.seed)
np.random.seed(seed=self.seed)
# Create Placeholder
self.input = input = tf.placeholder(
dtype=tf.float32, shape=[None, n_features])
self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[])
# Build graph
z, x_dash = self.comp_net.inference(input)
gamma = self.est_net.inference(z, drop)
self.gmm.fit(z, gamma)
energy = self.gmm.energy(z)
self.x_dash = x_dash
# Loss function
loss = (self.comp_net.reconstruction_error(input, x_dash) +
self.lambda1 * tf.reduce_mean(energy) +
self.lambda2 * self.gmm.cov_diag_loss())
# Minimizer
minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)
# Number of batch
n_batch = (n_samples - 1) // self.minibatch_size + 1
# Create tensorflow session and initilize
init = tf.global_variables_initializer()
self.sess = tf.Session(graph=graph)
self.sess.run(init)
# Training
idx = np.arange(x.shape[0])
np.random.shuffle(idx)
for epoch in range(self.epoch_size):
for batch in range(n_batch):
i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size
x_batch = x[idx[i_start:i_end]]
self.sess.run(minimizer, feed_dict={
input:x_batch, drop:self.est_dropout_ratio})
if (epoch + 1) % 100 == 0:
loss_val = self.sess.run(loss, feed_dict={input:x, drop:0})
print(" epoch {}/{} : loss = {:.3f}".format(epoch + 1, self.epoch_size, loss_val))
# Fix GMM parameter
fix = self.gmm.fix_op()
self.sess.run(fix, feed_dict={input:x, drop:0})
self.energy = self.gmm.energy(z)
tf.add_to_collection("save", self.input)
tf.add_to_collection("save", self.energy)
self.saver = tf.train.Saver()
def predict(self, x):
""" Calculate anormaly scores (sample energy) on samples in X.
Parameters
----------
x : array-like, shape (n_samples, n_features)
Data for which anomaly scores are calculated.
n_features must be equal to n_features of the fitted data.
Returns
-------
energies : array-like, shape (n_samples)
Calculated sample energies.
"""
if self.sess is None:
raise Exception("Trained model does not exist.")
if self.normalize:
x = self.scaler.transform(x)
energies = self.sess.run(self.energy, feed_dict={self.input:x})
return energies
def save(self, fdir):
""" Save trained model to designated directory.
This method have to be called after training.
(If not, throw an exception)
Parameters
----------
fdir : str
Path of directory trained model is saved.
If not exists, it is created automatically.
"""
if self.sess is None:
raise Exception("Trained model does not exist.")
if not exists(fdir):
makedirs(fdir)
model_path = join(fdir, self.MODEL_FILENAME)
self.saver.save(self.sess, model_path)
if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
joblib.dump(self.scaler, scaler_path)
def restore(self, fdir):
""" Restore trained model from designated directory.
Parameters
----------
fdir : str
Path of directory trained model is saved.
"""
if not exists(fdir):
raise Exception("Model directory does not exist.")
model_path = join(fdir, self.MODEL_FILENAME)
meta_path = model_path + ".meta"
with tf.Graph().as_default() as graph:
self.graph = graph
self.sess = tf.Session(graph=graph)
self.saver = tf.train.import_meta_graph(meta_path)
self.saver.restore(self.sess, model_path)
self.input, self.energy = tf.get_collection("save")
if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
self.scaler = joblib.load(scaler_path)