Skip to content

Commit

Permalink
add eps in config
Browse files Browse the repository at this point in the history
  • Loading branch information
Yanqing0327 committed Dec 17, 2024
1 parent fa9152f commit d1cba05
Show file tree
Hide file tree
Showing 24 changed files with 167 additions and 558 deletions.
6 changes: 2 additions & 4 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
Expand Down Expand Up @@ -132,10 +131,9 @@ def __init__(
cast_dtype=cast_dtype,
)

lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
Expand Down
5 changes: 0 additions & 5 deletions src/open_clip/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,3 @@
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)

# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
HF_CONFIG_NAME = 'open_clip_config.json'
48 changes: 19 additions & 29 deletions src/open_clip/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def _n2p(w, t=True, idx=None):
if idx is not None:
w = w[idx]
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
Expand Down Expand Up @@ -68,28 +66,21 @@ def _convert_timm_img(module, prefix):

mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))

module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
Expand Down Expand Up @@ -138,14 +129,13 @@ def _convert_openclip_txt(module: TextTransformer, prefix):
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
if module.text_projection is not None:
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'img/')
_convert_openclip_txt(model.text, 'txt/')
model.logit_bias.copy_(_n2p(w['b'])[0])
model.logit_scale.copy_(_n2p(w['t'])[0])
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])


@torch.no_grad()
Expand Down
3 changes: 0 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,6 @@ def create_model(
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
if 'CLIPS' in model_name:
model_cfg['vision_cfg']['eps'] = 1e-6
model_cfg['text_cfg']['eps'] = 1e-6
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

if precision in ("fp16", "bf16"):
Expand Down
80 changes: 23 additions & 57 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import torch
import torch.nn as nn
from torch.nn import functional as F
Expand Down Expand Up @@ -104,14 +102,8 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features,
text_features,
local_loss=self.local_loss,
gather_with_grad=self.gather_with_grad,
rank=self.rank,
world_size=self.world_size,
use_horovod=self.use_horovod,
)
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
Expand Down Expand Up @@ -166,11 +158,12 @@ def __init__(
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):

clip_loss = torch.tensor(0)

if self.clip_loss_weight:
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss
else:
clip_loss = torch.tensor(0, device=logits.device)

caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
Expand Down Expand Up @@ -323,17 +316,19 @@ class SigLipLoss(nn.Module):
"""
def __init__(
self,
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1,
dist_impl: Optional[str] = None,
cache_labels=False,
rank=0,
world_size=1,
bidir=True,
use_horovod=False,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir

# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
Expand Down Expand Up @@ -366,9 +361,10 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
loss = self._loss(image_features, text_features, logit_scale, logit_bias)

if self.world_size > 1:
if self.dist_impl == 'bidir':
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
# exchange text features w/ neighbour world_size - 1 times
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
Expand All @@ -378,6 +374,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
text_features_to_left,
text_features_to_right,
)

for f in text_features_recv:
loss += self._loss(
image_features,
Expand All @@ -390,27 +387,21 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output

if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank,
right_rank,
text_features_to_right
)
left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
elif self.dist_impl == "shift":
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
else:
text_features_to_right = text_features
for i in range(self.world_size - 1):
text_features_from_left = neighbour_exchange_with_grad(
left_rank,
right_rank,
text_features_to_right,
)
left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_from_left,
Expand All @@ -419,30 +410,5 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
negative_only=True,
)
text_features_to_right = text_features_from_left
elif self.dist_impl == "reduce":
for i in range(self.world_size):
text_from_other = torch.distributed.nn.all_reduce(
text_features * (self.rank == i),
torch.distributed.ReduceOp.SUM,
)
loss += float(i != self.rank) * self._loss(
image_features,
text_from_other,
logit_scale,
logit_bias,
negative_only=True,
)
elif self.dist_impl == "gather":
all_text = torch.distributed.nn.all_gather(text_features)
for i in range(self.world_size):
loss += float(i != self.rank) * self._loss(
image_features,
all_text[i],
logit_scale,
logit_bias,
negative_only=True,
)
else:
assert False

return {"contrastive_loss": loss} if output_dict else loss
4 changes: 0 additions & 4 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class CLIPVisionCfg:
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
eps: float = 1e-5

ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
Expand Down Expand Up @@ -77,7 +76,6 @@ class CLIPTextCfg:
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
eps: float = 1e-5

# HuggingFace specific text tower config
hf_model_name: Optional[str] = None
Expand Down Expand Up @@ -168,7 +166,6 @@ def _build_vision_tower(
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
eps=vision_cfg.eps,
)

return visual
Expand Down Expand Up @@ -218,7 +215,6 @@ def _build_text_tower(
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
eps=text_cfg.eps,
)
return text

Expand Down
22 changes: 0 additions & 22 deletions src/open_clip/model_configs/RN50x16-quickgelu.json

This file was deleted.

22 changes: 0 additions & 22 deletions src/open_clip/model_configs/RN50x4-quickgelu.json

This file was deleted.

22 changes: 0 additions & 22 deletions src/open_clip/model_configs/RN50x64-quickgelu.json

This file was deleted.

17 changes: 0 additions & 17 deletions src/open_clip/model_configs/ViT-H-14-378.json

This file was deleted.

Loading

0 comments on commit d1cba05

Please sign in to comment.