-
Notifications
You must be signed in to change notification settings - Fork 1
/
models_mae.py
444 lines (363 loc) · 16.1 KB
/
models_mae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
from helpers import video_vit
class MaskedAutoencoderViT(nn.Module):
"""Masked Autoencoder with VisionTransformer backbone"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=4,
decoder_num_heads=16,
mlp_ratio=4.0,
norm_layer=nn.LayerNorm,
norm_pix_loss=False,
num_frames=16,
t_patch_size=2,
patch_embed=video_vit.PatchEmbed,
no_qkv_bias=False,
sep_pos_embed=True,
trunc_init=False,
cls_embed=True,
pred_t_dim=16
):
super().__init__()
self.trunc_init = trunc_init
self.sep_pos_embed = sep_pos_embed
self.cls_embed = cls_embed
self.pred_t_dim = pred_t_dim
self.t_pred_patch_size = t_patch_size * pred_t_dim // num_frames
self.patch_embed = patch_embed(
img_size,
patch_size,
in_chans,
embed_dim,
num_frames,
t_patch_size,
)
num_patches = self.patch_embed.num_patches
input_size = self.patch_embed.input_size
self.input_size = input_size
if self.cls_embed:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.decoder_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
if sep_pos_embed:
self.pos_embed_spatial = nn.Parameter(torch.zeros(1, input_size[1] * input_size[2], embed_dim))
self.pos_embed_temporal = nn.Parameter(torch.zeros(1, input_size[0], embed_dim))
if self.cls_embed:
self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim))
else:
if self.cls_embed:
_num_patches = num_patches + 1
else:
_num_patches = num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, _num_patches, embed_dim))
self.blocks = nn.ModuleList(
[
video_vit.Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=not no_qkv_bias,
qk_scale=None,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
if sep_pos_embed:
self.decoder_pos_embed_spatial = nn.Parameter(torch.zeros(1, input_size[1] * input_size[2], decoder_embed_dim))
self.decoder_pos_embed_temporal = nn.Parameter(torch.zeros(1, input_size[0], decoder_embed_dim))
if self.cls_embed:
self.decoder_pos_embed_class = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
else:
if self.cls_embed:
_num_patches = num_patches + 1
else:
_num_patches = num_patches
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, _num_patches, decoder_embed_dim))
self.decoder_blocks = nn.ModuleList(
[
video_vit.Block(
decoder_embed_dim,
decoder_num_heads,
mlp_ratio,
qkv_bias=not no_qkv_bias,
qk_scale=None,
norm_layer=norm_layer,
)
for i in range(decoder_depth)
]
)
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, self.t_pred_patch_size * patch_size**2 * in_chans, bias=True)
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
print("model initialized")
def initialize_weights(self):
if self.cls_embed:
torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
if self.sep_pos_embed:
torch.nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
torch.nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
torch.nn.init.trunc_normal_(self.decoder_pos_embed_spatial, std=0.02)
torch.nn.init.trunc_normal_(self.decoder_pos_embed_temporal, std=0.02)
if self.cls_embed:
torch.nn.init.trunc_normal_(self.pos_embed_class, std=0.02)
torch.nn.init.trunc_normal_(self.decoder_pos_embed_class, std=0.02)
else:
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
torch.nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
w = self.patch_embed.proj.weight.data
if self.trunc_init:
torch.nn.init.trunc_normal_(w)
torch.nn.init.trunc_normal_(self.mask_token, std=0.02)
else:
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.mask_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
if self.trunc_init:
nn.init.trunc_normal_(m.weight, std=0.02)
else:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
N, _, T, H, W = imgs.shape
p = self.patch_embed.patch_size[0]
u = self.t_pred_patch_size
assert H == W and H % p == 0 and T % u == 0
h = w = H // p
t = T // u
x = imgs.reshape(shape=(N, 3, t, u, h, p, w, p))
x = torch.einsum("nctuhpwq->nthwupqc", x)
x = x.reshape(shape=(N, t * h * w, u * p**2 * 3))
self.patch_info = (N, T, H, W, p, u, t, h, w)
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
N, T, H, W, p, u, t, h, w = self.patch_info
x = x.reshape(shape=(N, t, h, w, u, p, p, 3))
x = torch.einsum("nthwupqc->nctuhpwq", x)
imgs = x.reshape(shape=(N, 3, T, H, W))
return imgs
def center_masking(self, x, mask_ratio):
"""
Perform per-sample boundary masking
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.cat((torch.arange(L//2, device=x.device), torch.arange(L//2 - 1, -1, -1, device=x.device))).repeat(N, 1)
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def temporal_masking(self, x, mask_ratio):
"""
Perform per-sample temporal masking
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.arange(N*L, device=x.device).reshape(N, L)
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_encoder(self, x, mask_ratio, mask_type='random'):
# embed patches
x = self.patch_embed(x)
N, T, L, C = x.shape
x = x.reshape(N, T * L, C)
# masking: length -> length * mask_ratio
if mask_type == 'random':
x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio)
elif mask_type == 'temporal':
x, mask, ids_restore, ids_keep = self.temporal_masking(x, mask_ratio)
elif mask_type == 'center':
x, mask, ids_restore, ids_keep = self.center_masking(x, mask_ratio)
else:
raise NotImplementedError("Does not support {} masking".format(mask_type))
x = x.view(N, -1, C)
# append cls token
if self.cls_embed:
cls_token = self.cls_token
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add pos embed w/o cls token
if self.sep_pos_embed:
pos_embed = self.pos_embed_spatial.repeat(1, self.input_size[0], 1) + torch.repeat_interleave(self.pos_embed_temporal, self.input_size[1] * self.input_size[2], dim=1)
pos_embed = pos_embed.expand(x.shape[0], -1, -1)
pos_embed = torch.gather(pos_embed, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]))
if self.cls_embed:
pos_embed = torch.cat([self.pos_embed_class.expand(pos_embed.shape[0], -1, -1), pos_embed], 1)
else:
if self.cls_embed:
cls_ind = 1
else:
cls_ind = 0
pos_embed = self.pos_embed[:, cls_ind:, :].expand(x.shape[0], -1, -1)
pos_embed = torch.gather(pos_embed, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]))
if self.cls_embed:
pos_embed = torch.cat([self.pos_embed[:, :1, :].expand(x.shape[0], -1, -1), pos_embed], 1)
x = x.view([N, -1, C]) + pos_embed
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.cls_embed:
# remove cls token
x = x[:, 1:, :]
else:
x = x[:, :, :]
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
N = x.shape[0]
T = self.patch_embed.t_grid_size
H = W = self.patch_embed.grid_size
# embed tokens
x = self.decoder_embed(x)
C = x.shape[-1]
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1)
x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
x_ = x_.view([N, T * H * W, C])
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])) # unshuffle
x = x_.view([N, T * H * W, C])
# append cls token
if self.cls_embed:
decoder_cls_token = self.decoder_cls_token
decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((decoder_cls_tokens, x), dim=1)
if self.sep_pos_embed:
decoder_pos_embed = self.decoder_pos_embed_spatial.repeat(1, self.input_size[0], 1) + torch.repeat_interleave(self.decoder_pos_embed_temporal, self.input_size[1] * self.input_size[2], dim=1)
if self.cls_embed:
decoder_pos_embed = torch.cat([self.decoder_pos_embed_class.expand(decoder_pos_embed.shape[0], -1, -1), decoder_pos_embed], 1)
else:
decoder_pos_embed = self.decoder_pos_embed[:, :, :]
# add pos embed
x = x + decoder_pos_embed
attn = self.decoder_blocks[0].attn
requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape
if requires_t_shape:
x = x.view([N, T, H * W, C])
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
if requires_t_shape:
x = x.view([N, T * H * W, -1])
if self.cls_embed:
# remove cls token
x = x[:, 1:, :]
else:
x = x[:, :, :]
return x
def forward_loss(self, imgs, pred, mask, visualize):
"""
imgs: [N, 3, T, H, W]
pred: [N, t*h*w, u*p*p*3]
mask: [N*t, h*w], 0 is keep, 1 is remove,
"""
_imgs = torch.index_select(imgs, 2, torch.linspace(0, imgs.shape[2] - 1, self.pred_t_dim).long().to(imgs.device))
target = self.patchify(_imgs)
if visualize:
self.target = target
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
mask = mask.view(loss.shape)
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward(self, imgs, mask_ratio=0.9, visualize=False, mask_type='random'):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio, mask_type)
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss = self.forward_loss(imgs, pred, mask, visualize)
if visualize:
N, T, H, W, p, u, t, h, w = self.patch_info
pred = pred
reconstruct = self.unpatchify(pred * mask.reshape(N, t * h * w, 1) + self.target * (1 - mask.reshape(N, t * h * w, 1)))
masked = self.unpatchify(self.target * (1 - mask.reshape(N, t * h * w, 1)))
comparison = torch.stack([self.unpatchify(self.target), masked, reconstruct], dim=1)
return loss, pred, mask, comparison
else:
return loss, pred, mask
def mae_vit_huge_patch14(**kwargs):
model = MaskedAutoencoderViT(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model