From d1cba0565076e7dd1217762d86477c9ecce232bc Mon Sep 17 00:00:00 2001 From: Yanqing0327 Date: Tue, 17 Dec 2024 15:06:16 -0800 Subject: [PATCH] add eps in config --- src/open_clip/coca_model.py | 6 +- src/open_clip/constants.py | 5 - src/open_clip/convert.py | 48 ++- src/open_clip/factory.py | 3 - src/open_clip/loss.py | 80 ++--- src/open_clip/model.py | 4 - .../model_configs/RN50x16-quickgelu.json | 22 -- .../model_configs/RN50x4-quickgelu.json | 22 -- .../model_configs/RN50x64-quickgelu.json | 22 -- src/open_clip/model_configs/ViT-H-14-378.json | 17 - .../model_configs/ViT-H-14-CLIPS-224.json | 8 +- .../model_configs/ViT-L-14-336-quickgelu.json | 17 - .../model_configs/ViT-L-14-CLIPS-224.json | 8 +- .../model_configs/ViT-L-14-CLIPS-336.json | 8 +- .../ViT-SO400M-14-SigLIP-378.json | 30 -- .../ViT-SO400M-16-SigLIP-i18n-256.json | 30 -- .../model_configs/ViT-bigG-14-quickgelu.json | 19 -- .../model_configs/ViTamin-L-384.json | 20 -- .../model_configs/ViTamin-L2-384.json | 20 -- src/open_clip/pretrained.py | 298 +++++------------- src/open_clip/push_to_hf_hub.py | 11 +- src/open_clip/timm_model.py | 11 +- src/open_clip/transformer.py | 14 +- src/open_clip/version.py | 2 +- 24 files changed, 167 insertions(+), 558 deletions(-) delete mode 100644 src/open_clip/model_configs/RN50x16-quickgelu.json delete mode 100644 src/open_clip/model_configs/RN50x4-quickgelu.json delete mode 100644 src/open_clip/model_configs/RN50x64-quickgelu.json delete mode 100644 src/open_clip/model_configs/ViT-H-14-378.json delete mode 100644 src/open_clip/model_configs/ViT-L-14-336-quickgelu.json delete mode 100644 src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json delete mode 100644 src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json delete mode 100644 src/open_clip/model_configs/ViT-bigG-14-quickgelu.json delete mode 100644 src/open_clip/model_configs/ViTamin-L-384.json delete mode 100644 src/open_clip/model_configs/ViTamin-L2-384.json diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 539616332..dda3faba5 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -96,7 +96,6 @@ def __init__( quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, - nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, pad_id: int = 0, ): @@ -132,10 +131,9 @@ def __init__( cast_dtype=cast_dtype, ) - lshape = [1] if nonscalar_logit_scale else [] - self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) + self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) if init_logit_bias is not None: - self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) + self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) else: self.logit_bias = None self.pad_id = pad_id diff --git a/src/open_clip/constants.py b/src/open_clip/constants.py index 5bdfc2451..599c48c03 100644 --- a/src/open_clip/constants.py +++ b/src/open_clip/constants.py @@ -4,8 +4,3 @@ IMAGENET_STD = (0.229, 0.224, 0.225) INCEPTION_MEAN = (0.5, 0.5, 0.5) INCEPTION_STD = (0.5, 0.5, 0.5) - -# Default name for a weights file hosted on the Huggingface Hub. -HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl -HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version -HF_CONFIG_NAME = 'open_clip_config.json' diff --git a/src/open_clip/convert.py b/src/open_clip/convert.py index f0c06ffba..84571e0f1 100644 --- a/src/open_clip/convert.py +++ b/src/open_clip/convert.py @@ -18,9 +18,7 @@ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): """ from timm.layers import resample_patch_embed, resample_abs_pos_embed - def _n2p(w, t=True, idx=None): - if idx is not None: - w = w[idx] + def _n2p(w, t=True): if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: @@ -68,28 +66,21 @@ def _convert_timm_img(module, prefix): mha_sub, b_sub, ln1_sub = (0, 0, 1) for i, block in enumerate(module.blocks.children()): - if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w: - block_prefix = f'{prefix}Transformer/encoderblock/' - idx = i - else: - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' - idx = None + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx)) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx)) + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')])) + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx)) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx)) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx)) + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_( - _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx)) - getattr(block.mlp, f'fc{r + 1}').bias.copy_( - _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx)) + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) @@ -138,14 +129,13 @@ def _convert_openclip_txt(module: TextTransformer, prefix): _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) - if module.text_projection is not None: - module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) - module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - - _convert_timm_img(model.visual.trunk, 'img/') - _convert_openclip_txt(model.text, 'txt/') - model.logit_bias.copy_(_n2p(w['b'])[0]) - model.logit_scale.copy_(_n2p(w['t'])[0]) + module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + + _convert_timm_img(model.visual.trunk, 'params/img/') + _convert_openclip_txt(model.text, 'params/txt/') + model.logit_bias.copy_(_n2p(w['params/b'])[0]) + model.logit_scale.copy_(_n2p(w['params/t'])[0]) @torch.no_grad() diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 49f59a6e3..d8a0bd90b 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -346,9 +346,6 @@ def create_model( else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: - if 'CLIPS' in model_name: - model_cfg['vision_cfg']['eps'] = 1e-6 - model_cfg['text_cfg']['eps'] = 1e-6 model = CLIP(**model_cfg, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index b3e6dd256..5beaab1c3 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from torch.nn import functional as F @@ -104,14 +102,8 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor: def get_logits(self, image_features, text_features, logit_scale): if self.world_size > 1: all_image_features, all_text_features = gather_features( - image_features, - text_features, - local_loss=self.local_loss, - gather_with_grad=self.gather_with_grad, - rank=self.rank, - world_size=self.world_size, - use_horovod=self.use_horovod, - ) + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T @@ -166,11 +158,12 @@ def __init__( self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + + clip_loss = torch.tensor(0) + if self.clip_loss_weight: clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - else: - clip_loss = torch.tensor(0, device=logits.device) caption_loss = self.caption_loss( logits.permute(0, 2, 1), @@ -323,17 +316,19 @@ class SigLipLoss(nn.Module): """ def __init__( self, - cache_labels: bool = False, - rank: int = 0, - world_size: int = 1, - dist_impl: Optional[str] = None, + cache_labels=False, + rank=0, + world_size=1, + bidir=True, + use_horovod=False, ): super().__init__() self.cache_labels = cache_labels self.rank = rank self.world_size = world_size - self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change - assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather') + assert not use_horovod # FIXME need to look at hvd ops for ring transfers + self.use_horovod = use_horovod + self.bidir = bidir # cache state FIXME cache not currently used, worthwhile? self.prev_num_logits = 0 @@ -366,9 +361,10 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output loss = self._loss(image_features, text_features, logit_scale, logit_bias) if self.world_size > 1: - if self.dist_impl == 'bidir': - right_rank = (self.rank + 1) % self.world_size - left_rank = (self.rank - 1 + self.world_size) % self.world_size + # exchange text features w/ neighbour world_size - 1 times + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size + if self.bidir: text_features_to_right = text_features_to_left = text_features num_bidir, remainder = divmod(self.world_size - 1, 2) for i in range(num_bidir): @@ -378,6 +374,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output text_features_to_left, text_features_to_right, ) + for f in text_features_recv: loss += self._loss( image_features, @@ -390,10 +387,8 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output if remainder: text_features_recv = neighbour_exchange_with_grad( - left_rank, - right_rank, - text_features_to_right - ) + left_rank, right_rank, text_features_to_right) + loss += self._loss( image_features, text_features_recv, @@ -401,16 +396,12 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output logit_bias, negative_only=True, ) - elif self.dist_impl == "shift": - right_rank = (self.rank + 1) % self.world_size - left_rank = (self.rank - 1 + self.world_size) % self.world_size + else: text_features_to_right = text_features for i in range(self.world_size - 1): text_features_from_left = neighbour_exchange_with_grad( - left_rank, - right_rank, - text_features_to_right, - ) + left_rank, right_rank, text_features_to_right) + loss += self._loss( image_features, text_features_from_left, @@ -419,30 +410,5 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output negative_only=True, ) text_features_to_right = text_features_from_left - elif self.dist_impl == "reduce": - for i in range(self.world_size): - text_from_other = torch.distributed.nn.all_reduce( - text_features * (self.rank == i), - torch.distributed.ReduceOp.SUM, - ) - loss += float(i != self.rank) * self._loss( - image_features, - text_from_other, - logit_scale, - logit_bias, - negative_only=True, - ) - elif self.dist_impl == "gather": - all_text = torch.distributed.nn.all_gather(text_features) - for i in range(self.world_size): - loss += float(i != self.rank) * self._loss( - image_features, - all_text[i], - logit_scale, - logit_bias, - negative_only=True, - ) - else: - assert False return {"contrastive_loss": loss} if output_dict else loss diff --git a/src/open_clip/model.py b/src/open_clip/model.py index bf2a02c0e..9a9443603 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -31,7 +31,6 @@ class CLIPVisionCfg: mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 - eps: float = 1e-5 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results @@ -77,7 +76,6 @@ class CLIPTextCfg: output_tokens: bool = False act_kwargs: dict = None norm_kwargs: dict = None - eps: float = 1e-5 # HuggingFace specific text tower config hf_model_name: Optional[str] = None @@ -168,7 +166,6 @@ def _build_vision_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, - eps=vision_cfg.eps, ) return visual @@ -218,7 +215,6 @@ def _build_text_tower( output_tokens=text_cfg.output_tokens, act_layer=act_layer, norm_layer=norm_layer, - eps=text_cfg.eps, ) return text diff --git a/src/open_clip/model_configs/RN50x16-quickgelu.json b/src/open_clip/model_configs/RN50x16-quickgelu.json deleted file mode 100644 index 989bb87c6..000000000 --- a/src/open_clip/model_configs/RN50x16-quickgelu.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "embed_dim": 768, - "quick_gelu": true, - "vision_cfg": { - "image_size": 384, - "layers": [ - 6, - 8, - 18, - 8 - ], - "width": 96, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/RN50x4-quickgelu.json b/src/open_clip/model_configs/RN50x4-quickgelu.json deleted file mode 100644 index 9bf11fc3a..000000000 --- a/src/open_clip/model_configs/RN50x4-quickgelu.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "embed_dim": 640, - "quick_gelu": true, - "vision_cfg": { - "image_size": 288, - "layers": [ - 4, - 6, - 10, - 6 - ], - "width": 80, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 640, - "heads": 10, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/RN50x64-quickgelu.json b/src/open_clip/model_configs/RN50x64-quickgelu.json deleted file mode 100644 index 6da9d7e21..000000000 --- a/src/open_clip/model_configs/RN50x64-quickgelu.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "embed_dim": 1024, - "quick_gelu": true, - "vision_cfg": { - "image_size": 448, - "layers": [ - 3, - 15, - 36, - 10 - ], - "width": 128, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-378.json b/src/open_clip/model_configs/ViT-H-14-378.json deleted file mode 100644 index 04b2e62d6..000000000 --- a/src/open_clip/model_configs/ViT-H-14-378.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 378, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-CLIPS-224.json b/src/open_clip/model_configs/ViT-H-14-CLIPS-224.json index 92433c625..dcea499a6 100644 --- a/src/open_clip/model_configs/ViT-H-14-CLIPS-224.json +++ b/src/open_clip/model_configs/ViT-H-14-CLIPS-224.json @@ -9,7 +9,10 @@ "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", - "final_ln_after_pool": true + "final_ln_after_pool": true, + "norm_kwargs": { + "eps": 1e-6 + } }, "text_cfg": { "context_length": 80, @@ -25,6 +28,9 @@ "no_causal_mask": true, "act_kwargs": { "approximate": "tanh" + }, + "norm_kwargs": { + "eps": 1e-6 } } }, diff --git a/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json b/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json deleted file mode 100644 index d928c0284..000000000 --- a/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 768, - "quick_gelu": true, - "vision_cfg": { - "image_size": 336, - "layers": 24, - "width": 1024, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-14-CLIPS-224.json b/src/open_clip/model_configs/ViT-L-14-CLIPS-224.json index 94f75cd7d..f4b5f85e2 100644 --- a/src/open_clip/model_configs/ViT-L-14-CLIPS-224.json +++ b/src/open_clip/model_configs/ViT-L-14-CLIPS-224.json @@ -8,7 +8,10 @@ "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", - "final_ln_after_pool": true + "final_ln_after_pool": true, + "norm_kwargs": { + "eps": 1e-6 + } }, "text_cfg": { "context_length": 80, @@ -24,6 +27,9 @@ "no_causal_mask": true, "act_kwargs": { "approximate": "tanh" + }, + "norm_kwargs": { + "eps": 1e-6 } } }, diff --git a/src/open_clip/model_configs/ViT-L-14-CLIPS-336.json b/src/open_clip/model_configs/ViT-L-14-CLIPS-336.json index 2cef0d2ce..a166a4ba8 100644 --- a/src/open_clip/model_configs/ViT-L-14-CLIPS-336.json +++ b/src/open_clip/model_configs/ViT-L-14-CLIPS-336.json @@ -8,7 +8,10 @@ "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", - "final_ln_after_pool": true + "final_ln_after_pool": true, + "norm_kwargs": { + "eps": 1e-6 + } }, "text_cfg": { "context_length": 80, @@ -24,6 +27,9 @@ "no_causal_mask": true, "act_kwargs": { "approximate": "tanh" + }, + "norm_kwargs": { + "eps": 1e-6 } } }, diff --git a/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json b/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json deleted file mode 100644 index 6bc14fabc..000000000 --- a/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "embed_dim": 1152, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 378, - "timm_model_name": "vit_so400m_patch14_siglip_378", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 1152, - "heads": 16, - "layers": 27, - "mlp_ratio": 3.7362, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json b/src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json deleted file mode 100644 index 4e39b1b46..000000000 --- a/src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "embed_dim": 1152, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 256, - "timm_model_name": "vit_so400m_patch16_siglip_256", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 250000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 1152, - "heads": 16, - "layers": 27, - "mlp_ratio": 3.7362, - "no_causal_mask": true, - "pool_type": "last", - "proj_type": "none", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-bigG-14-quickgelu.json b/src/open_clip/model_configs/ViT-bigG-14-quickgelu.json deleted file mode 100644 index fed567cc6..000000000 --- a/src/open_clip/model_configs/ViT-bigG-14-quickgelu.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "embed_dim": 1280, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": 48, - "width": 1664, - "head_width": 104, - "mlp_ratio": 4.9231, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1280, - "heads": 20, - "layers": 32 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-L-384.json b/src/open_clip/model_configs/ViTamin-L-384.json deleted file mode 100644 index 1278d8393..000000000 --- a/src/open_clip/model_configs/ViTamin-L-384.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "timm_model_name": "vitamin_large_384", - "timm_model_pretrained": false, - "timm_pool": "", - "timm_proj": "linear", - "timm_drop": 0.0, - "timm_drop_path": 0.1, - "image_size": 384 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - }, - "custom_text": true -} diff --git a/src/open_clip/model_configs/ViTamin-L2-384.json b/src/open_clip/model_configs/ViTamin-L2-384.json deleted file mode 100644 index cc0faaae7..000000000 --- a/src/open_clip/model_configs/ViTamin-L2-384.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "timm_model_name": "vitamin_large2_384", - "timm_model_pretrained": false, - "timm_pool": "", - "timm_proj": "linear", - "timm_drop": 0.0, - "timm_drop_path": 0.1, - "image_size": 384 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - }, - "custom_text": true -} diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index c3629ef4a..4dcbf4ae5 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -1,21 +1,12 @@ -import copy import hashlib import os import urllib import warnings from functools import partial -from typing import Dict, Iterable, Optional, Union +from typing import Dict, Union from tqdm import tqdm - -try: - import safetensors.torch - _has_safetensors = True -except ImportError: - _has_safetensors = False - - from .constants import ( IMAGENET_MEAN, IMAGENET_STD, @@ -23,8 +14,6 @@ INCEPTION_STD, OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, - HF_WEIGHTS_NAME, - HF_SAFE_WEIGHTS_NAME, ) from .version import __version__ @@ -92,81 +81,60 @@ def _mccfg(url='', hf_hub='', **kwargs): _RN50 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - hf_hub="timm/resnet50_clip.openai/", - quick_gelu=True, - ), + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), yfcc15m=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", - hf_hub="timm/resnet50_clip.yfcc15m/", - quick_gelu=True, - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), cc12m=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", - hf_hub="timm/resnet50_clip.cc12m/", - quick_gelu=True, - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), ) _RN101 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - hf_hub="timm/resnet101_clip.openai/", - quick_gelu=True, - ), + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), yfcc15m=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", - hf_hub="timm/resnet101_clip.yfcc15m/", - quick_gelu=True, - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), ) _RN50x4 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", - hf_hub="timm/resnet50x4_clip.openai/", - quick_gelu=True, - ), + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), ) _RN50x16 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", - hf_hub="timm/resnet50x16_clip.openai/", - quick_gelu=True, - ), + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), ) _RN50x64 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", - hf_hub="timm/resnet50x64_clip.openai/", - quick_gelu=True, - ), + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), ) _VITB32 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - hf_hub="timm/vit_base_patch32_clip_224.openai/", - quick_gelu=True, - ), - # LAION 400M (quick gelu) + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), laion400m_e31=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", - hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", - quick_gelu=True, - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), laion400m_e32=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", - hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", - quick_gelu=True, - ), - # LAION 2B-en + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), laion2b_e16=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", - hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), @@ -186,17 +154,19 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), - # MetaClip models (NOTE quick-gelu activation used) +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), metaclip_400m=_pcfg( - url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", - hf_hub="timm/vit_base_patch32_clip_224.metaclip_400m/", - quick_gelu=True, - ), + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"), metaclip_fullcc=_pcfg( - url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", - hf_hub="timm/vit_base_patch32_clip_224.metaclip_2pt5b/", - quick_gelu=True, - ), + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"), ) _VITB32_256 = dict( @@ -205,20 +175,11 @@ def _mccfg(url='', hf_hub='', **kwargs): _VITB16 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", - hf_hub="timm/vit_base_patch16_clip_224.openai/", - quick_gelu=True, - ), - # LAION-400M + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), laion400m_e31=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", - hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), laion400m_e32=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", - hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", - ), - # LAION-2B + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), @@ -231,50 +192,30 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), # DFN - dfn2b=_pcfg( - hf_hub='apple/DFN2B-CLIP-ViT-B-16/', - quick_gelu=True, - ), - # MetaCLIP (these are quick-gelu) + dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/') +) + +_VITB16_quickgelu = dict( metaclip_400m=_pcfg( - url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", - hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", - quick_gelu=True, - ), + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"), metaclip_fullcc=_pcfg( - url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", - hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", - quick_gelu=True, - ), + "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", - hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), laion400m_e32=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", - hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), ) _VITL14 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", - hf_hub="timm/vit_large_patch14_clip_224.openai/", - quick_gelu=True, - ), - # LAION-400M + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), laion400m_e31=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", - hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", - ), + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), laion400m_e32=_pcfg( - url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", - hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", - ), - # LAION-2B-en + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', mean=INCEPTION_MEAN, std=INCEPTION_STD), @@ -283,55 +224,38 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), - # MetaCLIP +) + +_VITL14_quickgelu = dict( metaclip_400m=_pcfg( - url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", - hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", - quick_gelu=True, - ), + "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"), metaclip_fullcc=_pcfg( - url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", - hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", - quick_gelu=True, - ), - # DFN-2B (quick-gelu) - dfn2b=_pcfg( - hf_hub='apple/DFN2B-CLIP-ViT-L-14/', - quick_gelu=True, - ), + "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"), + dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'), ) _VITL14_336 = dict( openai=_pcfg( - url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", - hf_hub="timm/vit_large_patch14_clip_336.openai/", - quick_gelu=True, - ), + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), ) _VITH14 = dict( - # LAION-2B-en laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), - # MetaCLIP (quick-gelu) +) + +_VITH14_quickgelu = dict( metaclip_fullcc=_pcfg( - url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", - hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", - quick_gelu=True, - ), - # DFN-5B (quick-gelu) + "https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"), dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14/', - quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), ) -_VITH14_378 = dict( - # DFN-5B (quick-gelu) +_VITH14_378_quickgelu = dict( dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', - quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), @@ -343,14 +267,7 @@ def _mccfg(url='', hf_hub='', **kwargs): ) _VITbigG14 = dict( - # LAION-2B-en laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), - # MetaCLIP (quick-gelu) - metaclip_fullcc=_pcfg( - url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', - hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", - quick_gelu=True, - ), ) _robertaViTB32 = dict( @@ -408,19 +325,25 @@ def _mccfg(url='', hf_hub='', **kwargs): _PRETRAINED = { "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, "RN50x4": _RN50x4, "RN50x16": _RN50x16, "RN50x64": _RN50x64, "ViT-B-32": _VITB32, "ViT-B-32-256": _VITB32_256, + "ViT-B-32-quickgelu": _VITB32_quickgelu, "ViT-B-16": _VITB16, + "ViT-B-16-quickgelu": _VITB16_quickgelu, "ViT-B-16-plus-240": _VITB16_PLUS_240, "ViT-L-14": _VITL14, + "ViT-L-14-quickgelu": _VITL14_quickgelu, "ViT-L-14-336": _VITL14_336, "ViT-H-14": _VITH14, - "ViT-H-14-378": _VITH14_378, + "ViT-H-14-quickgelu": _VITH14_quickgelu, + "ViT-H-14-378-quickgelu": _VITH14_378_quickgelu, "ViT-g-14": _VITg14, "ViT-bigG-14": _VITbigG14, @@ -491,12 +414,6 @@ def _mccfg(url='', hf_hub='', **kwargs): "ViT-SO400M-14-SigLIP": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), ), - "ViT-SO400M-16-SigLIP-i18n-256": dict( - webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'), - ), - "ViT-SO400M-14-SigLIP-378": dict( - webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used - ), "ViT-SO400M-14-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), ), @@ -593,15 +510,6 @@ def _mccfg(url='', hf_hub='', **kwargs): ), } -_PRETRAINED_quickgelu = {} -for k, v in _PRETRAINED.items(): - quick_gelu_tags = {} - for tk, tv in v.items(): - if tv.get('quick_gelu', False): - quick_gelu_tags[tk] = copy.deepcopy(tv) - if quick_gelu_tags: - _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags -_PRETRAINED.update(_PRETRAINED_quickgelu) def _clean_tag(tag: str): # normalize pretrained tags @@ -653,7 +561,7 @@ def get_pretrained_url(model: str, tag: str): def download_pretrained_from_url( url: str, - cache_dir: Optional[str] = None, + cache_dir: Union[str, None] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") @@ -705,70 +613,30 @@ def has_hf_hub(necessary=False): return _has_hf_hub -def _get_safe_alternatives(filename: str) -> Iterable[str]: - """Returns potential safetensors alternatives for a given filename. - - Use case: - When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. - """ - if filename == HF_WEIGHTS_NAME: - yield HF_SAFE_WEIGHTS_NAME - - if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): - yield filename[:-4] + ".safetensors" - - def download_pretrained_from_hf( model_id: str, - filename: Optional[str] = None, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, ): has_hf_hub(True) - - filename = filename or HF_WEIGHTS_NAME - - # Look for .safetensors alternatives and load from it if it exists - if _has_safetensors: - for safe_filename in _get_safe_alternatives(filename): - try: - cached_file = hf_hub_download( - repo_id=model_id, - filename=safe_filename, - revision=revision, - cache_dir=cache_dir, - ) - return cached_file - except Exception: - pass - - try: - # Attempt to download the file - cached_file = hf_hub_download( - repo_id=model_id, - filename=filename, - revision=revision, - cache_dir=cache_dir, - ) - return cached_file # Return the path to the downloaded file if successful - except Exception as e: - raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file def download_pretrained( cfg: Dict, - prefer_hf_hub: bool = True, - cache_dir: Optional[str] = None, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, ): target = '' if not cfg: return target - has_hub = has_hf_hub() download_url = cfg.get('url', '') download_hf_hub = cfg.get('hf_hub', '') - if has_hub and prefer_hf_hub and download_hf_hub: - # prefer to use HF hub, remove url info + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists download_url = '' if download_url: diff --git a/src/open_clip/push_to_hf_hub.py b/src/open_clip/push_to_hf_hub.py index 6a8eeedb9..26b5594e4 100644 --- a/src/open_clip/push_to_hf_hub.py +++ b/src/open_clip/push_to_hf_hub.py @@ -1,5 +1,6 @@ import argparse import json +import os from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union @@ -27,10 +28,14 @@ except ImportError: _has_safetensors = False -from .constants import HF_WEIGHTS_NAME, HF_SAFE_WEIGHTS_NAME, HF_CONFIG_NAME from .factory import create_model_from_pretrained, get_model_config, get_tokenizer from .tokenizer import HFTokenizer +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' + def save_config_for_hf( model, @@ -114,7 +119,6 @@ def push_to_hf_hub( try: repo_files = set(list_repo_files(repo_id)) repo_exists = True - print('Repo exists', repo_files) except Exception as e: print('Repo does not exist', e) @@ -189,7 +193,7 @@ def push_pretrained_to_hf_hub( tokenizer = get_tokenizer(model_name) if hf_tokenizer_self: # make hf tokenizer config in the uploaded model point to self instead of original location - model_config['text_cfg']['hf_tokenizer_name'] = repo_id + model_config['text']['hf_tokenizer_name'] = repo_id push_to_hf_hub( model=model, @@ -312,7 +316,6 @@ def generate_readme(model_card: dict, model_name: str): image_std=args.image_std, image_interpolation=args.image_interpolation, image_resize_mode=args.image_resize_mode, - hf_tokenizer_self=args.hf_tokenizer_self, ) print(f'{args.model} saved.') diff --git a/src/open_clip/timm_model.py b/src/open_clip/timm_model.py index 975e37a9d..5ddb9a76b 100644 --- a/src/open_clip/timm_model.py +++ b/src/open_clip/timm_model.py @@ -10,16 +10,15 @@ try: import timm + from timm.models.layers import Mlp, to_2tuple try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: # new timm imports >= 0.8.1 from timm.layers import RotAttentionPool2d from timm.layers import AttentionPool2d as AbsAttentionPool2d - from timm.layers import Mlp, to_2tuple - except ImportError as e: - # fallback, try old timm imports < 0.8.1 - from timm.models.layers.attention_pool2d import RotAttentionPool2d - from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d - from timm.models.layers import Mlp, to_2tuple except ImportError: timm = None diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 78a581fd7..d06fbdbfd 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -455,7 +455,6 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, - eps: float = 1e-5 # Add eps as a parameter ): super().__init__() assert pool_type in ('tok', 'avg', 'none') @@ -488,7 +487,7 @@ def __init__( # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() - self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width, eps=eps) + self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) self.transformer = Transformer( width, layers, @@ -496,7 +495,7 @@ def __init__( mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, - norm_layer=lambda x: norm_layer(x, eps=eps), + norm_layer=norm_layer, ) if attentional_pool: @@ -534,7 +533,7 @@ def __init__( pool_dim = width self.pool_type = pool_type - self.ln_post = norm_layer(pool_dim, eps=eps) + self.ln_post = norm_layer(pool_dim) self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) self.init_parameters() @@ -694,7 +693,6 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, - eps: float = 1e-5 # Add eps as a parameter ): super().__init__() assert pool_type in ('first', 'last', 'argmax', 'none') @@ -721,9 +719,9 @@ def __init__( mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, - norm_layer=lambda x: norm_layer(x, eps=eps), + norm_layer=norm_layer, ) - self.ln_final = norm_layer(width, eps=eps) + self.ln_final = norm_layer(width) if no_causal_mask: self.attn_mask = None @@ -925,4 +923,4 @@ def forward(self, image_embs, text_embs): @torch.jit.ignore def set_grad_checkpointing(self, enable=True): - self.grad_checkpointing = enable + self.grad_checkpointing = enable \ No newline at end of file diff --git a/src/open_clip/version.py b/src/open_clip/version.py index 92e0aea0c..acb6e2fcc 100644 --- a/src/open_clip/version.py +++ b/src/open_clip/version.py @@ -1 +1 @@ -__version__ = '2.29.0' +__version__ = '2.26.1'