Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLIPS to open_clip #1008

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
Expand Down Expand Up @@ -132,10 +131,9 @@ def __init__(
cast_dtype=cast_dtype,
)

lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
Expand Down
5 changes: 0 additions & 5 deletions src/open_clip/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,3 @@
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)

# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
HF_CONFIG_NAME = 'open_clip_config.json'
48 changes: 19 additions & 29 deletions src/open_clip/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def _n2p(w, t=True, idx=None):
if idx is not None:
w = w[idx]
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
Expand Down Expand Up @@ -68,28 +66,21 @@ def _convert_timm_img(module, prefix):

mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))

module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
Expand Down Expand Up @@ -138,14 +129,13 @@ def _convert_openclip_txt(module: TextTransformer, prefix):
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
if module.text_projection is not None:
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'img/')
_convert_openclip_txt(model.text, 'txt/')
model.logit_bias.copy_(_n2p(w['b'])[0])
model.logit_scale.copy_(_n2p(w['t'])[0])
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])


@torch.no_grad()
Expand Down
13 changes: 9 additions & 4 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
from .tokenizer import HFTokenizer, SimpleTokenizer, CLIPS_Tokenizer, DEFAULT_CONTEXT_LENGTH

HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
Expand Down Expand Up @@ -123,12 +123,17 @@ def get_tokenizer(
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)

if 'hf_tokenizer_name' in text_config:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
if 'CLIPS' in model_name:
tokenizer = CLIPS_Tokenizer(
context_length=context_length,
cache_dir=cache_dir,
**tokenizer_kwargs,
)
else:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer(
context_length=context_length,
Expand Down
80 changes: 23 additions & 57 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import torch
import torch.nn as nn
from torch.nn import functional as F
Expand Down Expand Up @@ -104,14 +102,8 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features,
text_features,
local_loss=self.local_loss,
gather_with_grad=self.gather_with_grad,
rank=self.rank,
world_size=self.world_size,
use_horovod=self.use_horovod,
)
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
Expand Down Expand Up @@ -166,11 +158,12 @@ def __init__(
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):

clip_loss = torch.tensor(0)

if self.clip_loss_weight:
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss
else:
clip_loss = torch.tensor(0, device=logits.device)

caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
Expand Down Expand Up @@ -323,17 +316,19 @@ class SigLipLoss(nn.Module):
"""
def __init__(
self,
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1,
dist_impl: Optional[str] = None,
cache_labels=False,
rank=0,
world_size=1,
bidir=True,
use_horovod=False,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir

# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
Expand Down Expand Up @@ -366,9 +361,10 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
loss = self._loss(image_features, text_features, logit_scale, logit_bias)

if self.world_size > 1:
if self.dist_impl == 'bidir':
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
# exchange text features w/ neighbour world_size - 1 times
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
Expand All @@ -378,6 +374,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
text_features_to_left,
text_features_to_right,
)

for f in text_features_recv:
loss += self._loss(
image_features,
Expand All @@ -390,27 +387,21 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output

if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank,
right_rank,
text_features_to_right
)
left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
elif self.dist_impl == "shift":
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
else:
text_features_to_right = text_features
for i in range(self.world_size - 1):
text_features_from_left = neighbour_exchange_with_grad(
left_rank,
right_rank,
text_features_to_right,
)
left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_from_left,
Expand All @@ -419,30 +410,5 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
negative_only=True,
)
text_features_to_right = text_features_from_left
elif self.dist_impl == "reduce":
for i in range(self.world_size):
text_from_other = torch.distributed.nn.all_reduce(
text_features * (self.rank == i),
torch.distributed.ReduceOp.SUM,
)
loss += float(i != self.rank) * self._loss(
image_features,
text_from_other,
logit_scale,
logit_bias,
negative_only=True,
)
elif self.dist_impl == "gather":
all_text = torch.distributed.nn.all_gather(text_features)
for i in range(self.world_size):
loss += float(i != self.rank) * self._loss(
image_features,
all_text[i],
logit_scale,
logit_bias,
negative_only=True,
)
else:
assert False

return {"contrastive_loss": loss} if output_dict else loss
22 changes: 0 additions & 22 deletions src/open_clip/model_configs/RN50x16-quickgelu.json

This file was deleted.

22 changes: 0 additions & 22 deletions src/open_clip/model_configs/RN50x4-quickgelu.json

This file was deleted.

22 changes: 0 additions & 22 deletions src/open_clip/model_configs/RN50x64-quickgelu.json

This file was deleted.

17 changes: 0 additions & 17 deletions src/open_clip/model_configs/ViT-H-14-378.json

This file was deleted.

Loading
Loading