From 780ddfa6a91b2e71434012b204fb0a43931d941d Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Tue, 2 Apr 2024 21:52:30 -0700 Subject: [PATCH] Embeddings fix for truncation without room for begin/end and for batch truncation * Attempting truncate_input_tokens=2 (or 1) was creating a strange error (or misbehaving) because it takes at least 3 tokens for [CLS] TOK [SEP] for meaningful results. * Now that truncate value generally means number of tokens not including begin/end. * On the max end the 2 special tokens will be allowed to consume 2 from the limit. * Batch embedding processing was returning odd/misordered results when combined with truncation. Added a re tokenize() call to avoid sending the overflow tokens as features to be processed. Signed-off-by: Mark Sturdevant --- .../modules/text_embedding/embedding.py | 22 ++++- .../modules/text_embedding/test_embedding.py | 88 +++++++++++++++---- 2 files changed, 90 insertions(+), 20 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 1c7c45aa..d4447ebe 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -818,7 +818,9 @@ def _truncate_input_tokens( max_length = max_tokens elif 0 < truncate_input_tokens <= max_tokens: okay_to_truncate = True - max_length = truncate_input_tokens + # Add 2 for begin/end tokens, but don't go higher than model's max_tokens + max_length = min(truncate_input_tokens + 2, max_tokens) + else: okay_to_truncate = not implicit_truncation_errors max_length = max_tokens @@ -859,7 +861,9 @@ def _truncate_input_tokens( padding=True, ) - tokens = sum_token_count(tokenized, truncate_only=False) + tokens = 0 + for encoding in tokenized.encodings: + tokens = max(sum(encoding.attention_mask), tokens) error.log_raise( "", ValueError( @@ -870,6 +874,20 @@ def _truncate_input_tokens( input_token_count = sum_token_count(tokenized, truncate_only=True) + # Tokenize without overflow for batching and truncation to work together. + tokenized = self.tokenizer( + *to_tokenize, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=False, + return_offsets_mapping=False, + return_length=False, + return_tensors="pt", + truncation=True, + padding=True, + max_length=max_length, + ) + return TruncatedTokensTuple(tokenized, input_token_count) def encode( diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 99c6bebe..bb12b989 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -681,12 +681,14 @@ def test_too_many_tokens_error_params(truncate_input_tokens, loaded_model): ) -@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 512]) +@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 510, 511, 512]) def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model): """truncate_input_tokens prevents these endpoints from raising an error when too many tokens. Test with -1 which lets the model do truncation instead of raising an error. - Test with 99 (< 512) which causes our code to do the truncation instead of raising an error. + Test with 99 (< 512 -2) which causes our code to do the truncation instead of raising an error. + Test with 510 (512 -2) which causes our code to do the truncation instead of raising an error. + 511 and 512 also behave like 510. The value is allowed, but begin/end tokens will take space. """ model_max = loaded_model.model.max_seq_length @@ -749,20 +751,27 @@ def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_m ) -@pytest.mark.parametrize("truncate_input_tokens", [99, 512, -1]) +@pytest.mark.parametrize( + "truncate_input_tokens", [1, 2, 3, 4, 99, 100, 101, 510, 511, 512, -1] +) def test_embeddings_with_truncation(truncate_input_tokens, loaded_model): """verify that results are as expected with truncation""" + max_len = loaded_model.model.max_seq_length - 2 if truncate_input_tokens is None or truncate_input_tokens < 0: - # For -1 we don't truncate, but sentence-transformers will truncate at max_seq_length - repeat = loaded_model.model.max_seq_length + # For -1 we don't truncate, but sentence-transformers will truncate at max_seq_length - 2 + repeat = max_len else: - repeat = truncate_input_tokens + repeat = min( + truncate_input_tokens, max_len + ) # max_len is used when we need -2 for begin/end # Build a text like "x x x.. x " with room for one more token repeat = repeat - 1 # space for the final x or y token to show difference - base = "x " * repeat # A bunch of "x" tokens + base = "" + if repeat > 0: + base = "x " * repeat # A bunch of "x" tokens x = base + "x" # One last "x" that will not get truncated y = base + "y" # A different last character "y" not truncated z = y + "z" # Add token "z" after "y". This should get truncated. @@ -770,8 +779,23 @@ def test_embeddings_with_truncation(truncate_input_tokens, loaded_model): res = loaded_model.run_embeddings( texts=[base, x, y, z], truncate_input_tokens=truncate_input_tokens ) + vectors = res.results.vectors # vectors from batch embeddings - vectors = res.results.vectors + # Compare with results from individual embedding calls in a loop + loop_res = [] + for t in [base, x, y, z]: + r = loaded_model.run_embedding( + text=t, truncate_input_tokens=truncate_input_tokens + ) + loop_res.append(r) + loop_vectors = [ + r.result for r in loop_res + ] # vectors from loop of single embedding calls + + assert len(vectors) == len(loop_vectors), "expected the same length vectors" + # compare the vectors from batch with the single calls + for i, e in enumerate(vectors): + assert np.allclose(e.data.values, loop_vectors[i].data.values) # x...xyz is the same as x...xy because that is exactly where truncation worked assert len(vectors[2].data.values) == len(vectors[3].data.values) @@ -892,9 +916,18 @@ def test_env_val_to_int(): ], ) def test_sum_token_count_no_truncation(texts, expected_count, loaded_model): - tokenized, _ = loaded_model.model._truncate_input_tokens( - truncate_input_tokens=-1, # don't truncate. Model's truncation can still apply. - texts=texts, + + tokenized = loaded_model.model.tokenizer( + texts, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_length=True, + return_tensors="pt", + truncation=True, + padding=True, + max_length=loaded_model.model.max_seq_length, ) token_count = sum_token_count( tokenized, @@ -924,9 +957,17 @@ def test_sum_token_count_no_truncation(texts, expected_count, loaded_model): ], ) def test_sum_token_count_with_truncation(texts, truncate, expected_count, loaded_model): - tokenized, _ = loaded_model.model._truncate_input_tokens( - truncate_input_tokens=truncate, - texts=texts, + tokenized = loaded_model.model.tokenizer( + texts, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_length=True, + return_tensors="pt", + truncation=True, + padding=True, + max_length=truncate, ) token_count = sum_token_count( tokenized, @@ -936,10 +977,18 @@ def test_sum_token_count_with_truncation(texts, truncate, expected_count, loaded assert token_count == expected_count -def test_encoding_order(loaded_model: EmbeddingModule): +@pytest.mark.parametrize( + "truncate_input_tokens", [0, 1, 2, 3, 4, 99, 100, 101, 510, 511, 512, 513, -1] +) +def test_encoding_order(loaded_model: EmbeddingModule, truncate_input_tokens): """Confirm that encoding doesn't modify the original sort order""" - separate_embeddings = [loaded_model.run_embedding(text=i) for i in MANY_INPUTS] - combined_embeddings = loaded_model.run_embeddings(texts=MANY_INPUTS) + separate_embeddings = [ + loaded_model.run_embedding(text=i, truncate_input_tokens=truncate_input_tokens) + for i in MANY_INPUTS + ] + combined_embeddings = loaded_model.run_embeddings( + texts=MANY_INPUTS, truncate_input_tokens=truncate_input_tokens + ) separate_vectors = [ e.to_dict()["result"]["data"]["values"] for e in separate_embeddings @@ -954,7 +1003,7 @@ def test_encoding_order(loaded_model: EmbeddingModule): # test order by comparing value of individual embeddings in sequence for i, e in enumerate(separate_vectors): - assert approx(e) == combined_vectors[i] + assert np.allclose(e, combined_vectors[i]) # test expected failure case by reordering shifted_separate_vectors = separate_vectors[1:] + [separate_vectors[0]] @@ -964,6 +1013,9 @@ def test_encoding_order(loaded_model: EmbeddingModule): assert ( not approx(e) == combined_vectors[i] ), "expected altered order to not match combined vectors" + assert not np.allclose( + e, combined_vectors[i] + ), "expected altered order to not match combined" @pytest.mark.parametrize(