Skip to content

Commit

Permalink
Add model defs & weights for new so400m i18n variant. Add a 378x378 c…
Browse files Browse the repository at this point in the history
…onfig for the original 384x348 so400m because the patch size doesn't divide 384 properly.
  • Loading branch information
rwightman committed Oct 9, 2024
1 parent fc5a37b commit 7153ee0
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 25 deletions.
2 changes: 1 addition & 1 deletion requirements-training.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ pandas
braceexpand
huggingface_hub
transformers[sentencepiece]
timm>=1.0.7
timm>=1.0.10
fsspec
48 changes: 29 additions & 19 deletions src/open_clip/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def _n2p(w, t=True):
def _n2p(w, t=True, idx=None):
if idx is not None:
w = w[idx]
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
Expand Down Expand Up @@ -66,21 +68,28 @@ def _convert_timm_img(module, prefix):

mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))

module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
Expand Down Expand Up @@ -129,13 +138,14 @@ def _convert_openclip_txt(module: TextTransformer, prefix):
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])
if module.text_projection is not None:
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'img/')
_convert_openclip_txt(model.text, 'txt/')
model.logit_bias.copy_(_n2p(w['b'])[0])
model.logit_scale.copy_(_n2p(w['t'])[0])


@torch.no_grad()
Expand Down
2 changes: 2 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class CLIPTextCfg:
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'argmax'
proj_bias: bool = False
proj_type: str = 'linear' # control final text projection, 'none' forces no projection
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
Expand Down Expand Up @@ -209,6 +210,7 @@ def _build_text_tower(
no_causal_mask=text_cfg.no_causal_mask,
pad_id=text_cfg.pad_id,
pool_type=text_cfg.pool_type,
proj_type=text_cfg.proj_type,
proj_bias=text_cfg.proj_bias,
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
Expand Down
30 changes: 30 additions & 0 deletions src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 378,
"timm_model_name": "vit_so400m_patch14_siglip_378",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
30 changes: 30 additions & 0 deletions src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_so400m_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 250000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"pool_type": "last",
"proj_type": "none",
"norm_kwargs":{
"eps": 1e-6
}
}
}
6 changes: 6 additions & 0 deletions src/open_clip/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ def _mccfg(url='', hf_hub='', **kwargs):
"ViT-SO400M-14-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),
),
"ViT-SO400M-16-SigLIP-i18n-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'),
),
"ViT-SO400M-14-SigLIP-378": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used
),
"ViT-SO400M-14-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),
),
Expand Down
3 changes: 2 additions & 1 deletion src/open_clip/push_to_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def push_pretrained_to_hf_hub(
tokenizer = get_tokenizer(model_name)
if hf_tokenizer_self:
# make hf tokenizer config in the uploaded model point to self instead of original location
model_config['text']['hf_tokenizer_name'] = repo_id
model_config['text_cfg']['hf_tokenizer_name'] = repo_id

push_to_hf_hub(
model=model,
Expand Down Expand Up @@ -316,6 +316,7 @@ def generate_readme(model_card: dict, model_name: str):
image_std=args.image_std,
image_interpolation=args.image_interpolation,
image_resize_mode=args.image_resize_mode,
hf_tokenizer_self=args.hf_tokenizer_self,
)

print(f'{args.model} saved.')
12 changes: 8 additions & 4 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,11 +677,12 @@ def __init__(
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
output_dim: int = 512,
output_dim: Optional[int] = 512,
embed_cls: bool = False,
no_causal_mask: bool = False,
pad_id: int = 0,
pool_type: str = 'argmax',
proj_type: str = 'linear',
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
Expand Down Expand Up @@ -721,10 +722,13 @@ def __init__(
else:
self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)

if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
if proj_type == 'none' or not output_dim:
self.text_projection = None
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))

self.init_parameters()

Expand Down

0 comments on commit 7153ee0

Please sign in to comment.