diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 2a722eec6..3e40355a6 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -2,6 +2,7 @@ import logging import os import re +import warnings from copy import deepcopy from dataclasses import asdict from pathlib import Path @@ -222,8 +223,58 @@ def create_model( cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, + load_weights_only: bool = True, **model_kwargs, ): + """Creates and configures a contrastive vision-language model. + + Args: + model_name: Name of the model architecture to create. Can be a local model name + or a Hugging Face model ID prefixed with 'hf-hub:'. + pretrained: Tag/path for pretrained model weights. Can be: + - A pretrained tag name (e.g., 'openai') + - A path to local weights + - None to initialize with random weights + precision: Model precision/AMP configuration. Options: + - 'fp32': 32-bit floating point + - 'fp16'/'bf16': Mixed precision with FP32 for certain layers + - 'pure_fp16'/'pure_bf16': Pure 16-bit precision + device: Device to load the model on ('cpu', 'cuda', or torch.device object) + jit: If True, JIT compile the model + force_quick_gelu: Force use of QuickGELU activation + force_custom_text: Force use of custom text encoder + force_patch_dropout: Override default patch dropout value + force_image_size: Override default image size for vision encoder + force_preprocess_cfg: Override default preprocessing configuration + pretrained_image: Load pretrained weights for timm vision models + pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights + cache_dir: Override default cache directory for downloaded model files + output_dict: If True and model supports it, return dictionary of features + require_pretrained: Raise error if pretrained weights cannot be loaded + load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety) + **model_kwargs: Additional keyword arguments passed to model constructor + + Returns: + Created and configured model instance + + Raises: + RuntimeError: If model config is not found or required pretrained weights + cannot be loaded + + Examples: + # Create basic CLIP model + model = create_model('ViT-B/32') + + # Create CLIP model with mixed precision on GPU + model = create_model('ViT-B/32', precision='fp16', device='cuda') + + # Load pretrained OpenAI weights + model = create_model('ViT-B/32', pretrained='openai') + + # Load Hugging Face model + model = create_model('hf-hub:organization/model-name') + """ + force_preprocess_cfg = force_preprocess_cfg or {} preprocess_cfg = asdict(PreprocessCfg()) has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) @@ -242,112 +293,113 @@ def create_model( if isinstance(device, str): device = torch.device(device) - if pretrained and pretrained.lower() == 'openai': - logging.info(f'Loading pretrained {model_name} from OpenAI.') - model = load_openai_model( - model_name, - precision=precision, - device=device, - cache_dir=cache_dir, - ) + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') else: - model_cfg = model_cfg or get_model_config(model_name) - if model_cfg is not None: - logging.info(f'Loaded {model_name} model config.') + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True else: - logging.error(f'Model config for {model_name} not found; available models {list_models()}.') - raise RuntimeError(f'Model config for {model_name} not found.') - - if force_quick_gelu: - # override for use of QuickGELU on non-OpenAI transformer models - model_cfg["quick_gelu"] = True - - if force_patch_dropout is not None: - # override the default patch dropout value - model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout - - if force_image_size is not None: - # override model config's image size - model_cfg["vision_cfg"]["image_size"] = force_image_size - - is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) - if pretrained_image: - if is_timm_model: - # pretrained weight loading for timm models set via vision_cfg - model_cfg['vision_cfg']['timm_model_pretrained'] = True - else: - assert False, 'pretrained image towers currently only supported for timm models' - - # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes - cast_dtype = get_cast_dtype(precision) - is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) - if is_hf_model: - # load pretrained weights for HF text model IFF no CLIP weights being loaded - model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained - custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model - - model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) - if custom_text: - if "multimodal_cfg" in model_cfg: - model = CoCa(**model_cfg, cast_dtype=cast_dtype) - else: - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + if is_hf_model: + # load pretrained weights for HF text model IFF no CLIP weights being loaded + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) + if custom_text: + if "multimodal_cfg" in model_cfg: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) - - if precision in ("fp16", "bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 - # manual mixed precision that matches original OpenAI behaviour - if is_timm_model: - # FIXME this is a bit janky, create timm based model in low-precision and - # then cast only LayerNormFp32 instances back to float32 so they don't break. - # Why? The convert_weights_to_lp fn only works with native models. - model.to(device=device, dtype=dtype) - from .transformer import LayerNormFp32 - - def _convert_ln(m): - if isinstance(m, LayerNormFp32): - m.weight.data = m.weight.data.to(torch.float32) - m.bias.data = m.bias.data.to(torch.float32) - model.apply(_convert_ln) - else: - model.to(device=device) - convert_weights_to_lp(model, dtype=dtype) - elif precision in ("pure_fp16", "pure_bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) else: model.to(device=device) - - pretrained_loaded = False - if pretrained: - checkpoint_path = '' - pretrained_cfg = get_pretrained_cfg(model_name, pretrained) - if pretrained_cfg: - checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) - preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) - elif os.path.exists(pretrained): - checkpoint_path = pretrained - - if checkpoint_path: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) - else: - error_str = ( - f'Pretrained weights ({pretrained}) not found for model {model_name}.' - f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') - logging.warning(error_str) - raise RuntimeError(error_str) - pretrained_loaded = True - elif has_hf_hub_prefix: - logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') - load_checkpoint(model, checkpoint_path) - pretrained_loaded = True - - if require_pretrained and not pretrained_loaded: - # callers of create_model_from_pretrained always expect pretrained weights - raise RuntimeError( - f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) + pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False) + model_quick_gelu = model_cfg.get('quick_gelu', False) + if pretrained_quick_gelu and not model_quick_gelu: + warnings.warn( + f'These pretrained weights were trained with QuickGELU activation but the model config does ' + f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.') + elif not pretrained_quick_gelu and model_quick_gelu: + warnings.warn( + f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the ' + f'model config, consider using a model config without QuickGELU or disable override flags.') + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') if output_dict and hasattr(model, "output_dict"): model.output_dict = True diff --git a/src/open_clip/model_configs/RN50x16-quickgelu.json b/src/open_clip/model_configs/RN50x16-quickgelu.json new file mode 100644 index 000000000..989bb87c6 --- /dev/null +++ b/src/open_clip/model_configs/RN50x16-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 768, + "quick_gelu": true, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/RN50x4-quickgelu.json b/src/open_clip/model_configs/RN50x4-quickgelu.json new file mode 100644 index 000000000..9bf11fc3a --- /dev/null +++ b/src/open_clip/model_configs/RN50x4-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 640, + "quick_gelu": true, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/RN50x64-quickgelu.json b/src/open_clip/model_configs/RN50x64-quickgelu.json new file mode 100644 index 000000000..6da9d7e21 --- /dev/null +++ b/src/open_clip/model_configs/RN50x64-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-378.json b/src/open_clip/model_configs/ViT-H-14-378.json new file mode 100644 index 000000000..04b2e62d6 --- /dev/null +++ b/src/open_clip/model_configs/ViT-H-14-378.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 378, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json b/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json new file mode 100644 index 000000000..d928c0284 --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "quick_gelu": true, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index 8c89d3035..aac87619d 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -1,3 +1,4 @@ +import copy import hashlib import os import urllib @@ -91,60 +92,81 @@ def _mccfg(url='', hf_hub='', **kwargs): _RN50 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), - cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), -) - -_RN50_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + hf_hub="timm/resnet50_clip.openai/", + quick_gelu=True, + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + hf_hub="timm/resnet50_clip.yfcc15m/", + quick_gelu=True, + ), cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", + hf_hub="timm/resnet50_clip.cc12m/", + quick_gelu=True, + ), ) _RN101 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), -) - -_RN101_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + hf_hub="timm/resnet101_clip.openai/", + quick_gelu=True, + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", + hf_hub="timm/resnet101_clip.yfcc15m/", + quick_gelu=True, + ), ) _RN50x4 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), + url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + hf_hub="timm/resnet50x4_clip.openai/", + quick_gelu=True, + ), ) _RN50x16 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), + url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + hf_hub="timm/resnet50x16_clip.openai/", + quick_gelu=True, + ), ) _RN50x64 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), + url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + hf_hub="timm/resnet50x64_clip.openai/", + quick_gelu=True, + ), ) _VITB32 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + hf_hub="timm/vit_base_patch32_clip_224.openai/", + quick_gelu=True, + ), + # LAION 400M (quick gelu) laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", + quick_gelu=True, + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", + quick_gelu=True, + ), + # LAION 2B-en laion2b_e16=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", + hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", + ), laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), @@ -164,19 +186,15 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), -) - -_VITB32_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), - laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), - laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + # MetaClip models (NOTE quick-gelu activation used) metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", + quick_gelu=True, + ), metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", + quick_gelu=True, + ), ) _VITB32_256 = dict( @@ -185,11 +203,20 @@ def _mccfg(url='', hf_hub='', **kwargs): _VITB16 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + hf_hub="timm/vit_base_patch16_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", + ), + # LAION-2B laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), @@ -202,30 +229,50 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), # DFN - dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/') -) - -_VITB16_quickgelu = dict( + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-B-16/', + quick_gelu=True, + ), + # MetaCLIP (these are quick-gelu) metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", + quick_gelu=True, + ), metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), ) _VITL14 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + hf_hub="timm/vit_large_patch14_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", + ), + # LAION-2B-en laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', mean=INCEPTION_MEAN, std=INCEPTION_STD), @@ -234,38 +281,55 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), -) - -_VITL14_quickgelu = dict( + # MetaCLIP metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", + quick_gelu=True, + ), metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"), - dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + # DFN-2B (quick-gelu) + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-L-14/', + quick_gelu=True, + ), ) _VITL14_336 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), + url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", + hf_hub="timm/vit_large_patch14_clip_336.openai/", + quick_gelu=True, + ), ) _VITH14 = dict( + # LAION-2B-en laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), -) - -_VITH14_quickgelu = dict( + # MetaCLIP (quick-gelu) metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", + hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + # DFN-5B (quick-gelu) dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14/', + quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), ) -_VITH14_378_quickgelu = dict( +_VITH14_378 = dict( + # DFN-5B (quick-gelu) dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', + quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), @@ -277,11 +341,14 @@ def _mccfg(url='', hf_hub='', **kwargs): ) _VITbigG14 = dict( + # LAION-2B-en laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), -) - -_VITbigG14_quickgelu = dict( - metaclip_fullcc=_pcfg(url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt'), + # MetaCLIP (quick-gelu) + metaclip_fullcc=_pcfg( + url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', + hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), ) _robertaViTB32 = dict( @@ -339,28 +406,21 @@ def _mccfg(url='', hf_hub='', **kwargs): _PRETRAINED = { "RN50": _RN50, - "RN50-quickgelu": _RN50_quickgelu, "RN101": _RN101, - "RN101-quickgelu": _RN101_quickgelu, "RN50x4": _RN50x4, "RN50x16": _RN50x16, "RN50x64": _RN50x64, "ViT-B-32": _VITB32, "ViT-B-32-256": _VITB32_256, - "ViT-B-32-quickgelu": _VITB32_quickgelu, "ViT-B-16": _VITB16, - "ViT-B-16-quickgelu": _VITB16_quickgelu, "ViT-B-16-plus-240": _VITB16_PLUS_240, "ViT-L-14": _VITL14, - "ViT-L-14-quickgelu": _VITL14_quickgelu, "ViT-L-14-336": _VITL14_336, "ViT-H-14": _VITH14, - "ViT-H-14-quickgelu": _VITH14_quickgelu, - "ViT-H-14-378-quickgelu": _VITH14_378_quickgelu, + "ViT-H-14-378": _VITH14_378, "ViT-g-14": _VITg14, "ViT-bigG-14": _VITbigG14, - "ViT-bigG-14-quickgelu": _VITbigG14_quickgelu, "roberta-ViT-B-32": _robertaViTB32, "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, @@ -531,6 +591,15 @@ def _mccfg(url='', hf_hub='', **kwargs): ), } +_PRETRAINED_quickgelu = {} +for k, v in _PRETRAINED.items(): + quick_gelu_tags = {} + for tk, tv in v.items(): + if tv.get('quick_gelu', False): + quick_gelu_tags[tk] = copy.deepcopy(tv) + if quick_gelu_tags: + _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags +_PRETRAINED.update(_PRETRAINED_quickgelu) def _clean_tag(tag: str): # normalize pretrained tags @@ -662,7 +731,11 @@ def download_pretrained_from_hf( for safe_filename in _get_safe_alternatives(filename): try: cached_file = hf_hub_download( - repo_id=model_id, filename=safe_filename, revision=revision, cache_dir=cache_dir) + repo_id=model_id, + filename=safe_filename, + revision=revision, + cache_dir=cache_dir, + ) return cached_file except Exception: pass @@ -670,7 +743,11 @@ def download_pretrained_from_hf( try: # Attempt to download the file cached_file = hf_hub_download( - repo_id=model_id, filename=filename, revision=revision, cache_dir=cache_dir) + repo_id=model_id, + filename=filename, + revision=revision, + cache_dir=cache_dir, + ) return cached_file # Return the path to the downloaded file if successful except Exception as e: raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {e}") @@ -678,17 +755,18 @@ def download_pretrained_from_hf( def download_pretrained( cfg: Dict, - force_hf_hub: bool = False, + prefer_hf_hub: bool = True, cache_dir: Optional[str] = None, ): target = '' if not cfg: return target + has_hub = has_hf_hub() download_url = cfg.get('url', '') download_hf_hub = cfg.get('hf_hub', '') - if download_hf_hub and force_hf_hub: - # use HF hub even if url exists + if has_hub and prefer_hf_hub and download_hf_hub: + # prefer to use HF hub, remove url info download_url = '' if download_url: diff --git a/src/open_clip/push_to_hf_hub.py b/src/open_clip/push_to_hf_hub.py index 867a6d5f3..6a8eeedb9 100644 --- a/src/open_clip/push_to_hf_hub.py +++ b/src/open_clip/push_to_hf_hub.py @@ -114,6 +114,7 @@ def push_to_hf_hub( try: repo_files = set(list_repo_files(repo_id)) repo_exists = True + print('Repo exists', repo_files) except Exception as e: print('Repo does not exist', e)