Skip to content

Commit

Permalink
More tokenizer rework, add context_len as class attr set in factory, …
Browse files Browse the repository at this point in the history
…default __call__() arg to None. Clean up reduction masking logic and fix #680
  • Loading branch information
rwightman committed Oct 18, 2023
1 parent b086ddb commit 05e9864
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 200 deletions.
2 changes: 1 addition & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
get_model_tokenize_cfg, get_model_context_len, get_model_preprocess_cfg, set_model_preprocess_cfg
get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
24 changes: 16 additions & 8 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
import copy
import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from functools import partial

import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, get_model_preprocess_cfg, set_model_preprocess_cfg
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, SimpleTokenizer
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH


HF_HUB_PREFIX = 'hf-hub:'
Expand Down Expand Up @@ -86,15 +83,19 @@ def _get_hf_config(model_id, cache_dir=None):

def get_tokenizer(
model_name: str = '',
text_mask: str = '',
context_length: Optional[int] = None,
**kwargs,
):
if model_name.startswith(HF_HUB_PREFIX):
model_name = model_name[len(HF_HUB_PREFIX):]
try:
config = _get_hf_config(model_name)['model_cfg']
except Exception:
tokenizer = HFTokenizer(model_name)
tokenizer = HFTokenizer(
model_name,
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
**kwargs,
)
return tokenizer
else:
config = get_model_config(model_name)
Expand All @@ -106,13 +107,20 @@ def get_tokenizer(
else:
tokenizer_kwargs = kwargs

if context_length is None:
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)

if 'hf_tokenizer_name' in text_config:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer.create(text_mask=text_mask, **tokenizer_kwargs)
tokenizer = SimpleTokenizer(
context_length=context_length,
**tokenizer_kwargs,
)

return tokenizer

Expand Down
13 changes: 4 additions & 9 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,18 +594,13 @@ def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict


def get_model_context_len(model):
module = getattr(model, 'text', model)
return getattr(module, 'context_length', None)


def get_model_tokenize_cfg(model):
module = getattr(model, 'text', model)
cfg = {}
context_len = getattr(module, 'context_len', None)
if context_len is not None:
cfg['context_len'] = context_len
context_length = getattr(module, 'context_length', None)
if context_length is not None:
cfg['context_length'] = context_length
vocab_size = getattr(module, 'vocab_size', None)
if vocab_size is not None:
cfg['vocab_size'] = vocab_size
return cfg
return cfg
Loading

0 comments on commit 05e9864

Please sign in to comment.