forked from LTH14/mage
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels_vit_mage.py
124 lines (94 loc) · 4.6 KB
/
models_vit_mage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from functools import partial
import torch
import torch.nn as nn
import timm.models.vision_transformer
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
self.dropout = nn.Dropout(dropout)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1)))
torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
def forward(
self, input_ids
):
input_shape = input_ids.size()
seq_length = input_shape[1]
position_ids = self.position_ids[:, :seq_length]
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class VisionTransformerMage(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
"""
def __init__(self, global_pool=False, vqgan_ckpt_path='vqgan_jax.ckpt', **kwargs):
super(VisionTransformerMage, self).__init__(**kwargs)
self.global_pool = global_pool
if self.global_pool:
norm_layer = kwargs['norm_layer']
embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
# --------------------------------------------------------------------------
# VQGAN specifics
config = OmegaConf.load('config/vqgan.yaml').model
self.vqgan = VQModel(ddconfig=config.params.ddconfig,
n_embed=config.params.n_embed,
embed_dim=config.params.embed_dim,
ckpt_path=vqgan_ckpt_path)
for param in self.vqgan.parameters():
param.requires_grad = False
codebook_size = config.params.n_embed
vocab_size = codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
self.fake_class_label = codebook_size + 1100 - 1024
self.mask_token_label = vocab_size - 1
self.token_emb = BertEmbeddings(vocab_size=vocab_size,
hidden_size=kwargs['embed_dim'],
max_position_embeddings=256 + 1,
dropout=0.1)
def forward_features(self, x):
# tokenization
with torch.no_grad():
z_q, _, token_tuple = self.vqgan.encode(x)
_, _, token_indices = token_tuple
token_indices = token_indices.reshape(z_q.size(0), -1)
# concate class token
token_indices = torch.cat(
[torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
token_indices[:, 0] = self.fake_class_label
token_indices = token_indices.long()
# bert embedding
x = self.token_emb(token_indices)
for blk in self.blocks:
x = blk(x)
if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token
outcome = self.fc_norm(x)
# x = self.fc_norm(x)
# outcome = x[:, 1:, :].mean(dim=1) # global pool without cls token
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
def vit_base_patch16(**kwargs):
model = VisionTransformerMage(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_large_patch16(**kwargs):
model = VisionTransformerMage(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model