Skip to content

Commit

Permalink
Merge pull request #343 from markstur/embed_trunc_fix
Browse files Browse the repository at this point in the history
Embeddings fix for truncation without room for begin/end and for batch truncation
  • Loading branch information
evaline-ju authored Apr 3, 2024
2 parents 42c3075 + 780ddfa commit d34987a
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 20 deletions.
22 changes: 20 additions & 2 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,9 @@ def _truncate_input_tokens(
max_length = max_tokens
elif 0 < truncate_input_tokens <= max_tokens:
okay_to_truncate = True
max_length = truncate_input_tokens
# Add 2 for begin/end tokens, but don't go higher than model's max_tokens
max_length = min(truncate_input_tokens + 2, max_tokens)

else:
okay_to_truncate = not implicit_truncation_errors
max_length = max_tokens
Expand Down Expand Up @@ -859,7 +861,9 @@ def _truncate_input_tokens(
padding=True,
)

tokens = sum_token_count(tokenized, truncate_only=False)
tokens = 0
for encoding in tokenized.encodings:
tokens = max(sum(encoding.attention_mask), tokens)
error.log_raise(
"<NLP08391926E>",
ValueError(
Expand All @@ -870,6 +874,20 @@ def _truncate_input_tokens(

input_token_count = sum_token_count(tokenized, truncate_only=True)

# Tokenize without overflow for batching and truncation to work together.
tokenized = self.tokenizer(
*to_tokenize,
return_attention_mask=True,
return_token_type_ids=False,
return_overflowing_tokens=False,
return_offsets_mapping=False,
return_length=False,
return_tensors="pt",
truncation=True,
padding=True,
max_length=max_length,
)

return TruncatedTokensTuple(tokenized, input_token_count)

def encode(
Expand Down
88 changes: 70 additions & 18 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,14 @@ def test_too_many_tokens_error_params(truncate_input_tokens, loaded_model):
)


@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 512])
@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 510, 511, 512])
def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model):
"""truncate_input_tokens prevents these endpoints from raising an error when too many tokens.
Test with -1 which lets the model do truncation instead of raising an error.
Test with 99 (< 512) which causes our code to do the truncation instead of raising an error.
Test with 99 (< 512 -2) which causes our code to do the truncation instead of raising an error.
Test with 510 (512 -2) which causes our code to do the truncation instead of raising an error.
511 and 512 also behave like 510. The value is allowed, but begin/end tokens will take space.
"""

model_max = loaded_model.model.max_seq_length
Expand Down Expand Up @@ -749,29 +751,51 @@ def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_m
)


@pytest.mark.parametrize("truncate_input_tokens", [99, 512, -1])
@pytest.mark.parametrize(
"truncate_input_tokens", [1, 2, 3, 4, 99, 100, 101, 510, 511, 512, -1]
)
def test_embeddings_with_truncation(truncate_input_tokens, loaded_model):
"""verify that results are as expected with truncation"""

max_len = loaded_model.model.max_seq_length - 2
if truncate_input_tokens is None or truncate_input_tokens < 0:
# For -1 we don't truncate, but sentence-transformers will truncate at max_seq_length
repeat = loaded_model.model.max_seq_length
# For -1 we don't truncate, but sentence-transformers will truncate at max_seq_length - 2
repeat = max_len
else:
repeat = truncate_input_tokens
repeat = min(
truncate_input_tokens, max_len
) # max_len is used when we need -2 for begin/end

# Build a text like "x x x.. x " with room for one more token
repeat = repeat - 1 # space for the final x or y token to show difference

base = "x " * repeat # A bunch of "x" tokens
base = ""
if repeat > 0:
base = "x " * repeat # A bunch of "x" tokens
x = base + "x" # One last "x" that will not get truncated
y = base + "y" # A different last character "y" not truncated
z = y + "z" # Add token "z" after "y". This should get truncated.

res = loaded_model.run_embeddings(
texts=[base, x, y, z], truncate_input_tokens=truncate_input_tokens
)
vectors = res.results.vectors # vectors from batch embeddings

vectors = res.results.vectors
# Compare with results from individual embedding calls in a loop
loop_res = []
for t in [base, x, y, z]:
r = loaded_model.run_embedding(
text=t, truncate_input_tokens=truncate_input_tokens
)
loop_res.append(r)
loop_vectors = [
r.result for r in loop_res
] # vectors from loop of single embedding calls

assert len(vectors) == len(loop_vectors), "expected the same length vectors"
# compare the vectors from batch with the single calls
for i, e in enumerate(vectors):
assert np.allclose(e.data.values, loop_vectors[i].data.values)

# x...xyz is the same as x...xy because that is exactly where truncation worked
assert len(vectors[2].data.values) == len(vectors[3].data.values)
Expand Down Expand Up @@ -892,9 +916,18 @@ def test_env_val_to_int():
],
)
def test_sum_token_count_no_truncation(texts, expected_count, loaded_model):
tokenized, _ = loaded_model.model._truncate_input_tokens(
truncate_input_tokens=-1, # don't truncate. Model's truncation can still apply.
texts=texts,

tokenized = loaded_model.model.tokenizer(
texts,
return_attention_mask=True,
return_token_type_ids=False,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_length=True,
return_tensors="pt",
truncation=True,
padding=True,
max_length=loaded_model.model.max_seq_length,
)
token_count = sum_token_count(
tokenized,
Expand Down Expand Up @@ -924,9 +957,17 @@ def test_sum_token_count_no_truncation(texts, expected_count, loaded_model):
],
)
def test_sum_token_count_with_truncation(texts, truncate, expected_count, loaded_model):
tokenized, _ = loaded_model.model._truncate_input_tokens(
truncate_input_tokens=truncate,
texts=texts,
tokenized = loaded_model.model.tokenizer(
texts,
return_attention_mask=True,
return_token_type_ids=False,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_length=True,
return_tensors="pt",
truncation=True,
padding=True,
max_length=truncate,
)
token_count = sum_token_count(
tokenized,
Expand All @@ -936,10 +977,18 @@ def test_sum_token_count_with_truncation(texts, truncate, expected_count, loaded
assert token_count == expected_count


def test_encoding_order(loaded_model: EmbeddingModule):
@pytest.mark.parametrize(
"truncate_input_tokens", [0, 1, 2, 3, 4, 99, 100, 101, 510, 511, 512, 513, -1]
)
def test_encoding_order(loaded_model: EmbeddingModule, truncate_input_tokens):
"""Confirm that encoding doesn't modify the original sort order"""
separate_embeddings = [loaded_model.run_embedding(text=i) for i in MANY_INPUTS]
combined_embeddings = loaded_model.run_embeddings(texts=MANY_INPUTS)
separate_embeddings = [
loaded_model.run_embedding(text=i, truncate_input_tokens=truncate_input_tokens)
for i in MANY_INPUTS
]
combined_embeddings = loaded_model.run_embeddings(
texts=MANY_INPUTS, truncate_input_tokens=truncate_input_tokens
)

separate_vectors = [
e.to_dict()["result"]["data"]["values"] for e in separate_embeddings
Expand All @@ -954,7 +1003,7 @@ def test_encoding_order(loaded_model: EmbeddingModule):

# test order by comparing value of individual embeddings in sequence
for i, e in enumerate(separate_vectors):
assert approx(e) == combined_vectors[i]
assert np.allclose(e, combined_vectors[i])

# test expected failure case by reordering
shifted_separate_vectors = separate_vectors[1:] + [separate_vectors[0]]
Expand All @@ -964,6 +1013,9 @@ def test_encoding_order(loaded_model: EmbeddingModule):
assert (
not approx(e) == combined_vectors[i]
), "expected altered order to not match combined vectors"
assert not np.allclose(
e, combined_vectors[i]
), "expected altered order to not match combined"


@pytest.mark.parametrize(
Expand Down

0 comments on commit d34987a

Please sign in to comment.