Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combining CLIPA-v2 and SigLIP (both big_vision based) models #660

Merged
merged 24 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
64f4644
merge changes for clipa inference
zw615 Jul 26, 2023
546e8ae
update get_tokenizer to pass CI test; replace gelu_appoximate with ac…
zw615 Oct 5, 2023
8450c95
Merge remote-tracking branch 'zwclipa/clipa_torch_inference' into sig…
rwightman Oct 6, 2023
e3c2ea2
Temporary, cannot have a force tf dependency
rwightman Oct 6, 2023
0316911
Supporting SigLIP and CLIPA-v2 models (both sourced from big_vision j…
rwightman Oct 11, 2023
9d8385e
Merge remote-tracking branch 'origin/main' into siglip_clipa_models
rwightman Oct 11, 2023
39ba303
Fix some test failures, remove old v1 CLIPA configs, add add 336 H14 …
rwightman Oct 11, 2023
0724aab
Fix torchscript
rwightman Oct 11, 2023
f04eee8
Fix CoCa expand typo, force final LN after attentional pool
rwightman Oct 12, 2023
2f568cd
Merge branch 'main' into siglip_clipa_models
rwightman Oct 12, 2023
2c396d2
Used wrong default clean fn in SimpleTokenizer, put lower case back
rwightman Oct 12, 2023
e14f34b
Attempt to fix xlm roberta test w/ pretrained hf weight difference
rwightman Oct 12, 2023
3637f9d
SigLIP weights working. More changes to support differing image prepr…
rwightman Oct 17, 2023
72196f1
A typo and unused import
rwightman Oct 17, 2023
948d9e1
Merge remote-tracking branch 'origin/main' into siglip_clipa_models
rwightman Oct 17, 2023
c29cc9c
Fix two small issues, add hf_tokenizer_name to SigLIP models for non …
rwightman Oct 17, 2023
72b75bd
CLIPA reference temppory rwightman/ models for testing
rwightman Oct 17, 2023
b086ddb
Rename profile->profiler to avoid python naming conflict
rwightman Oct 18, 2023
05e9864
More tokenizer rework, add context_len as class attr set in factory, …
rwightman Oct 18, 2023
07f2c16
fix ViT-SO400M-14-SigLIP name
gabrielilharco Oct 19, 2023
d7542e4
Fix CoCa pool LN, improve clarity of ViT pooling logic
rwightman Oct 19, 2023
85f19b8
Exclude first/last tokens from tokens output of text models, should m…
rwightman Oct 19, 2023
a9d8d58
Add eval results for CLIPA + SigLIP models
gabrielilharco Oct 20, 2023
95ae868
Fixup bigG CLIPA config, 83.03 top-1 IN-1k
rwightman Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions scripts/clipav1_vit_l16_i37_t8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# eval on a single gpu
CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \
--model ViT-L-16-CL32-GAP \
--pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
10 changes: 10 additions & 0 deletions scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CUDA_VISIBLE_DEVICES=1 python3 -m training.main \
--model ViT-H-14-CL32-GAP-BigVision \
--pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \
--force-image-size 336 \
--square-resize-only \
--interpolation 'bilinear' \
--image-mean 0.485 0.456 0.406 \
--image-std 0.229 0.224 0.225 \
--seed 0 \
--imagenet-val '/path/to/ImageNet/val'
3 changes: 2 additions & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
get_model_tokenize_cfg, get_model_context_len, get_model_preprocess_cfg, set_model_preprocess_cfg
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
136 changes: 136 additions & 0 deletions src/open_clip/big_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
import numpy as np

from .model import CustomTextCLIP
from .transformer import TextTransformer, Transformer


@torch.no_grad()
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how stable do we expect this to be ? should we somehow (at least a comment) lock to a specific commit of big_vision ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rom1504 it works for big vision siglip weights -> timm + builtin text models only, I don't see why it wouldn't be stable, it has nothing to do with the code revision, it's only the weight files that are relevant. This wouldn't work for their lit models, but could be extended, etc.

I could drop this code after I convert and push to HF hub, but it could be useful reference or for future models. It's isolated to checkpoint loading with a npz/npy filename and has no extra deps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok you're saying this will work for some specific big vision trained checkpoints.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then maybe we could indicate those above the function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will make a comment that it only supports the big_vision SigLIP weights right now, any other big_vision weights (present or future) could be added by expanding the support by checking for various keys in the numpy archive....

""" Load weights from .npz checkpoints for official Google big_vision image-text models

Currently the SigLIP source models are supported and a CustomTextCLIP destination model
w/ timm image encoder.
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)

w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False

def _convert_timm_img(module, prefix):
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
module.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.patch_embed.proj.weight.copy_(embed_conv_w)
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))

if module.cls_token is not None:
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))

pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
if pos_embed_w.shape != module.pos_embed.shape:
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=module.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.pos_embed.copy_(pos_embed_w)

mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))

module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))

if module.attn_pool is not None:
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
module.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
module.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
for r in range(2):
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))

def _convert_openclip_transformer(module: Transformer, prefix):
for i, block in enumerate(module.resblocks.children()):
block_prefix = f'{prefix}encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.in_proj_weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.in_proj_bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))

def _convert_openclip_txt(module: TextTransformer, prefix):
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
module.positional_embedding.copy_(pos_embed_w)
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])


4 changes: 4 additions & 0 deletions src/open_clip/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)
Loading