-
Notifications
You must be signed in to change notification settings - Fork 7
/
alert.py
executable file
·116 lines (84 loc) · 3.67 KB
/
alert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn.functional as F
import torch
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch.optim as optim
import os
from matplotlib import pyplot as plt
import random
from scipy.special import expit
from sklearn.metrics import auc, roc_curve, f1_score, recall_score, precision_score, roc_auc_score, confusion_matrix
class GanAlert(object):
def __init__(self, discriminator, args, CONFIG=None, generator=None):
self.args = args
self.scores = []
self.labels = []
# for vis
self.imgs = []
self.targets = []
self.discriminator = discriminator
self.generator = generator
self.CONFIG = CONFIG
self.early_stop = CONFIG.early_stop if CONFIG is not None else 200
# training set with batch size 1
self.train_loader = torch.utils.data.DataLoader(CONFIG.train_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)
def collect(self, dataloader):
assert self.generator is not None
self.generator.eval()
all_disc = []
for i, (img, label) in enumerate(dataloader):
B = img.shape[0]
img = img.to(self.CONFIG.device)
label = label.to(self.CONFIG.device)
out = self.generator(img)
fake_v = self.discriminator(out['recon'])
disc = list(fake_v.detach().cpu().numpy())
all_disc += disc
if i >= self.early_stop:
break
# calculate stats
return np.mean(all_disc), np.std(all_disc)
def evaluate(self, scores, labels, collect=True):
# calculate mean/std on training set?
if collect:
mean, std = self.collect(self.train_loader)
else:
mean = 0.
std = 1.
results = self.alert(scores, labels, mean, std, print_info=False)
return results
def alert(self, scores, labels, mean=0., std=1., print_info=True):
scores = np.array(scores)
labels = np.array(labels)
scores = (scores - mean) / std
scores = 1. - expit(scores) # 1 is anomaly!!
best_acc = -1
best_t = 0
fpr, tpr, thres = roc_curve(labels, scores)
auc_score = auc(fpr, tpr) * 100.
for t in thres:
prediction = np.zeros_like(scores)
prediction[scores >= t] = 1
# metrics
f1 = f1_score(labels, prediction) * 100.
acc = np.average(prediction == labels) * 100.
recall = recall_score(labels, prediction) * 100.
precision = precision_score(labels, prediction, labels=np.unique(prediction)) * 100.
tn, fp, fn, tp = confusion_matrix(labels, prediction).ravel()
specificity = (tn / (tn+fp)) * 100.
if acc > best_acc:
best_t = t
best_acc = acc
results = dict(threshold=t, auc=auc_score, acc=acc, f1=f1, recall=recall, precision=precision, specificity=specificity)
if print_info:
print('threshold: %.2f, auc: %.2f, acc: %.2f, f1: %.2f, recall(sens): %.2f, prec: %.2f, spec: %.2f' % (t, auc_score, acc, f1, recall, precision, specificity))
if print_info:
print('[BEST] threshold: %.2f, auc: %.2f, acc: %.2f, f1: %.2f, recall(sens): %.2f, prec: %.2f, spec: %.2f' % (results['threshold'], results['auc'], results['acc'], results['f1'], results['recall'], results['precision'], results['specificity']))
return results