-
Notifications
You must be signed in to change notification settings - Fork 0
/
P_Pruner.py
136 lines (93 loc) · 4.79 KB
/
P_Pruner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn as nn
import copy
import numpy as np
class Pruner():
def __init__(self):
super().__init__()
def filters_to_keep(self, layer, nxt_layer):
is_cuda = layer.weight.is_cuda
filters = layer.weight
biases = layer.bias
nz_filters = filters.data.view(layer.out_channels, -1).sum(dim=1) # Flatten the filters to compare them
ixs = torch.LongTensor(np.argwhere(nz_filters!=0)) # Get which filters are not equal to zero
ixs = ixs.cuda() if is_cuda else ixs
filters_keep = filters.index_select(0, ixs[0]).data # keep only the non_zero filters
biases_keep = biases.index_select(0, ixs[0]).data
if nxt_layer is not None:
nxt_filters = nxt_layer.weight
nxt_filters_keep = nxt_filters.index_select(1, ixs[0]).data
else:
nxt_filters_keep = None
return filters_keep, biases_keep, nxt_filters_keep
def prune_conv(self, layer, nxt_layer):
assert layer.__class__.__name__ == 'Conv2d'
new_weights, new_biases, new_next_weights = self.filters_to_keep(layer, nxt_layer)
new_out_channels = new_weights.shape[0]
new_in_channels = new_weights.shape[1]
layer.out_channels = new_out_channels
layer.in_channels = new_in_channels
layer.weight = nn.Parameter(new_weights)
layer.bias = nn.Parameter(new_biases)
if new_next_weights is not None:
new_next_in_channels = new_next_weights.shape[1]
nxt_layer.weight = nn.Parameter(new_next_weights)
nxt_layer.in_channels = new_next_in_channels
return layer, nxt_layer
def delete_fc_weights(self, layer, last_conv):
is_cuda = last_conv.weight.is_cuda
filters = last_conv.weight
nz_filters = filters.data.view(last_conv.out_channels, -1).sum(dim=1) # Flatten the filters to compare them
ixs = torch.LongTensor(np.argwhere(nz_filters!=0))
ixs = ixs.cuda() if is_cuda else ixs
weights = layer.weight.data
#biases = layer.bias.data
weights_keep = weights.index_select(1, ixs[0]).data
layer.in_features = weights_keep.shape[1]
layer.weight = nn.Parameter(weights_keep)
return layer
def _find_next_conv(self, model, conv_ix):
for k,m in enumerate(model.children()):
if k > conv_ix and m.__class__.__name__ == 'Conv2d':
next_conv_ix = k
break
else:
next_conv_ix = None
return next_conv_ix
def _get_last_conv_ix(self, model):
layer_names = list(dict(model.named_children()).keys())
last_conv_ix = 0
for i in range(len(layer_names)):
if getattr(model, layer_names[i]).__class__.__name__ == 'Conv2d':
last_conv_ix = i
return last_conv_ix
def _get_first_fc_ix(self, model):
layer_names = list(dict(model.named_children()).keys())
first_fc_ix = 0
for i in range(len(layer_names)):
if getattr(model, layer_names[i]).__class__.__name__ == 'Linear':
first_fc_ix = i
break
return first_fc_ix
def prune_model(self, model):
pruned_model = copy.deepcopy(model)
layer_names = list(dict(pruned_model.named_children()).keys())
for k,m in enumerate(list(pruned_model.children())):
last_conv_ix = self._get_last_conv_ix(pruned_model)
first_fc_ix = self._get_first_fc_ix(pruned_model)
if isinstance(m, nn.Conv2d):
next_conv_ix = self._find_next_conv(model, k)
if next_conv_ix is not None: # The conv layer is not the last one
next_conv = getattr(pruned_model, layer_names[next_conv_ix]) # Get the next_conv_layer
new_m, new_next_m = self.prune_conv(m, next_conv) # Prune the current conv layer
setattr(pruned_model, layer_names[k], new_m) # Apply the changes to the model
setattr(pruned_model, layer_names[next_conv_ix], new_next_m)
else:
#new_m, _ = self.prune_conv(m, None) # Prune the current conv layer without changing the next one
#setattr(pruned_model, layer_names[k], new_m) # Apply the changes to the model
pass
if isinstance(m, nn.Linear) and k==first_fc_ix:
new_m = self.delete_fc_weights(m, getattr(model, layer_names[last_conv_ix]))
else:
pass
return pruned_model