-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
183 lines (145 loc) · 6.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import collections
import re
import torch
from glob import glob
from collections import OrderedDict
def check_args(args, rank=0):
args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file)
args.log_file = os.path.join(args.checkpoint_dir, args.log_file)
if rank == 0:
os.makedirs(args.checkpoint_dir, exist_ok=True)
with open(args.setting_file, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
print('------------ Options -------------')
for k in args.__dict__:
v = args.__dict__[k]
opt_file.write('%s: %s\n' % (str(k), str(v)))
print('%s: %s' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
print('------------ End -------------')
return args
def show_all_variables(rank=0):
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True if rank == 0 else False)
def torch_show_all_params(model, rank=0):
params = list(model.parameters())
k = 0
for i in params:
l = 1
for j in i.size():
l *= j
k = k + l
if rank == 0:
print("Total param num: " + str(k))
# import ipdb
def get_assigment_map_from_checkpoint(tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
initialized_variable_names = {}
new_variable_names = set()
unused_variable_names = set()
name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var
init_vars = tf.train.list_variables(init_checkpoint)
assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
if 'adam' not in name and 'lamb' not in name and 'accum' not in name:
unused_variable_names.add(name)
continue
# assignment_map[name] = name
assignment_map[name] = name_to_variable[name]
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1
for name in name_to_variable:
if name not in initialized_variable_names:
new_variable_names.add(name)
return assignment_map, initialized_variable_names, new_variable_names, unused_variable_names
# loading weights
def init_from_checkpoint(init_checkpoint, tvars=None, rank=0):
if not tvars:
tvars = tf.trainable_variables()
assignment_map, initialized_variable_names, new_variable_names, unused_variable_names \
= get_assigment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
if rank == 0:
for t in initialized_variable_names:
if ":0" not in t:
print("Loading weights success: " + t)
print('New parameters:', new_variable_names)
print('Unused parameters', unused_variable_names)
def torch_init_model(model, init_checkpoint):
state_dict = torch.load(init_checkpoint, map_location='cpu')
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
print("missing keys:{}".format(missing_keys))
print('unexpected keys:{}'.format(unexpected_keys))
print('error msgs:{}'.format(error_msgs))
# todo: fix here
def torch_init_pretrain_model(model, init_checkpoint):
state_dict = torch.load(init_checkpoint, map_location='cpu')
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
new_state_dict = OrderedDict()
for state, value in state_dict.items():
if not state.startswith("classifier"):
new_state_dict["bert." + state.replace(".gamma", ".weight").replace(".beta", ".bias")] = value
'''
new_state_dict = {k[len("module."):]: v for k, v in new_state_dict.items() if
k not in {"module.classifier.weight", "module.classifier.bias", "classifier.weight", "classifier.bias"}}
'''
def load(module, prefix=''):
local_metadata = {}
module._load_from_state_dict(
new_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
print("missing keys:{}".format(missing_keys))
print('unexpected keys:{}'.format(unexpected_keys))
print('error msgs:{}'.format(error_msgs))
def torch_save_model(model, output_dir, scores, max_save_num=1):
# Save model checkpoint
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
saved_pths = glob(os.path.join(output_dir, '*.pth'))
saved_pths.sort()
while len(saved_pths) >= max_save_num:
if os.path.exists(saved_pths[0].replace('//', '/')):
os.remove(saved_pths[0].replace('//', '/'))
del saved_pths[0]
save_prex = "checkpoint_score"
for k in scores:
save_prex += ('_' + k + '-' + str(scores[k])[:6])
save_prex += '.pth'
torch.save(model_to_save.state_dict(),
os.path.join(output_dir, save_prex))
print("Saving model checkpoint to %s", output_dir)