-
Notifications
You must be signed in to change notification settings - Fork 1
/
models_vit.py
193 lines (160 loc) · 6.55 KB
/
models_vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
from util.video_vit import Attention, Block, PatchEmbed
class VisionTransformer(nn.Module):
"""Vision Transformer with support for global average pooling"""
def __init__(
self,
num_frames=16,
t_patch_size=2,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=None,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
no_qkv_bias=False,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
dropout=0.3,
sep_pos_embed=True,
cls_embed=True
):
super().__init__()
print(locals())
self.sep_pos_embed = sep_pos_embed
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim, num_frames, t_patch_size)
num_patches = self.patch_embed.num_patches
input_size = self.patch_embed.input_size
self.input_size = input_size
self.cls_embed = cls_embed
if self.cls_embed:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if sep_pos_embed:
self.pos_embed_spatial = nn.Parameter(torch.zeros(1, input_size[1] * input_size[2], embed_dim))
self.pos_embed_temporal = nn.Parameter(torch.zeros(1, input_size[0], embed_dim))
if self.cls_embed:
self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim))
else:
if self.cls_embed:
_num_patches = num_patches + 1
else:
_num_patches = num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, _num_patches, embed_dim), requires_grad=True) # fixed or not?
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=not no_qkv_bias,
qk_scale=None,
norm_layer=norm_layer,
drop_path=dpr[i],
attn_func=partial(Attention, input_size=self.patch_embed.input_size),
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
self.dropout = nn.Dropout(dropout)
self.head = nn.Linear(embed_dim, num_classes) if num_classes is not None else nn.Identity()
if num_classes is not None:
torch.nn.init.normal_(self.head.weight, std=0.02)
@torch.jit.ignore
def no_weight_decay(self):
return {
"cls_token",
"pos_embed",
"pos_embed_spatial",
"pos_embed_temporal",
"pos_embed_class",
}
def forward(self, x):
# embed patches
x = self.patch_embed(x)
N, T, L, C = x.shape # T: temporal; L: spatial
x = x.view([N, T * L, C])
# append cls token
if self.cls_embed:
cls_token = self.cls_token
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.sep_pos_embed:
pos_embed = self.pos_embed_spatial.repeat(1, self.input_size[0], 1) + torch.repeat_interleave(self.pos_embed_temporal, self.input_size[1] * self.input_size[2], dim=1)
if self.cls_embed:
pos_embed = torch.cat([self.pos_embed_class.expand(pos_embed.shape[0], -1, -1), pos_embed], 1)
else:
pos_embed = self.pos_embed[:, :, :]
x = x + pos_embed
# reshape to [N, T, L, C] or [N, T*L, C]
requires_t_shape = (
len(self.blocks) > 0 # support empty decoder
and hasattr(self.blocks[0].attn, "requires_t_shape")
and self.blocks[0].attn.requires_t_shape
)
if requires_t_shape:
x = x.view([N, T, L, C])
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
if requires_t_shape:
x = x.view([N, T * L, C])
# classifier (TODO: FIX THIS)
x = x[:, 0, :] # CLS token
# x = x[:, 1:, :].mean(dim=1) # global pool
x = self.norm(x)
x = self.dropout(x)
x = self.head(x)
return x
def get_last_selfattention(self, x):
# embed patches
x = self.patch_embed(x)
N, T, L, C = x.shape # T: temporal; L: spatial
x = x.view([N, T * L, C])
# append cls token
if self.cls_embed:
cls_token = self.cls_token
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.sep_pos_embed:
pos_embed = self.pos_embed_spatial.repeat(1, self.input_size[0], 1) + torch.repeat_interleave(self.pos_embed_temporal, self.input_size[1] * self.input_size[2], dim=1)
if self.cls_embed:
pos_embed = torch.cat([self.pos_embed_class.expand(pos_embed.shape[0], -1, -1), pos_embed], 1)
else:
pos_embed = self.pos_embed[:, :, :]
x = x + pos_embed
# reshape to [N, T, L, C] or [N, T*L, C]
requires_t_shape = (
len(self.blocks) > 0 # support empty decoder
and hasattr(self.blocks[0].attn, "requires_t_shape")
and self.blocks[0].attn.requires_t_shape
)
if requires_t_shape:
x = x.view([N, T, L, C])
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)
def vit_huge_patch14(**kwargs):
model = VisionTransformer(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model