-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf-explain-visualize.py
77 lines (67 loc) · 2.58 KB
/
tf-explain-visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import random
from argparse import ArgumentParser
from os import path
import numpy as np
import pandas as pd
import tf_explain
from cv2 import imread, cvtColor, resize, COLOR_BGR2RGB
from progress.bar import Bar
from tensorflow.keras.models import load_model
from utils.constants import CLASSES, IMG_DIMENSIONS
parser = ArgumentParser()
parser.add_argument("-m", '--model', required=True, help=f"specify model ")
parser.add_argument("-c", "--classindex", default=0,type=int)
parser.add_argument("-f", "--classfilter", default=-1,type=int)
parser.add_argument("-l", "--limit", default=25,type=int)
parser.add_argument('-r', '--random', action='store_true')
parser.add_argument('-d', '--dataset', default="dataset")
args = vars(parser.parse_args())
metadata = pd.read_csv(path.join(args['dataset'], 'metadata.csv'), usecols=['File', 'No Finding', 'Covid'],
dtype={'File': np.str, 'No Finding': np.bool, 'Covid': np.bool})
data = []
labels = []
count = 0
with Bar('Loading images', max=len(metadata)) as bar:
for _idx, (file, noFinding, covid) in metadata.iterrows():
bar.next()
if covid: label = CLASSES[0] # covid
elif noFinding: label = CLASSES[1] # healthy
else: label = CLASSES[2] # other
if args['classfilter'] > -1 and label != CLASSES[args['classfilter']]:
continue
image = imread(path.join(args['dataset'], 'images', file))
image = cvtColor(image, COLOR_BGR2RGB)
image = resize(image, IMG_DIMENSIONS)
data.append(image)
count += 1
if not args['random'] and count == args['limit']:
break
if args['random']:
x = random.sample(range(0,len(data)),args['limit'])
randomdata = []
for i in x:
randomdata.append(data[i])
data=randomdata
del randomdata
data = np.array(data) / 255.0
explainer = tf_explain.core.grad_cam.GradCAM()
data = (data, None)
model = load_model(args['model'])
for layer in model.layers:
if "conv" in layer.name.lower():
gradcam_layer = layer.name
# Start explainer
grid = explainer.explain(data, model,class_index=args['classindex'], layer_name=gradcam_layer)
from datetime import datetime
now = datetime.now()
current_time = now.strftime("%H_%M_%S")
saveloc=path.join(path.dirname(args['model']),'visualized')
filename = f"{current_time}_class_{CLASSES[args['classindex']]}"
if args['classfilter'] > -1:
filename += f"_only_{CLASSES[args['classfilter']]}"
filename += ".jpg"
try:
explainer.save(grid,f"{saveloc}", f"{filename}")
print(f"GradCam image created and saved to {saveloc}/{filename}")
except:
print("F")