-
Notifications
You must be signed in to change notification settings - Fork 115
/
prune_utils.py
57 lines (45 loc) · 1.58 KB
/
prune_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
# @Time : 2021/5/24 下午4:36
# @Author : midaskong
# @File : prune_utils.py
# @Description:
import torch
from copy import deepcopy
import numpy as np
import torch.nn.functional as F
def gather_bn_weights(module_list):
prune_idx = list(range(len(module_list)))
size_list = [idx.weight.data.shape[0] for idx in module_list.values()]
bn_weights = torch.zeros(sum(size_list))
index = 0
for i, idx in enumerate(module_list.values()):
size = size_list[i]
bn_weights[index:(index + size)] = idx.weight.data.abs().clone()
index += size
return bn_weights
def gather_conv_weights(module_list):
prune_idx = list(range(len(module_list)))
size_list = [idx.weight.data.shape[0] for idx in module_list.values()]
conv_weights = torch.zeros(sum(size_list))
index = 0
for i, idx in enumerate(module_list.values()):
size = size_list[i]
conv_weights[index:(index + size)] = idx.weight.data.abs().sum(dim=1).sum(dim=1).sum(dim=1).clone()
index += size
return conv_weights
def obtain_bn_mask(bn_module, thre):
thre = thre.cuda()
mask = bn_module.weight.data.abs().ge(thre).float()
return mask
def obtain_conv_mask(conv_module, thre):
thre = thre.cuda()
mask = conv_module.weight.data.abs().sum(dim=1).sum(dim=1).sum(dim=1).ge(thre).float()
return mask
def uodate_pruned_yolov5_cfg(model, maskbndict):
# save pruned yolov5 model in yaml format:
# model:
# model to be pruned
# maskbndict:
# key : module name
# value : bn layer mask index
return