diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index 315f6caf2..23856a3f1 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -5,7 +5,7 @@ from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ - get_model_tokenize_cfg, get_model_context_len, get_model_preprocess_cfg, set_model_preprocess_cfg + get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 969e8de34..ef94b51f8 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -1,27 +1,24 @@ -import copy import json import logging import os -import pathlib import re from copy import deepcopy from dataclasses import asdict from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union -from functools import partial import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ - resize_pos_embed, get_cast_dtype, resize_text_pos_embed, get_model_preprocess_cfg, set_model_preprocess_cfg + resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg from .coca_model import CoCa from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs -from .tokenizer import HFTokenizer, SimpleTokenizer +from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH HF_HUB_PREFIX = 'hf-hub:' @@ -86,7 +83,7 @@ def _get_hf_config(model_id, cache_dir=None): def get_tokenizer( model_name: str = '', - text_mask: str = '', + context_length: Optional[int] = None, **kwargs, ): if model_name.startswith(HF_HUB_PREFIX): @@ -94,7 +91,11 @@ def get_tokenizer( try: config = _get_hf_config(model_name)['model_cfg'] except Exception: - tokenizer = HFTokenizer(model_name) + tokenizer = HFTokenizer( + model_name, + context_length=context_length or DEFAULT_CONTEXT_LENGTH, + **kwargs, + ) return tokenizer else: config = get_model_config(model_name) @@ -106,13 +107,20 @@ def get_tokenizer( else: tokenizer_kwargs = kwargs + if context_length is None: + context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) + if 'hf_tokenizer_name' in text_config: tokenizer = HFTokenizer( text_config['hf_tokenizer_name'], + context_length=context_length, **tokenizer_kwargs, ) else: - tokenizer = SimpleTokenizer.create(text_mask=text_mask, **tokenizer_kwargs) + tokenizer = SimpleTokenizer( + context_length=context_length, + **tokenizer_kwargs, + ) return tokenizer diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 920eb6d80..0310ee560 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -594,18 +594,13 @@ def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict -def get_model_context_len(model): - module = getattr(model, 'text', model) - return getattr(module, 'context_length', None) - - def get_model_tokenize_cfg(model): module = getattr(model, 'text', model) cfg = {} - context_len = getattr(module, 'context_len', None) - if context_len is not None: - cfg['context_len'] = context_len + context_length = getattr(module, 'context_length', None) + if context_length is not None: + cfg['context_length'] = context_length vocab_size = getattr(module, 'vocab_size', None) if vocab_size is not None: cfg['vocab_size'] = vocab_size - return cfg \ No newline at end of file + return cfg diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index f72093522..9d626315e 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -5,9 +5,10 @@ import gzip import html import os +import random import string -from functools import lru_cache -from typing import Optional, List, Union +from functools import lru_cache, partial +from typing import Callable, Optional, List, Union import ftfy import numpy as np @@ -16,6 +17,9 @@ # https://stackoverflow.com/q/62691279 os.environ["TOKENIZERS_PARALLELISM"] = "false" +_nltk_init = False + +DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP @lru_cache() @@ -123,8 +127,10 @@ class SimpleTokenizer(object): def __init__( self, bpe_path: str = default_bpe(), - special_tokens=None, + additional_special_tokens: Optional[List[str]] = None, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, clean: str = 'lower', + reduction_mask: str = '' ): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} @@ -135,31 +141,26 @@ def __init__( vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) - if not special_tokens: - special_tokens = ['', ''] - else: - special_tokens = ['', ''] + special_tokens + special_tokens = ['', ''] + if additional_special_tokens: + special_tokens += additional_special_tokens vocab.extend(special_tokens) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {t:t for t in special_tokens} special = "|".join(special_tokens) - self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - self.clean_fn = get_clean_fn(clean) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) self.vocab_size = len(self.encoder) self.all_special_ids = [self.encoder[t] for t in special_tokens] - - @staticmethod - def create(text_mask='', **kwargs) -> 'SimpleTokenizer': - if text_mask == 'simple': - return SimpleMaskTokenizer(**kwargs) - elif text_mask == 'random': - return RandomMaskTokenizer(**kwargs) - elif text_mask == 'syntax': - return SyntaxMaskTokenizer(**kwargs) - else: - return SimpleTokenizer(**kwargs) + self.sot_token_id = self.all_special_ids[0] + self.eot_token_id = self.all_special_ids[1] + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None def bpe(self, token): if token in self.cache: @@ -215,7 +216,7 @@ def decode(self, tokens): text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text - def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters @@ -232,15 +233,26 @@ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> to if isinstance(texts, str): texts = [texts] - sot_token = self.encoder[""] - eot_token = self.encoder[""] - all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length' + + if self.reduction_fn is not None: + # use reduction strategy for tokenize if set, otherwise default to truncation below + return self.reduction_fn( + texts, + context_length=context_length, + sot_token_id=self.sot_token_id, + eot_token_id=self.eot_token_id, + encode_fn=self.encode, + ) + + all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token + tokens[-1] = self.eot_token_id result[i, :len(tokens)] = torch.tensor(tokens) return result @@ -254,33 +266,160 @@ def decode(output_ids: torch.Tensor): return _tokenizer.decode(output_ids) -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor: return _tokenizer(texts, context_length=context_length) +def random_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, + shuffle: bool = False, +): + all_tokens = [encode_fn(text) for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = torch.tensor(tokens) + num_tokens = len(tokens) + if num_tokens > context_length - 2: # 2 for sot and eot token + num_keep = context_length - 2 + indices = torch.randperm(len(tokens)) + indices = indices[:num_keep] + if not shuffle: + indices = indices.msort() + tokens = tokens[indices] + num_tokens = num_keep + result[i, 0] = sot_token_id + result[i, 1:num_tokens + 1] = tokens + result[i, num_tokens + 1] = eot_token_id + + return result + +def simple_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, +): + all_tokens = [encode_fn(text) for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + num_tokens = len(tokens) + if num_tokens > context_length - 2: # 2 for sot and eot token + num_keep = context_length - 2 + start_index = random.randint(0, num_tokens - num_keep) # high is incl + tokens = tokens[start_index: start_index + num_keep] + tokens = [sot_token_id] + tokens + [eot_token_id] + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +def syntax_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, +) -> torch.LongTensor: + """ Returns the tokenized representation of given input string(s). + Apply syntax masking before tokenize. + """ + import nltk + global _nltk_init + if not _nltk_init: + # run them for the first time + nltk.download('punkt') + nltk.download('averaged_perceptron_tagger') + _nltk_init = True + + def get_order(x): + if x.startswith('NN'): + return 1 + elif x.startswith('JJ'): + return 2 + elif x.startswith('VB'): + return 3 + else: + return 4 + + # syntax masking + new_texts = [] + for text in texts: + list_tokens = nltk.tokenize.word_tokenize(text) + pos_tags = nltk.pos_tag(list_tokens) + # sample the words by get_order method + order_list = [get_order(tag) for _, tag in pos_tags] + sorted_ids = np.argsort(np.array(order_list)) + sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens + sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens + + new_text = '' + for token in sampled_tokens: + new_text = new_text + str(token) + ' ' + new_text = new_text.strip() + new_texts.append(new_text) + texts = new_texts + + all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + # still need first truncate because some words produces two tokens + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token_id + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +def get_reduction_mask_fn(type: str): + """ Choose strategy for dropping (masking) tokens to achieve target context length""" + assert type in ('simple', 'random', 'shuffle', 'syntax') + if type == 'simple': + return simple_mask_tokenize # randomly select block [start:end] + elif type == 'random': + return random_mask_tokenize # randomly drop tokens (keep order) + elif type == 'shuffle': + return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order) + elif type == 'syntax': + return syntax_mask_tokenize # randomly drop prioritized by syntax + + class HFTokenizer: """HuggingFace tokenizer wrapper""" def __init__( self, tokenizer_name: str, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, clean: str = 'whitespace', - strip_sep_token=False, + strip_sep_token: bool = False, ): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.context_length = context_length self.clean_fn = get_clean_fn(clean) self.strip_sep_token = strip_sep_token def save_pretrained(self, dest): self.tokenizer.save_pretrained(dest) - def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length in class init or call.' + texts = [self.clean_fn(text) for text in texts] input_ids = self.tokenizer( texts, @@ -300,148 +439,6 @@ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> to return input_ids -class RandomMaskTokenizer(SimpleTokenizer): - - def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = self.encoder[""] - eot_token = self.encoder[""] - all_tokens = [self.encode(text) for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length - 2: # 2 for sot and eot token - indices = np.random.permutation(len(tokens)).tolist() - indices = indices[:context_length - 2] - tokens = tokens[indices] - tokens = [sot_token] + tokens + [eot_token] - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - -class SimpleMaskTokenizer(SimpleTokenizer): - def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = self.encoder[""] - eot_token = self.encoder[""] - all_tokens = [self.encode(text) for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length - 2: # 2 for sot and eot token - start_index = np.random.randint(len(tokens) - context_length + 3) - tokens = tokens[start_index : start_index + context_length - 2] - tokens = [sot_token] + tokens + [eot_token] - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - -class SyntaxMaskTokenizer(SimpleTokenizer): - - def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s). - Apply syntax masking before tokenize. - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - import nltk - if not hasattr(self, '_nltk_init'): - # run them for the first time - nltk.download('punkt') - nltk.download('averaged_perceptron_tagger') - self._nltk_init = True - - if isinstance(texts, str): - texts = [texts] - - def get_order(x): - if x.startswith('NN'): - return 1 - elif x.startswith('JJ'): - return 2 - elif x.startswith('VB'): - return 3 - else: - return 4 - - # syntax masking - new_texts = [] - for text in texts: - list_tokens = nltk.tokenize.word_tokenize(text) - pos_tags = nltk.pos_tag(list_tokens) - # sample the words by get_order method - order_list = [get_order(tag) for _, tag in pos_tags] - sorted_ids = np.argsort(np.array(order_list)) - sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens - # sample the tokens and convert to tf.tensor - sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) - - new_text = '' - for token in sampled_tokens: - new_text = new_text + str(token) + ' ' - new_text = new_text.strip() - new_texts.append(new_text) - texts = new_texts - - sot_token = self.encoder[""] - eot_token = self.encoder[""] - all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - # still need first truncate because some words produces two tokens - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - class SigLipTokenizer: """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs """ @@ -452,7 +449,11 @@ class SigLipTokenizer: "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", } - def __init__(self, tokenizer_name: str): + def __init__( + self, + tokenizer_name: str, + context_length: Optional[int] = 64, + ): from transformers import T5TokenizerFast if tokenizer_name in self.VOCAB_FILES: @@ -469,15 +470,20 @@ def __init__(self, tokenizer_name: str): self.tokenizer.pad_token_id = 1 self.tokenizer.eos_token_id = 1 + self.context_length = context_length def save_pretrained(self, dest): self.tokenizer.save_pretrained(dest) - def __call__(self, texts: Union[str, List[str]], context_length: int = 64) -> torch.Tensor: + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] + + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length in class init or call.' + texts = [canonicalize_text(basic_clean(text)) for text in texts] output = self.tokenizer( texts, diff --git a/src/training/main.py b/src/training/main.py index 853f8586b..94496999f 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -28,7 +28,7 @@ except ImportError: hvd = None -from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss, get_model_context_len +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data from training.distributed import is_master, init_distributed_device, broadcast_object from training.logger import setup_logging @@ -354,9 +354,6 @@ def main(args): # initialize datasets tokenizer = get_tokenizer(args.model) - context_len = get_model_context_len(model) - if context_len is not None: - tokenizer = partial(tokenizer, context_length=context_len) data = get_data( args, (preprocess_train, preprocess_val), diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 5bcc8a7df..06ce7ac09 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -1,11 +1,9 @@ import logging -from functools import partial import torch -import torch.nn.functional as F from tqdm import tqdm -from open_clip import get_input_dtype, get_tokenizer, get_model_context_len, build_zero_shot_classifier, \ +from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES from .precision import get_autocast @@ -56,9 +54,6 @@ def zero_shot_eval(model, data, epoch, args, tokenizer=None): logging.info('Starting zero-shot imagenet.') if tokenizer is None: tokenizer = get_tokenizer(args.model) - context_len = get_model_context_len(model) - if context_len is not None: - tokenizer = partial(tokenizer, context_length=context_len) logging.info('Building zero-shot classifier') autocast = get_autocast(args.precision)