Skip to content

Commit

Permalink
Add CLIPS to open_clip
Browse files Browse the repository at this point in the history
  • Loading branch information
Yanqing0327 committed Dec 13, 2024
1 parent aeaf2a0 commit fa9152f
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 9 deletions.
16 changes: 12 additions & 4 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
from .tokenizer import HFTokenizer, SimpleTokenizer, CLIPS_Tokenizer, DEFAULT_CONTEXT_LENGTH

HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
Expand Down Expand Up @@ -123,12 +123,17 @@ def get_tokenizer(
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)

if 'hf_tokenizer_name' in text_config:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
if 'CLIPS' in model_name:
tokenizer = CLIPS_Tokenizer(
context_length=context_length,
cache_dir=cache_dir,
**tokenizer_kwargs,
)
else:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer(
context_length=context_length,
Expand Down Expand Up @@ -341,6 +346,9 @@ def create_model(
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
if 'CLIPS' in model_name:
model_cfg['vision_cfg']['eps'] = 1e-6
model_cfg['text_cfg']['eps'] = 1e-6
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

if precision in ("fp16", "bf16"):
Expand Down
4 changes: 4 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class CLIPVisionCfg:
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
eps: float = 1e-5

ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
Expand Down Expand Up @@ -76,6 +77,7 @@ class CLIPTextCfg:
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
eps: float = 1e-5

# HuggingFace specific text tower config
hf_model_name: Optional[str] = None
Expand Down Expand Up @@ -166,6 +168,7 @@ def _build_vision_tower(
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
eps=vision_cfg.eps,
)

return visual
Expand Down Expand Up @@ -215,6 +218,7 @@ def _build_text_tower(
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
eps=text_cfg.eps,
)
return text

Expand Down
45 changes: 45 additions & 0 deletions src/open_clip/model_configs/ViT-H-14-CLIPS-224.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"model_cfg": {
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 80,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1024,
"heads": 16,
"layers": 24,
"pool_type": "last",
"no_causal_mask": true,
"act_kwargs": {
"approximate": "tanh"
}
}
},
"preprocess_cfg": {
"mean": [
0.485,
0.456,
0.406
],
"std": [
0.229,
0.224,
0.225
],
"interpolation": "bilinear",
"resize_mode": "squash"
}
}
44 changes: 44 additions & 0 deletions src/open_clip/model_configs/ViT-L-14-CLIPS-224.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"model_cfg": {
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 80,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 768,
"heads": 12,
"layers": 12,
"pool_type": "last",
"no_causal_mask": true,
"act_kwargs": {
"approximate": "tanh"
}
}
},
"preprocess_cfg": {
"mean": [
0.485,
0.456,
0.406
],
"std": [
0.229,
0.224,
0.225
],
"interpolation": "bilinear",
"resize_mode": "squash"
}
}
44 changes: 44 additions & 0 deletions src/open_clip/model_configs/ViT-L-14-CLIPS-336.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"model_cfg": {
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 80,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 768,
"heads": 12,
"layers": 12,
"pool_type": "last",
"no_causal_mask": true,
"act_kwargs": {
"approximate": "tanh"
}
}
},
"preprocess_cfg": {
"mean": [
0.485,
0.456,
0.406
],
"std": [
0.229,
0.224,
0.225
],
"interpolation": "bilinear",
"resize_mode": "squash"
}
}
83 changes: 83 additions & 0 deletions src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,89 @@ def get_reduction_mask_fn(type: str):
return syntax_mask_tokenize # randomly drop prioritized by syntax


from tokenizers import BertWordPieceTokenizer

class CustomTokenizer:
"""Custom tokenizer using WordPiece-based subword tokenization"""

def __init__(self, vocab_file, context_length=512, bos_token=1, eos_token=2, class_token=101, pad_token=0):
self.tokenizer = BertWordPieceTokenizer(lowercase=True)
self.tokenizer = self.tokenizer.from_file(vocab_file)
self.context_length = context_length
self.bos_token = bos_token
self.eos_token = eos_token
self.class_token = class_token
self.pad_token = pad_token

def tokenize(self, text):
encoding = self.tokenizer.encode(text, add_special_tokens=False)
tokens = encoding.ids[:self.context_length - 3]
return [self.bos_token] + tokens + [self.eos_token]

def batch_encode_plus(self, texts, max_length=None):
max_length = max_length or self.context_length
encoded = [self.tokenize(text) for text in texts]
import torch
return {
'input_ids': torch.tensor([self.pad_and_add_class_token(e, max_length) for e in encoded])
}

def pad_and_add_class_token(self, encoded_text, max_length):
if len(encoded_text) < max_length - 1:
encoded_text += [self.pad_token] * (max_length - 1 - len(encoded_text))
return encoded_text + [self.class_token]

class CLIPS_Tokenizer:
"""HuggingFace tokenizer wrapper"""

def __init__(
self,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'whitespace',
strip_sep_token: bool = False,
language: Optional[str] = None,
**kwargs
):
vocab_file = './vocab.txt'
self.tokenizer = CustomTokenizer(vocab_file, context_length=80, bos_token=1, eos_token=2, class_token=101, pad_token=0)
print("Load CLIPS Tokenizer.")
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
if language is not None:
self.set_language(language)
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token

def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)

def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, 'Please set a valid context length in class init or call.'

texts = [self.clean_fn(text) for text in texts]
encoded_outputs = self.tokenizer.batch_encode_plus(
texts,
max_length=context_length
)

input_ids = encoded_outputs['input_ids']

return input_ids

def set_language(self, src_lang):
if hasattr(self, 'set_lang_fn'):
self.set_lang_fn(src_lang)
else:
warnings.warn('Cannot set language for the tokenizer.')


class HFTokenizer:
"""HuggingFace tokenizer wrapper"""

Expand Down
12 changes: 7 additions & 5 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def __init__(
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
eps: float = 1e-5 # Add eps as a parameter
):
super().__init__()
assert pool_type in ('tok', 'avg', 'none')
Expand Down Expand Up @@ -487,15 +488,15 @@ def __init__(
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()

self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width, eps=eps)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer=lambda x: norm_layer(x, eps=eps),
)

if attentional_pool:
Expand Down Expand Up @@ -533,7 +534,7 @@ def __init__(
pool_dim = width
self.pool_type = pool_type

self.ln_post = norm_layer(pool_dim)
self.ln_post = norm_layer(pool_dim, eps=eps)
self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))

self.init_parameters()
Expand Down Expand Up @@ -693,6 +694,7 @@ def __init__(
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
eps: float = 1e-5 # Add eps as a parameter
):
super().__init__()
assert pool_type in ('first', 'last', 'argmax', 'none')
Expand All @@ -719,9 +721,9 @@ def __init__(
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer=lambda x: norm_layer(x, eps=eps),
)
self.ln_final = norm_layer(width)
self.ln_final = norm_layer(width, eps=eps)

if no_causal_mask:
self.attn_mask = None
Expand Down

0 comments on commit fa9152f

Please sign in to comment.