-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Combining CLIPA-v2 and SigLIP (both big_vision based) models #660
Merged
Merged
Changes from 13 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
64f4644
merge changes for clipa inference
zw615 546e8ae
update get_tokenizer to pass CI test; replace gelu_appoximate with ac…
zw615 8450c95
Merge remote-tracking branch 'zwclipa/clipa_torch_inference' into sig…
rwightman e3c2ea2
Temporary, cannot have a force tf dependency
rwightman 0316911
Supporting SigLIP and CLIPA-v2 models (both sourced from big_vision j…
rwightman 9d8385e
Merge remote-tracking branch 'origin/main' into siglip_clipa_models
rwightman 39ba303
Fix some test failures, remove old v1 CLIPA configs, add add 336 H14 …
rwightman 0724aab
Fix torchscript
rwightman f04eee8
Fix CoCa expand typo, force final LN after attentional pool
rwightman 2f568cd
Merge branch 'main' into siglip_clipa_models
rwightman 2c396d2
Used wrong default clean fn in SimpleTokenizer, put lower case back
rwightman e14f34b
Attempt to fix xlm roberta test w/ pretrained hf weight difference
rwightman 3637f9d
SigLIP weights working. More changes to support differing image prepr…
rwightman 72196f1
A typo and unused import
rwightman 948d9e1
Merge remote-tracking branch 'origin/main' into siglip_clipa_models
rwightman c29cc9c
Fix two small issues, add hf_tokenizer_name to SigLIP models for non …
rwightman 72b75bd
CLIPA reference temppory rwightman/ models for testing
rwightman b086ddb
Rename profile->profiler to avoid python naming conflict
rwightman 05e9864
More tokenizer rework, add context_len as class attr set in factory, …
rwightman 07f2c16
fix ViT-SO400M-14-SigLIP name
gabrielilharco d7542e4
Fix CoCa pool LN, improve clarity of ViT pooling logic
rwightman 85f19b8
Exclude first/last tokens from tokens output of text models, should m…
rwightman a9d8d58
Add eval results for CLIPA + SigLIP models
gabrielilharco 95ae868
Fixup bigG CLIPA config, 83.03 top-1 IN-1k
rwightman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# eval on a single gpu | ||
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \ | ||
--model ViT-L-16-CL32-GAP \ | ||
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ | ||
--seed 0 \ | ||
--imagenet-val '/path/to/ImageNet/val' |
10 changes: 10 additions & 0 deletions
10
scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
CUDA_VISIBLE_DEVICES=1 python3 -m training.main \ | ||
--model ViT-H-14-CL32-GAP-BigVision \ | ||
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ | ||
--force-image-size 336 \ | ||
--square-resize-only \ | ||
--interpolation 'bilinear' \ | ||
--image-mean 0.485 0.456 0.406 \ | ||
--image-std 0.229 0.224 0.225 \ | ||
--seed 0 \ | ||
--imagenet-val '/path/to/ImageNet/val' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch | ||
import numpy as np | ||
|
||
from .model import CustomTextCLIP | ||
from .transformer import TextTransformer, Transformer | ||
|
||
|
||
@torch.no_grad() | ||
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): | ||
""" Load weights from .npz checkpoints for official Google big_vision image-text models | ||
|
||
Currently the SigLIP source models are supported and a CustomTextCLIP destination model | ||
w/ timm image encoder. | ||
""" | ||
from timm.layers import resample_patch_embed, resample_abs_pos_embed | ||
|
||
def _n2p(w, t=True): | ||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: | ||
w = w.flatten() | ||
if t: | ||
if w.ndim == 4: | ||
w = w.transpose([3, 2, 0, 1]) | ||
elif w.ndim == 3: | ||
w = w.transpose([2, 0, 1]) | ||
elif w.ndim == 2: | ||
w = w.transpose([1, 0]) | ||
return torch.from_numpy(w) | ||
|
||
w = np.load(checkpoint_path) | ||
interpolation = 'bilinear' | ||
antialias = False | ||
|
||
def _convert_timm_img(module, prefix): | ||
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) | ||
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: | ||
embed_conv_w = resample_patch_embed( | ||
embed_conv_w, | ||
module.patch_embed.proj.weight.shape[-2:], | ||
interpolation=interpolation, | ||
antialias=antialias, | ||
verbose=True, | ||
) | ||
module.patch_embed.proj.weight.copy_(embed_conv_w) | ||
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) | ||
|
||
if module.cls_token is not None: | ||
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) | ||
|
||
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) | ||
if pos_embed_w.shape != module.pos_embed.shape: | ||
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' | ||
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) | ||
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights | ||
pos_embed_w, | ||
new_size=module.patch_embed.grid_size, | ||
num_prefix_tokens=num_prefix_tokens, | ||
interpolation=interpolation, | ||
antialias=antialias, | ||
verbose=True, | ||
) | ||
module.pos_embed.copy_(pos_embed_w) | ||
|
||
mha_sub, b_sub, ln1_sub = (0, 0, 1) | ||
for i, block in enumerate(module.blocks.children()): | ||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' | ||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
block.attn.qkv.weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) | ||
block.attn.qkv.bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) | ||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
for r in range(2): | ||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) | ||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) | ||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) | ||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) | ||
|
||
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) | ||
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) | ||
|
||
if module.attn_pool is not None: | ||
block_prefix = f'{prefix}MAPHead_0/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | ||
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) | ||
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) | ||
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) | ||
module.attn_pool.kv.weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) | ||
module.attn_pool.kv.bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) | ||
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
for r in range(2): | ||
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) | ||
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) | ||
|
||
def _convert_openclip_transformer(module: Transformer, prefix): | ||
for i, block in enumerate(module.resblocks.children()): | ||
block_prefix = f'{prefix}encoderblock_{i}/' | ||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | ||
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | ||
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | ||
block.attn.in_proj_weight.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) | ||
block.attn.in_proj_bias.copy_(torch.cat([ | ||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) | ||
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | ||
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | ||
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) | ||
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) | ||
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) | ||
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) | ||
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) | ||
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) | ||
|
||
def _convert_openclip_txt(module: TextTransformer, prefix): | ||
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) | ||
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) | ||
module.positional_embedding.copy_(pos_embed_w) | ||
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') | ||
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) | ||
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) | ||
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) | ||
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) | ||
|
||
_convert_timm_img(model.visual.trunk, 'params/img/') | ||
_convert_openclip_txt(model.text, 'params/txt/') | ||
model.logit_bias.copy_(_n2p(w['params/b'])[0]) | ||
model.logit_scale.copy_(_n2p(w['params/t'])[0]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) | ||
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) | ||
IMAGENET_MEAN = (0.485, 0.456, 0.406) | ||
IMAGENET_STD = (0.229, 0.224, 0.225) | ||
INCEPTION_MEAN = (0.5, 0.5, 0.5) | ||
INCEPTION_STD = (0.5, 0.5, 0.5) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how stable do we expect this to be ? should we somehow (at least a comment) lock to a specific commit of big_vision ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rom1504 it works for big vision siglip weights -> timm + builtin text models only, I don't see why it wouldn't be stable, it has nothing to do with the code revision, it's only the weight files that are relevant. This wouldn't work for their lit models, but could be extended, etc.
I could drop this code after I convert and push to HF hub, but it could be useful reference or for future models. It's isolated to checkpoint loading with a npz/npy filename and has no extra deps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok you're saying this will work for some specific big vision trained checkpoints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then maybe we could indicate those above the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will make a comment that it only supports the big_vision SigLIP weights right now, any other big_vision weights (present or future) could be added by expanding the support by checking for various keys in the numpy archive....