-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling_siglip.py
361 lines (295 loc) · 15.2 KB
/
modeling_siglip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
from typing import Tuple, Optional
import torch
import torch.nn as nn
# Architecture of SigLIP
# (VisionEmbeddings, Encoder) -> VisionTransformer -> VisionModel
# (Attention, LayerNorm, MLP) -> EncoderLayer -> Encoder
class SiglipVisionConfig:
def __init__(self, \
image_size: int = 224,
patch_size: int = 16,
num_channels: int = 3,
hidden_size: int = 768, # embed_dim of image, not text_tokens(we use Linear Projection for this)
intermediate_size: int = 3072, # hidden dim of FFN that it projects to.
num_hidden_layers: int = 12, # Nx of Encoder Block
num_attention_heads: int = 12, # Num of heads of MH-A
attention_dropout: float = 0.0,
layer_norm_eps: float = 1e-6,
num_image_tokens: int = None,
**kwargs
):
'''
Paligemma comes in different configs
The above are the default configs of the one we will use
'''
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.num_image_tokens = num_image_tokens
class SiglipAttention(nn.Module):
def __init__(self, config: SiglipVisionConfig):
'''
(LayerNorm) -> Attention -> (residual concat) -> (LayerNorm) -> (MLP)
Idea:
- Make 3 copies xk, xq, xv
- Interact with corresponding weight matrices
-
'''
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scale = 1 / (self.head_dim ** 0.5)
self.dropout = config.attention_dropout
self.key_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.value_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.query_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, x):
batch_size, num_patches, _ = x.shape
# Contextualize the embddings i.e
# Each patch needs to interact with key/query/value-learned_matrices
# (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
xk = self.key_proj(x)
# (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
xq = self.query_proj(x)
# (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
xv = self.value_proj(x)
# split into heads, so that each head just focusses on a specific part of the embed_dim
# (B, num_patches, embed_dim) -> (B, num_patches, num_heads, head_dim)
xk = xk.reshape(batch_size, num_patches, self.num_heads, self.head_dim)
xq = xq.reshape(batch_size, num_patches, self.num_heads, self.head_dim)
xv = xv.reshape(batch_size, num_patches, self.num_heads, self.head_dim)
# Now, since learning of each head can be parallelized, we transpose them
# For a particular head, all the patches' slice corresponding to that head
# can be trained indep'tly from other heads
# (B, num_patches, num_heads, head_dim) -> (B, num_heads, num_patches, head_dim)
keys = xk.transpose(1, 2)
values = xv.transpose(1, 2)
queries = xq.transpose(1, 2)
# How strongly are ith-query and jth-keys related.
# Higher the score, they are linked more closely
# In ow, we are contextualising the embeddings
# (B, num_heads, num_patches, head_dim) x (B, num_heads, head_dim, num_patches)
# => (B, num_heads, num_patches, num_patches)
attn_weights = torch.matmul(queries, keys.transpose(2, 3) )
# scaling down by sqrt(head_dim)
# (B, num_heads, num_patches, num_patches) -> (B, num_heads, num_patches, num_patches)
attn_weights = attn_weights * self.scale
if attn_weights.size() != (batch_size, self.num_heads, num_patches, num_patches):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, num_patches, num_patches)}, but is"
f" {attn_weights.size()}"
)
# Now, we apply softmax
# In case of causal mask i.e Language Transf., after calc the attn_weights
# we mask the future tokens by making the weights as -inf. Why?
# So, that the softmax,which has e**x will make e**(-inf) = 0, i.e
# => Thereby making the influence of future tokens 0. (no contribution)
# Not so, in the case of VisionTransformer, here. Why?
# An image isn't sequencial,
# Argument: "any patch of an image contain info about other patches"
# weighted-sum argument when interacting attn_weights with values
# (B, num_heads, num_patches, num_patches)
attn_outputs = torch.softmax(attn_weights, dim = -1, dtype = torch.float32).to(queries.dtype)
##### CAN IGNORE as it is done only during training
# Apply dropout only during training -> for nference ew are passing dropout = 0.0
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
################
# interact with values, so that the weighted_sum of a attn_score can be treated as how much
# weight/importance each contextualised emb gives to each value inorder to predict
# the next image_token
# important funda -> register
# (B, num_heads, num_patches, num_patches) * (B, num_heads, num_patches, head_dim)
# -> (B, num_heads, num_patches, head_dim)
attn_outputs = torch.matmul(attn_outputs, values)
if attn_outputs.size() != (batch_size, self.num_heads, num_patches, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, num_patches, self.head_dim)}, but is"
f" {attn_outputs.size()}"
)
# concat all the heads
# (B, num_heads, num_patches, head_dim)
# rearrange info abt heads i.e aggregate all heads to make embed_dim for a given patch
# (B, num_heads, num_patches, head_dim) -> (B, num_patches, num_heads, head_dim)
attn_outputs = attn_outputs.transpose(1, 2)
# concat
# (B, num_patches, num_heads, head_dim) -> (B, num_patches, num_heads * head_dim)
# (B, num_patches, num_heads * head_dim) === (B, num_patches, embed_dim)
attn_outputs = attn_outputs.reshape(batch_size, num_patches, self.num_heads * self.head_dim)
# (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
attn_outputs = self.out_proj(attn_outputs)
return attn_outputs, attn_weights
class SiglipMLP(nn.Module):
"""
Idea behind this layer:-
- Increases degress of freedom of the model
- i.e more params to learn -> more flexible
- Has non-linear `gelu` activation function
- Therefore, model can be more flexible to learn non-linear transformations
making it more powerful, than to just learn linear-transformations [argument]
- Also, it generally is used, between 2 main blocks of an architecture
- To reshape tensors before it enters another block,
- The projections in MLP make it compatible for it to be fed
to the next block
"""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.intermediate_size = config.intermediate_size
self.embed_dim = config.hidden_size
self.fc1 = nn.Linear(self.embed_dim, self.intermediate_size)
self.fc2 = nn.Linear(self.intermediate_size, self.embed_dim)
def forward(self, x):
# (B, num_patches, embed_dim)
x = self.fc1(x)
x = nn.functional.gelu( x, approximate = "tanh" )
x = self.fc2(x)
return x
class SiglipEncoderLayer(nn.Module):
'''
Recall:
(Attention, LayerNorm, MLP) -> SiglipEncoderLayer
'''
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.layer_norm_eps = config.layer_norm_eps
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(normalized_shape = self.embed_dim,
eps = self.layer_norm_eps) # normalise across, feature-space to solve covariate-shift problem
self.self_attn = SiglipAttention(config)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(normalized_shape = self.embed_dim,
eps = self.layer_norm_eps)
def forward(self, x):
residual = x
# (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
x = self.layer_norm1(x)
# # (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
x, _ = self.self_attn(x)
# (B, num_patches, embed_dim) -> # (B, num_patches, embed_dim)
x = residual + x
residual = x
# (B, num_patches, embed_dim) -> # (B, num_patches, embed_dim)
x = self.layer_norm2(x)
# (B, num_patches, embed_dim) -> # (B, num_patches, embed_dim)
x = self.mlp(x)
return residual + x
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.num_hidden_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.layers = nn.ModuleList()
for _ in range(self.num_hidden_layers):
self.layers.append(SiglipEncoderLayer(config))
def forward(self, x):
# "x" comes after final_emb i.e concat( img_embs, pos_embs )
# (B, num_patches, hidden_size or embed_dim ) -> (B, num_patches, embed_dim) ... N times
for layer in self.layers:
x = layer(x)
return x
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
'''
Gets an image as an np-ndarray
Each Image in a batch gets partitioned into patches
Each patch is then projected to a `embed_dim` dimesnsion
'''
super().__init__()
self.patch_size = config.patch_size
self.image_size = config.image_size
self.num_patches = ( self.image_size // self.patch_size ) ** 2
self.num_channels = config.num_channels # 3 RGB
self.embed_dim = config.hidden_size
# Partition each image into num_patches s.t
# Each patch has (patch_size, patch_size) shape
self.patch_embedding = nn.Conv2d( in_channels = self.num_channels, \
out_channels = self.embed_dim, \
stride = self.patch_size, # recall viz in H, W dim. Just this?
kernel_size = self.patch_size, # in conjunction to stride, makes the intended patch
padding = "valid", # no padding applied
)
# Positional Encodings (to be learned unlike Vanilla Transformer)
# Recall: In Vanilla Transformer, the positional Encodings were deterministic sinosuids
self.positional_embeddings = nn.Embedding( self.num_patches, self.embed_dim )
# For book-keeping we store the postion_ids in buffer for later
self.position_ids = torch.arange(self.num_patches) # These many postions for each patch
self.position_ids = self.position_ids.expand((1, -1)) # Adds the batch_dim
# register in buffer <since we don't learn this.>
# Also, since it's easy to reconstruct, we don't want to store in state_dict
# i.e when loading state_dict, we intend not to retain this buffer, <register_buffer is in nn.Module>
self.register_buffer( name = "postion_ids",
tensor = self.position_ids,
persistent= False )
def forward(self, x: torch.FloatTensor):
# x is a numpy tensor/ nd-array
# (B, C, H, W) -> (B, embed_dim, num_patches_H, num_patches_W)
# where, num_patches_H = H // patch_size
# and, num_patches_W = W // patch_size
patch_embedding = self.patch_embedding(x)
# Flatten as in viz ViT before concat, so that same shape as position_enc
# (B, embed_dim, num_patches_H, num_patches_W) -> (B, embed_dim, num_patches_H * num_patches_W)
image_embeddings = patch_embedding.flatten( start_dim = 2 )
# (B, num_patches) -> (B, num_patches, embed_dim)
positional_encodings = self.positional_embeddings( self.position_ids )
# Align the image_embs to correctly assoc with positional_encodings
# (B, embed_dim, num_patches_H * num_patches_W) -> (B, num_patches_H * num_patches_W, embed_dim)
image_embeddings = image_embeddings.transpose(1, 2)
final_embeddings = image_embeddings + positional_encodings
return final_embeddings
class SiglipTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
# (VisionEmbeddings, Encoder)
self.config = config
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, x):
# x is a np-ndarray
# (B, C, H, W) -> (B, num_patches, embed_dim)
x = self.embeddings(x)
# (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
x = self.encoder(x)
# (B, num_patches, embed_dim) -> (B, num_patches, embed_dim)
x = self.post_layernorm(x)
return x
class SiglipVisionModel(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
# sets the config and model
self.config = config
self.model = SiglipTransformer(self.config)
def forward(self, x):
# x is of type numpy
# (Batch, C, H, W)
return self.model(x)
if __name__ == '__main__':
batch_size = 2
num_channels = 3
height, width = 224, 224
x = torch.rand(batch_size, num_channels, height, width)
config = SiglipVisionConfig( num_channels=3, \
image_size=224,
patch_size=16,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
attention_dropout=0.0,
layer_norm_eps=1e-6,
num_image_tokens=None
)
# print( config.num_channels )
model = SiglipVisionModel(config)
out = model.forward(x)
print( out.shape ) # torch.Size([2, 196, 768])
# num_patches = (224//16)**2 = 14**2 = 196
# verified!