Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sentence piece tokenizer support for TokenizerInfo #120

Merged
merged 9 commits into from
Dec 22, 2024
81 changes: 76 additions & 5 deletions python/xgrammar/tokenizer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import List, Optional, Union

from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast
import tiktoken
import sentencepiece

from .base import XGRObject, _core
from .support import logging
Expand Down Expand Up @@ -39,6 +41,35 @@ class VocabType(Enum):
BYTE_FALLBACK = "BYTE_FALLBACK"
BYTE_LEVEL = "BYTE_LEVEL"

def is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
Ubospica marked this conversation as resolved.
Show resolved Hide resolved
# helper to check if tokenizer is a tiktoken tokenizer
has_tiktoken_encoding = (
hasattr(tokenizer, 'tokenizer') and
isinstance(tokenizer.tokenizer, tiktoken.Encoding)
)

filename_pattern = (
"vocab_file" in tokenizer.vocab_files_names and
"tiktoken" in tokenizer.vocab_files_names["vocab_file"]
)

return has_tiktoken_encoding or filename_pattern

def is_sentencepiece_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
Ubospica marked this conversation as resolved.
Show resolved Hide resolved
# helper to check if tokenizer is a sentence piece tokenizer
has_sp_model_attr = (
hasattr(tokenizer, 'sp_model') and
isinstance(tokenizer.sp_model, sentencepiece.SentencePieceProcessor)
)

has_nested_sp_model_attr = (
hasattr(tokenizer, 'tokenizer') and
hasattr(tokenizer.tokenizer, 'sp_model') and
isinstance(tokenizer.tokenizer.sp_model, sentencepiece.SentencePieceProcessor)
)

return has_sp_model_attr or has_nested_sp_model_attr


class TokenizerInfo(XGRObject):
"""The tokenizer info contains the vocabulary, the type of the vocabulary, and necessary
Expand Down Expand Up @@ -174,10 +205,7 @@ def from_huggingface(
encoded_vocab, backend_str, vocab_size, stop_token_ids
)
)
elif (
"vocab_file" in tokenizer.vocab_files_names
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]
):
elif is_tiktoken_tokenizer(tokenizer):
# tiktoken tokenizer
# e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously)
if stop_token_ids is None:
Expand All @@ -196,8 +224,51 @@ def from_huggingface(
stop_token_ids=stop_token_ids,
prepend_space_in_tokenization=False,
)
elif is_sentencepiece_tokenizer(tokenizer):
# sentencepiece tokenizer
# e.g. Chatglm3-6b
if hasattr(tokenizer, 'sp_model'):
sp_model = tokenizer.sp_model
elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'):
sp_model = tokenizer.tokenizer.sp_model

if stop_token_ids is None:
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
else:
eos_id = sp_model.eos_id()
if eos_id != -1:
stop_token_ids = [eos_id]
else:
logger.warning(
"When constructing TokenizerInfo from a huggingface tokenizer, "
"stop_token_ids is neither provided by user nor found from the tokenizer. "
"It will be automatically detected."
)

vocab_dict = tokenizer.get_vocab()
vocab_size = len(vocab_dict) if vocab_size is None else vocab_size

# fill in any special tokens from vocab_dict
for token, idx in vocab_dict.items():
if idx < vocab_size:
encoded_vocab[idx] = token

# detect vocab_type of tokenizer
if "<0x0A>" in vocab_dict:
vocab_type = VocabType.BYTE_FALLBACK
else:
vocab_type = VocabType.RAW

return TokenizerInfo(
encoded_vocab,
vocab_type=vocab_type,
vocab_size=vocab_size,
stop_token_ids=stop_token_ids,
prepend_space_in_tokenization=True,
)
else:
# TODO(yixin): sentencepiece tokenizer
# TODO(yixin): unsupported tokenizer
raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}")

@property
Expand Down
2 changes: 2 additions & 0 deletions tests/python/test_tokenizer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def tokenizer_info_storage() -> Dict[str, Tuple[PreTrainedTokenizerBase, xgr.Tok
("Qwen/Qwen2.5-1.5B", xgr.VocabType.BYTE_LEVEL, False),
("internlm/internlm2_5-7b-chat", xgr.VocabType.BYTE_FALLBACK, False),
("mistralai/Mixtral-8x22B-Instruct-v0.1", xgr.VocabType.BYTE_FALLBACK, True),
("THUDM/glm-4-9b-chat", xgr.VocabType.RAW, False),
("THUDM/chatglm3-6b", xgr.VocabType.BYTE_FALLBACK, True),
]

tokenizer_paths = [path for path, *_ in tokenizer_paths_metadata]
Expand Down
Loading