-
Notifications
You must be signed in to change notification settings - Fork 1
/
pruning_utils.py
102 lines (76 loc) · 3.06 KB
/
pruning_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
def pruning_model(model, px, conv1=False):
print('start unstructured pruning for all conv layers')
parameters_to_prune =[]
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
if (name == 'conv1' and conv1) or (name != 'conv1'):
parameters_to_prune.append((m,'weight'))
parameters_to_prune = tuple(parameters_to_prune)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=px,
)
def check_sparsity(model, conv1=True):
sum_list = 0
zero_sum = 0
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
if name == 'conv1':
if conv1:
sum_list = sum_list+float(m.weight.nelement())
zero_sum = zero_sum+float(torch.sum(m.weight == 0))
else:
print('skip conv1 for sparsity checking')
else:
sum_list = sum_list+float(m.weight.nelement())
zero_sum = zero_sum+float(torch.sum(m.weight == 0))
print('* remain weight = ', 100*(1-zero_sum/sum_list),'%')
return 100*(1-zero_sum/sum_list)
def remove_prune(model, conv1=True):
print('remove pruning')
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
if (name == 'conv1' and conv1) or (name != 'conv1'):
prune.remove(m,'weight')
def extract_mask(model_dict):
new_dict = {}
for key in model_dict.keys():
if 'mask' in key:
new_dict[key] = model_dict[key]
return new_dict
def extract_main_weight(model_dict):
new_dict = {}
for key in model_dict.keys():
if not 'mask' in key:
new_dict[key] = model_dict[key]
return new_dict
def prune_model_custom(model, mask_dict, conv1=False):
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
if (name == 'conv1' and conv1) or (name != 'conv1'):
print('pruning layer with custom mask:', name)
prune.CustomFromMask.apply(m, 'weight', mask=mask_dict[name+'.weight_mask'].to(m.weight.device))
def pruning_model_random(model, px):
print('start unstructured pruning')
parameters_to_prune =[]
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
parameters_to_prune.append((m,'weight'))
parameters_to_prune = tuple(parameters_to_prune)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.RandomUnstructured,
amount=px,
)
for name,m in model.named_modules():
index = 0
if isinstance(m, nn.Conv2d):
origin_mask = m.weight_mask
print((origin_mask == 0).sum().float() / origin_mask.numel())
print(index)
index += 1
print(name, (origin_mask == 0).sum())