Skip to content

Commit

Permalink
All default pretrained weights pushed to HF hub, stragglers uploaded …
Browse files Browse the repository at this point in the history
…to timm org for simplicity.

* OpenAI models no longer use special path that loads from torchscript archive, use same path as other models
* Handling of QuickGELU consistent btw openai and non, made a bit more safe, warn on mismatch
* safetensors is default weight load if available
  • Loading branch information
rwightman committed Oct 23, 2024
1 parent 427c434 commit 2a61d34
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 182 deletions.
252 changes: 152 additions & 100 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import re
import warnings
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
Expand Down Expand Up @@ -222,8 +223,58 @@ def create_model(
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
load_weights_only: bool = True,
**model_kwargs,
):
"""Creates and configures a contrastive vision-language model.
Args:
model_name: Name of the model architecture to create. Can be a local model name
or a Hugging Face model ID prefixed with 'hf-hub:'.
pretrained: Tag/path for pretrained model weights. Can be:
- A pretrained tag name (e.g., 'openai')
- A path to local weights
- None to initialize with random weights
precision: Model precision/AMP configuration. Options:
- 'fp32': 32-bit floating point
- 'fp16'/'bf16': Mixed precision with FP32 for certain layers
- 'pure_fp16'/'pure_bf16': Pure 16-bit precision
device: Device to load the model on ('cpu', 'cuda', or torch.device object)
jit: If True, JIT compile the model
force_quick_gelu: Force use of QuickGELU activation
force_custom_text: Force use of custom text encoder
force_patch_dropout: Override default patch dropout value
force_image_size: Override default image size for vision encoder
force_preprocess_cfg: Override default preprocessing configuration
pretrained_image: Load pretrained weights for timm vision models
pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights
cache_dir: Override default cache directory for downloaded model files
output_dict: If True and model supports it, return dictionary of features
require_pretrained: Raise error if pretrained weights cannot be loaded
load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)
**model_kwargs: Additional keyword arguments passed to model constructor
Returns:
Created and configured model instance
Raises:
RuntimeError: If model config is not found or required pretrained weights
cannot be loaded
Examples:
# Create basic CLIP model
model = create_model('ViT-B/32')
# Create CLIP model with mixed precision on GPU
model = create_model('ViT-B/32', precision='fp16', device='cuda')
# Load pretrained OpenAI weights
model = create_model('ViT-B/32', pretrained='openai')
# Load Hugging Face model
model = create_model('hf-hub:organization/model-name')
"""

force_preprocess_cfg = force_preprocess_cfg or {}
preprocess_cfg = asdict(PreprocessCfg())
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
Expand All @@ -242,112 +293,113 @@ def create_model(
if isinstance(device, str):
device = torch.device(device)

if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
cache_dir=cache_dir,
)
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')

if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True

if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout

if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size

is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')

if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True

if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout

if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size

is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'

# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
if custom_text:
if "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
assert False, 'pretrained image towers currently only supported for timm models'

# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
if custom_text:
if "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
from .transformer import LayerNormFp32

def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
from .transformer import LayerNormFp32

def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)

pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
elif os.path.exists(pretrained):
checkpoint_path = pretrained

if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True

if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)

pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)
model_quick_gelu = model_cfg.get('quick_gelu', False)
if pretrained_quick_gelu and not model_quick_gelu:
warnings.warn(
f'These pretrained weights were trained with QuickGELU activation but the model config does '
f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.')
elif not pretrained_quick_gelu and model_quick_gelu:
warnings.warn(
f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '
f'model config, consider using a model config without QuickGELU or disable override flags.')
elif os.path.exists(pretrained):
checkpoint_path = pretrained

if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
pretrained_loaded = True

if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')

if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
Expand Down
22 changes: 22 additions & 0 deletions src/open_clip/model_configs/RN50x16-quickgelu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"embed_dim": 768,
"quick_gelu": true,
"vision_cfg": {
"image_size": 384,
"layers": [
6,
8,
18,
8
],
"width": 96,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
22 changes: 22 additions & 0 deletions src/open_clip/model_configs/RN50x4-quickgelu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"embed_dim": 640,
"quick_gelu": true,
"vision_cfg": {
"image_size": 288,
"layers": [
4,
6,
10,
6
],
"width": 80,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
22 changes: 22 additions & 0 deletions src/open_clip/model_configs/RN50x64-quickgelu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 448,
"layers": [
3,
15,
36,
10
],
"width": 128,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 12
}
}
17 changes: 17 additions & 0 deletions src/open_clip/model_configs/ViT-H-14-378.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 378,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
17 changes: 17 additions & 0 deletions src/open_clip/model_configs/ViT-L-14-336-quickgelu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"embed_dim": 768,
"quick_gelu": true,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
Loading

0 comments on commit 2a61d34

Please sign in to comment.