Skip to content

Commit

Permalink
Add split_special_tokens=True
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jan 2, 2024
1 parent f678d4c commit 26590d2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/siglip/tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class SiglipTokenizer(PreTrainedTokenizer):
The maximum length (in number of tokens) for model inputs.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
split_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to split the special tokens when tokenizing.
"""

vocab_files_names = VOCAB_FILES_NAMES
Expand All @@ -106,6 +108,7 @@ def __init__(
sp_model_kwargs: Optional[Dict[str, Any]] = None,
model_max_length=64,
do_lower_case=True,
split_special_tokens=True,
**kwargs,
) -> None:
requires_backends(self, "protobuf")
Expand Down Expand Up @@ -142,6 +145,7 @@ def __init__(
sp_model_kwargs=self.sp_model_kwargs,
model_max_length=model_max_length,
do_lower_case=do_lower_case,
split_special_tokens=split_special_tokens,
**kwargs,
)

Expand Down
14 changes: 4 additions & 10 deletions tests/models/siglip/test_tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,6 @@ def test_rust_and_python_full_tokenizers(self):
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)

def test_eos_treatment(self):
tokenizer = self.siglip_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])

def test_prepare_batch(self):
tokenizer = self.siglip_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
Expand Down Expand Up @@ -204,8 +198,8 @@ def test_eos_in_input(self):
tokenizer = self.siglip_tokenizer
src_text = ["A long paragraph for summarization. </s>"]
tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [262, 266, 476, 8532, 270, 4460, 3949, 1682, 1]
expected_tgt_tokens = [6254, 267, 260, 1443, 1]
expected_src_tokens = [262, 266, 476, 8532, 270, 4460, 3949, 1682, 262, 264, 1]
expected_tgt_tokens = [6254, 267, 260, 1443, 262, 264, 1]

batch = tokenizer(src_text, text_target=tgt_text)

Expand Down Expand Up @@ -366,13 +360,13 @@ def test_tokenizer_integration(self):
)

def test_some_edge_cases(self):
tokenizer = SiglipTokenizer.from_pretrained("nielsr/siglip-base-patch16-224", legacy=False)
tokenizer = SiglipTokenizer.from_pretrained("nielsr/siglip-base-patch16-224")

sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
self.assertEqual(sp_tokens, ["</", "s", ">", ">"])
tokens = tokenizer.tokenize("</s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["</s>"])
self.assertEqual(tokens, ["▁", "s"])

tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
Expand Down
25 changes: 24 additions & 1 deletion tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,13 @@ def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences)
# TODO: this test can be combined with `test_sentencepiece_tokenize_and_convert_tokens_to_string` after the latter is extended to all tokenizers.
def test_tokenize_special_tokens(self):
"""Test `tokenize` with special tokens."""

tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
for tokenizer in tokenizers:
if tokenizer.split_special_tokens:
self.skipTest("Skipping since tokenizer splits special tokens.")
return

with self.subTest(f"{tokenizer.__class__.__name__}"):
SPECIAL_TOKEN_1 = "[SPECIAL_TOKEN_1]"
SPECIAL_TOKEN_2 = "[SPECIAL_TOKEN_2]"
Expand Down Expand Up @@ -823,6 +828,10 @@ def test_added_tokens_do_lower_case(self):
if not hasattr(tokenizer, "do_lower_case") or not tokenizer.do_lower_case:
continue

if tokenizer.split_special_tokens:
self.skipTest("Skipping since tokenizer splits special tokens.")
return

special_token = tokenizer.all_special_tokens[0]

text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
Expand Down Expand Up @@ -887,6 +896,9 @@ def test_added_tokens_do_lower_case(self):
def test_add_tokens_tokenizer(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
if tokenizer.split_special_tokens:
self.skipTest("Skipping since tokenizer splits special tokens.")
return
with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
Expand Down Expand Up @@ -953,7 +965,9 @@ def test_add_special_tokens(self):
tokenizer.add_special_tokens({"cls_token": special_token})
special_token = str(special_token)
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)

if not tokenizer.split_special_tokens:
self.assertEqual(len(encoded_special_token), 1)

text = tokenizer.decode(ids + encoded_special_token, clean_up_tokenization_spaces=False)
encoded = tokenizer.encode(text, add_special_tokens=False)
Expand Down Expand Up @@ -987,6 +1001,9 @@ def test_internal_consistency(self):
def test_encode_decode_with_spaces(self):
tokenizers = self.get_tokenizers(do_lower_case=False, fast=False)
for tokenizer in tokenizers:
if tokenizer.split_special_tokens:
self.skipTest("Skipping since tokenizer splits special tokens.")
return
with self.subTest(f"{tokenizer.__class__.__name__}"):
new_toks = [
# These are added tokens, they will be normalized....
Expand Down Expand Up @@ -2167,7 +2184,13 @@ def test_added_token_are_matched_longest_first(self):
if not self.test_slow_tokenizer:
self.skipTest("This test is only for slow tokenizers")
return

tokenizers = self.get_tokenizers(fast=False)

if tokenizers[0].split_special_tokens:
self.skipTest("This test is only relevant in case the tokenizer doesn't split special tokens")
return

for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
try:
Expand Down

0 comments on commit 26590d2

Please sign in to comment.