forked from anhquan0412/animation-classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradcam.py
171 lines (144 loc) · 6.61 KB
/
gradcam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from fastai.vision import *
from fastai.callbacks.hooks import *
import scipy.ndimage
class GradCam():
@classmethod
def from_interp(cls,learn,interp,img_idx,ds_type=DatasetType.Valid,include_label=False):
# produce heatmap and xb_grad for pred label (and actual label if include_label is True)
if ds_type == DatasetType.Valid:
ds = interp.data.valid_ds
elif ds_type == DatasetType.Test:
ds = interp.data.test_ds
include_label=False
else:
return None
x_img = ds.x[img_idx]
xb,_ = interp.data.one_item(x_img)
xb_img = Image(interp.data.denorm(xb)[0])
probs = interp.preds[img_idx].numpy()
pred_idx = interp.pred_class[img_idx].item() # get class idx of img prediction label
hmap_pred,xb_grad_pred = get_grad_heatmap(learn,xb,pred_idx,size=xb_img.shape[-1])
prob_pred = probs[pred_idx]
actual_args=None
if include_label:
actual_idx = ds.y.items[img_idx] # get class idx of img actual label
if actual_idx!=pred_idx:
hmap_actual,xb_grad_actual = get_grad_heatmap(learn,xb,actual_idx,size=xb_img.shape[-1])
prob_actual = probs[actual_idx]
actual_args=[interp.data.classes[actual_idx],prob_actual,hmap_actual,xb_grad_actual]
return cls(xb_img,interp.data.classes[pred_idx],prob_pred,hmap_pred,xb_grad_pred,actual_args)
@classmethod
def from_one_img(cls,learn,x_img,label1=None,label2=None):
'''
learn: fastai's Learner
x_img: fastai.vision.image.Image
label1: generate heatmap according to this label. If None, this wil be the label with highest probability from the model
label2: generate additional heatmap according to this label
'''
pred_class,pred_idx,probs = learn.predict(x_img)
label1= str(pred_class) if not label1 else label1
xb,_ = learn.data.one_item(x_img)
xb_img = Image(learn.data.denorm(xb)[0])
probs = probs.numpy()
label1_idx = learn.data.classes.index(label1)
hmap1,xb_grad1 = get_grad_heatmap(learn,xb,label1_idx,size=xb_img.shape[-1])
prob1 = probs[label1_idx]
label2_args = None
if label2:
label2_idx = learn.data.classes.index(label2)
hmap2,xb_grad2 = get_grad_heatmap(learn,xb,label2_idx,size=xb_img.shape[-1])
prob2 = probs[label2_idx]
label2_args = [label2,prob2,hmap2,xb_grad2]
return cls(xb_img,label1,prob1,hmap1,xb_grad1,label2_args)
def __init__(self,xb_img,label1,prob1,hmap1,xb_grad1,label2_args=None):
self.xb_img=xb_img
self.label1,self.prob1,self.hmap1,self.xb_grad1 = label1,prob1,hmap1,xb_grad1
if label2_args:
self.label2,self.prob2,self.hmap2,self.xb_grad2 = label2_args
def plot(self,plot_hm=True,plot_gbp=True):
if not plot_hm and not plot_gbp:
plot_hm=True
cols = 5 if hasattr(self, 'label2') else 3
if not plot_gbp or not plot_hm:
cols-= 2 if hasattr(self, 'label2') else 1
fig,row_axes = plt.subplots(1,cols,figsize=(cols*5,5))
col=0
size=self.xb_img.shape[-1]
self.xb_img.show(row_axes[col]);col+=1
label1_title = f'1.{self.label1} {self.prob1:.3f}'
if plot_hm:
show_heatmap(self.hmap1,self.xb_img,size,row_axes[col])
row_axes[col].set_title(label1_title);col+=1
if plot_gbp:
row_axes[col].imshow(self.xb_grad1)
row_axes[col].set_axis_off()
row_axes[col].set_title(label1_title);col+=1
if hasattr(self, 'label2'):
label2_title = f'2.{self.label2} {self.prob2:.3f}'
if plot_hm:
show_heatmap(self.hmap2,self.xb_img,size,row_axes[col])
row_axes[col].set_title(label2_title);col+=1
if plot_gbp:
row_axes[col].imshow(self.xb_grad2)
row_axes[col].set_axis_off()
row_axes[col].set_title(label2_title)
# plt.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)
# fig.savefig('data_draw/both/gradcam.png')
def minmax_norm(x):
return (x - np.min(x))/(np.max(x) - np.min(x))
def scaleup(x,size):
scale_mult=size/x.shape[0]
upsampled = scipy.ndimage.zoom(x, scale_mult)
return upsampled
# hook for Gradcam
def hooked_backward(m,xb,target_layer,clas):
with hook_output(target_layer) as hook_a: #hook at last layer of group 0's output (after bn, size 512x7x7 if resnet34)
with hook_output(target_layer, grad=True) as hook_g: # gradient w.r.t to the target_layer
preds = m(xb)
preds[0,int(clas)].backward() # same as onehot backprop
return hook_a,hook_g
def clamp_gradients_hook(module, grad_in, grad_out):
for grad in grad_in:
torch.clamp_(grad, min=0.0)
# hook for guided backprop
def hooked_ReLU(m,xb,clas):
relu_modules = [module[1] for module in m.named_modules() if str(module[1]) == "ReLU(inplace)"]
with callbacks.Hooks(relu_modules, clamp_gradients_hook, is_forward=False) as _:
preds = m(xb)
preds[0,int(clas)].backward()
def guided_backprop(learn,xb,y):
xb = xb.cuda()
m = learn.model.eval();
xb.requires_grad_();
if not xb.grad is None:
xb.grad.zero_();
hooked_ReLU(m,xb,y);
return xb.grad[0].cpu().numpy()
def show_heatmap(hm,xb_im,size,ax=None):
if ax is None:
_,ax = plt.subplots()
xb_im.show(ax)
ax.imshow(hm, alpha=0.8, extent=(0,size,size,0),
interpolation='bilinear',cmap='magma');
def get_grad_heatmap(learn,xb,y,size):
'''
Main function to get hmap for heatmap and xb_grad for guided backprop
'''
xb = xb.cuda()
m = learn.model.eval();
target_layer = m[0][-1][-1] # last layer of group 0
hook_a,hook_g = hooked_backward(m,xb,target_layer,y)
target_act= hook_a.stored[0].cpu().numpy()
target_grad = hook_g.stored[0][0].cpu().numpy()
mean_grad = target_grad.mean(1).mean(1)
# hmap = (target_act*mean_grad[...,None,None]).mean(0)
hmap = (target_act*mean_grad[...,None,None]).sum(0)
hmap = np.where(hmap >= 0, hmap, 0)
xb_grad = guided_backprop(learn,xb,y) # (3,224,224)
#minmax norm the grad
xb_grad = minmax_norm(xb_grad)
hmap_scaleup = minmax_norm(scaleup(hmap,size)) # (224,224)
# multiply xb_grad and hmap_scaleup and switch axis
xb_grad = np.einsum('ijk, jk->jki',xb_grad, hmap_scaleup) #(224,224,3)
return hmap,xb_grad