From fa9152fcc09303bb9a8cd143595bd4ea71e007c5 Mon Sep 17 00:00:00 2001 From: Yanqing0327 Date: Thu, 12 Dec 2024 22:25:19 -0800 Subject: [PATCH] Add CLIPS to open_clip --- src/open_clip/factory.py | 16 +++- src/open_clip/model.py | 4 + .../model_configs/ViT-H-14-CLIPS-224.json | 45 ++++++++++ .../model_configs/ViT-L-14-CLIPS-224.json | 44 ++++++++++ .../model_configs/ViT-L-14-CLIPS-336.json | 44 ++++++++++ src/open_clip/tokenizer.py | 83 +++++++++++++++++++ src/open_clip/transformer.py | 12 +-- 7 files changed, 239 insertions(+), 9 deletions(-) create mode 100644 src/open_clip/model_configs/ViT-H-14-CLIPS-224.json create mode 100644 src/open_clip/model_configs/ViT-L-14-CLIPS-224.json create mode 100644 src/open_clip/model_configs/ViT-L-14-CLIPS-336.json diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index c6a9e9eac..49f59a6e3 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -18,7 +18,7 @@ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs -from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH +from .tokenizer import HFTokenizer, SimpleTokenizer, CLIPS_Tokenizer, DEFAULT_CONTEXT_LENGTH HF_HUB_PREFIX = 'hf-hub:' _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] @@ -123,12 +123,17 @@ def get_tokenizer( context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) if 'hf_tokenizer_name' in text_config: - tokenizer = HFTokenizer( - text_config['hf_tokenizer_name'], + if 'CLIPS' in model_name: + tokenizer = CLIPS_Tokenizer( context_length=context_length, - cache_dir=cache_dir, **tokenizer_kwargs, ) + else: + tokenizer = HFTokenizer( + text_config['hf_tokenizer_name'], + context_length=context_length, + **tokenizer_kwargs, + ) else: tokenizer = SimpleTokenizer( context_length=context_length, @@ -341,6 +346,9 @@ def create_model( else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: + if 'CLIPS' in model_name: + model_cfg['vision_cfg']['eps'] = 1e-6 + model_cfg['text_cfg']['eps'] = 1e-6 model = CLIP(**model_cfg, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 9a9443603..bf2a02c0e 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -31,6 +31,7 @@ class CLIPVisionCfg: mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 + eps: float = 1e-5 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results @@ -76,6 +77,7 @@ class CLIPTextCfg: output_tokens: bool = False act_kwargs: dict = None norm_kwargs: dict = None + eps: float = 1e-5 # HuggingFace specific text tower config hf_model_name: Optional[str] = None @@ -166,6 +168,7 @@ def _build_vision_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, + eps=vision_cfg.eps, ) return visual @@ -215,6 +218,7 @@ def _build_text_tower( output_tokens=text_cfg.output_tokens, act_layer=act_layer, norm_layer=norm_layer, + eps=text_cfg.eps, ) return text diff --git a/src/open_clip/model_configs/ViT-H-14-CLIPS-224.json b/src/open_clip/model_configs/ViT-H-14-CLIPS-224.json new file mode 100644 index 000000000..92433c625 --- /dev/null +++ b/src/open_clip/model_configs/ViT-H-14-CLIPS-224.json @@ -0,0 +1,45 @@ +{ + "model_cfg": { + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 80, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1024, + "heads": 16, + "layers": 24, + "pool_type": "last", + "no_causal_mask": true, + "act_kwargs": { + "approximate": "tanh" + } + } + }, + "preprocess_cfg": { + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ], + "interpolation": "bilinear", + "resize_mode": "squash" + } + } \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-14-CLIPS-224.json b/src/open_clip/model_configs/ViT-L-14-CLIPS-224.json new file mode 100644 index 000000000..94f75cd7d --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-14-CLIPS-224.json @@ -0,0 +1,44 @@ +{ + "model_cfg": { + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 80, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 768, + "heads": 12, + "layers": 12, + "pool_type": "last", + "no_causal_mask": true, + "act_kwargs": { + "approximate": "tanh" + } + } + }, + "preprocess_cfg": { + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ], + "interpolation": "bilinear", + "resize_mode": "squash" + } + } \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-14-CLIPS-336.json b/src/open_clip/model_configs/ViT-L-14-CLIPS-336.json new file mode 100644 index 000000000..2cef0d2ce --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-14-CLIPS-336.json @@ -0,0 +1,44 @@ +{ + "model_cfg": { + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 80, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 768, + "heads": 12, + "layers": 12, + "pool_type": "last", + "no_causal_mask": true, + "act_kwargs": { + "approximate": "tanh" + } + } + }, + "preprocess_cfg": { + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ], + "interpolation": "bilinear", + "resize_mode": "squash" + } + } \ No newline at end of file diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 872c1833b..a469c6a87 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -400,6 +400,89 @@ def get_reduction_mask_fn(type: str): return syntax_mask_tokenize # randomly drop prioritized by syntax +from tokenizers import BertWordPieceTokenizer + +class CustomTokenizer: + """Custom tokenizer using WordPiece-based subword tokenization""" + + def __init__(self, vocab_file, context_length=512, bos_token=1, eos_token=2, class_token=101, pad_token=0): + self.tokenizer = BertWordPieceTokenizer(lowercase=True) + self.tokenizer = self.tokenizer.from_file(vocab_file) + self.context_length = context_length + self.bos_token = bos_token + self.eos_token = eos_token + self.class_token = class_token + self.pad_token = pad_token + + def tokenize(self, text): + encoding = self.tokenizer.encode(text, add_special_tokens=False) + tokens = encoding.ids[:self.context_length - 3] + return [self.bos_token] + tokens + [self.eos_token] + + def batch_encode_plus(self, texts, max_length=None): + max_length = max_length or self.context_length + encoded = [self.tokenize(text) for text in texts] + import torch + return { + 'input_ids': torch.tensor([self.pad_and_add_class_token(e, max_length) for e in encoded]) + } + + def pad_and_add_class_token(self, encoded_text, max_length): + if len(encoded_text) < max_length - 1: + encoded_text += [self.pad_token] * (max_length - 1 - len(encoded_text)) + return encoded_text + [self.class_token] + +class CLIPS_Tokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__( + self, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = 'whitespace', + strip_sep_token: bool = False, + language: Optional[str] = None, + **kwargs + ): + vocab_file = './vocab.txt' + self.tokenizer = CustomTokenizer(vocab_file, context_length=80, bos_token=1, eos_token=2, class_token=101, pad_token=0) + print("Load CLIPS Tokenizer.") + set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None) + if callable(set_lang_fn): + self.set_lang_fn = set_lang_fn + if language is not None: + self.set_language(language) + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + self.strip_sep_token = strip_sep_token + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length in class init or call.' + + texts = [self.clean_fn(text) for text in texts] + encoded_outputs = self.tokenizer.batch_encode_plus( + texts, + max_length=context_length + ) + + input_ids = encoded_outputs['input_ids'] + + return input_ids + + def set_language(self, src_lang): + if hasattr(self, 'set_lang_fn'): + self.set_lang_fn(src_lang) + else: + warnings.warn('Cannot set language for the tokenizer.') + + class HFTokenizer: """HuggingFace tokenizer wrapper""" diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 860d77503..78a581fd7 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -455,6 +455,7 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, + eps: float = 1e-5 # Add eps as a parameter ): super().__init__() assert pool_type in ('tok', 'avg', 'none') @@ -487,7 +488,7 @@ def __init__( # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() - self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) + self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width, eps=eps) self.transformer = Transformer( width, layers, @@ -495,7 +496,7 @@ def __init__( mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, - norm_layer=norm_layer, + norm_layer=lambda x: norm_layer(x, eps=eps), ) if attentional_pool: @@ -533,7 +534,7 @@ def __init__( pool_dim = width self.pool_type = pool_type - self.ln_post = norm_layer(pool_dim) + self.ln_post = norm_layer(pool_dim, eps=eps) self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) self.init_parameters() @@ -693,6 +694,7 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, + eps: float = 1e-5 # Add eps as a parameter ): super().__init__() assert pool_type in ('first', 'last', 'argmax', 'none') @@ -719,9 +721,9 @@ def __init__( mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, - norm_layer=norm_layer, + norm_layer=lambda x: norm_layer(x, eps=eps), ) - self.ln_final = norm_layer(width) + self.ln_final = norm_layer(width, eps=eps) if no_causal_mask: self.attn_mask = None