-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert_weights.py
203 lines (188 loc) · 13.9 KB
/
convert_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""Script to download the pre-trained tensorflow weights and convert them to pytorch weights."""
import os
import argparse
import torch
import numpy as np
from tensorflow.python.training import py_checkpoint_reader
from repnet import utils
from repnet.model import RepNet
# Relevant paths
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt'
TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index']
OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
# Mapping of ndim -> permutation to go from tf to pytorch
WEIGHTS_PERMUTATION = {
2: (1, 0),
4: (3, 2, 0, 1),
5: (4, 3, 0, 1, 2)
}
# Mapping of tf attributes -> pytorch attributes
ATTR_MAPPING = {
'kernel':'weight',
'bias': 'bias',
'beta': 'bias',
'gamma': 'weight',
'moving_mean': 'running_mean',
'moving_variance': 'running_var'
}
# Mapping of tf checkpoint -> tf model -> pytorch model
WEIGHTS_MAPPING = [
# Base frame encoder
('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'),
('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'),
('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'),
('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'),
('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'),
('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'),
('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'),
('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'),
('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'),
('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'),
('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'),
('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'),
('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'),
('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'),
('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'),
('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'),
('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'),
('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'),
('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'),
('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'),
('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'),
('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'),
('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'),
('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'),
('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'),
('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'),
('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'),
('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'),
('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'),
('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'),
('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'),
('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'),
('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'),
('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'),
('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'),
('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'),
('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'),
('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'),
('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'),
('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'),
('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'),
('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'),
('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'),
('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'),
('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'),
('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'),
('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'),
('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'),
('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'),
('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'),
('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'),
('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'),
('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'),
('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'),
('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'),
('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'),
('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'),
('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'),
('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'),
('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'),
('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'),
('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'),
('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'),
('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'),
# Temporal convolution
('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'),
('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'),
('conv_3x3_layer', 'conv2d', 'tsm_conv.0'),
# Period length head
('input_projection', 'dense', 'period_length_head.0.input_projection'),
('pos_encoding', None, 'period_length_head.0.pos_encoding'),
('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'),
('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'),
('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'),
('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'),
('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'),
('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'),
('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'),
('fc_layers.0', 'dense_14', 'period_length_head.1'),
('fc_layers.1', 'dense_15', 'period_length_head.3'),
('fc_layers.2', 'dense_16', 'period_length_head.5'),
# Periodicity head
('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'),
('pos_encoding2', None, 'periodicity_head.0.pos_encoding'),
('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'),
('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'),
('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'),
('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'),
('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'),
('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'),
('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'),
('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'),
('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'),
('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'),
]
# Script arguments
parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.')
if __name__ == '__main__':
args = parser.parse_args()
# Download tensorflow checkpoints
print('Downloading checkpoints...')
tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint')
os.makedirs(tf_checkpoint_dir, exist_ok=True)
for file in TF_CHECKPOINT_FILES:
dst = os.path.join(tf_checkpoint_dir, file)
if not os.path.exists(dst):
utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst)
# Load tensorflow weights into a dictionary
print('Loading tensorflow checkpoint...')
checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88')
checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
shape_map = checkpoint_reader.get_variable_to_shape_map()
tf_state_dict = {}
for var_name in sorted(shape_map.keys()):
var_tensor = checkpoint_reader.get_tensor(var_name)
if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name:
continue # Skip variables that are not part of the model, e.g. from the optimizer
# Split var_name into path
var_path = var_name.split('/')[1:] # Remove `model`` key from the path
var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']]
# Map weights into a nested dictionary
current_dict = tf_state_dict
for path in var_path[:-1]:
current_dict = current_dict.setdefault(path, {})
current_dict[var_path[-1]] = var_tensor
# Merge transformer self-attention weights into a single tensor
for k in ['transformer_layers', 'transformer_layers2']:
v = tf_state_dict[k]['0']['mha']
v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0)
v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0)
del v['wk'], v['wq'], v['wv']
tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True)
# Add missing final level for some weights
for k, v in tf_state_dict.items():
if not isinstance(v, dict):
tf_state_dict[k] = {None: v}
# Convert to a format compatible with PyTorch and save
print(f'Converting to PyTorch format...')
pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth')
pt_state_dict = {}
for k_tf, _, k_pt in WEIGHTS_MAPPING:
assert k_pt not in pt_state_dict
pt_state_dict[k_pt] = {}
for attr in tf_state_dict[k_tf]:
new_attr = ATTR_MAPPING.get(attr, attr)
pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr])
if attr == 'kernel':
weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] # Permute weights if needed
pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation)
pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True)
torch.save(pt_state_dict, pt_checkpoint_path)
# Initialize the model and try to load the weights
print('Check that the weights can be loaded into the model...')
model = RepNet()
pt_state_dict = torch.load(pt_checkpoint_path)
model.load_state_dict(pt_state_dict)
print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.')