Skip to content

Commit

Permalink
Merge pull request caikit#291 from markstur/no_decode
Browse files Browse the repository at this point in the history
Text Embedding: Fix concurrency errors
  • Loading branch information
evaline-ju authored Dec 11, 2023
2 parents fd0509c + a87cc64 commit 8b7366a
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 137 deletions.
141 changes: 111 additions & 30 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

# Standard
from copy import deepcopy
from typing import List, Optional
import importlib
import os
import time

# Third Party
from torch.backends import mps
Expand Down Expand Up @@ -76,6 +78,22 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
FALSY = ("no", "n", "false", "0", "f", "off")


def env_var_to_int(name, default):
"""Returns the integer value of name env var or default value if None or invalid integer"""
s = os.getenv(name, default)
try:
return int(s)
except (TypeError, ValueError):
return default


# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
BATCH_SIZE = env_var_to_int("BATCH_SIZE", default=0)

# Retry count for catching sporadic encode() or tokenize() errors (in case if they come back)
RETRY_COUNT = env_var_to_int("RETRY_COUNT", default=5)


@module(
"eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f",
"EmbeddingModule",
Expand All @@ -101,6 +119,12 @@ def __init__(
super().__init__()
self.model = model

# Separate copy of tokenizer just for _truncate_input_tokens()
# This avoids RuntimeError('Already borrowed') which is way too frequent
# otherwise when using Python threads and tokenize for truncation followed
# by sentence-transformers tokenize/encode.
self._tokenizer = deepcopy(self.model.tokenizer)

@classmethod
def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
"""Load model
Expand Down Expand Up @@ -214,6 +238,30 @@ def _optimize(model, ipex, device):
logger.warning(warn_msg, exc_info=True)
return model

@staticmethod
def _with_retry(fn, *args, **kwargs):
retries = max(RETRY_COUNT, 0)
for count in range(1 + retries): # try once plus retries (if needed)
try:
return fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
warn_msg = f"Retry {fn} due to: {e}"
logger.warning(warn_msg, exc_info=True)
time.sleep(0.1 * (count * 2))
error.log_raise("<NLP31069292E>", RuntimeError(f"Too many retries of fn={fn}"))

def _encode_with_retry(self, *args, **kwargs):
"""All encode calls should use this for consistent param adding and retry loop"""

# Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE
if BATCH_SIZE > 0:
if kwargs is None:
kwargs = {}
if "batch_size" not in kwargs:
kwargs["batch_size"] = BATCH_SIZE

return self._with_retry(self.model.encode, *args, **kwargs)

def _truncate_input_tokens(
self, truncate_input_tokens, texts: List[str]
) -> List[str]:
Expand Down Expand Up @@ -257,32 +305,33 @@ def _truncate_input_tokens(
max_length = max_tokens
ret = texts # will not alter texts when truncation is not allowed

for text in texts:
tokenized = self.model.tokenizer(
text,
return_attention_mask=False,
return_token_type_ids=False,
return_overflowing_tokens=True,
return_length=True,
truncation=True,
max_length=max_length,
)
tokenized = self._with_retry(
self._tokenizer,
texts,
return_attention_mask=False,
return_token_type_ids=False,
return_overflowing_tokens=True,
return_offsets_mapping=True,
return_length=True,
truncation=True,
max_length=max_length,
)

lengths = tokenized["length"]
was_truncated = len(lengths) > 1 # multiple lengths when truncated
texts_map = tokenized["overflow_to_sample_mapping"]

if okay_to_truncate and was_truncated:
# decode the truncated input tokens back to text to be returned
ret.append(
self.model.tokenizer.decode(
tokenized.input_ids[0], skip_special_tokens=False
)
)
for text_number, text in enumerate(texts):
# positions: the positions (in lengths and offsets arrays) that belong to this text
positions = [
position
for position, sample_number in enumerate(texts_map)
if sample_number == text_number
]
lengths = [tokenized["length"][pos] for pos in positions]

elif okay_to_truncate and not was_truncated:
ret.append(text) # return original text
was_truncated = len(lengths) > 1 # multiple lengths when truncated

elif was_truncated:
if not okay_to_truncate and was_truncated:
# Raise error. We don't allow silent truncation in this case.
tokens = sum(lengths) # add up total tokens for error message
error.log_raise(
"<NLP08391926E>",
Expand All @@ -292,6 +341,34 @@ def _truncate_input_tokens(
),
)

elif okay_to_truncate and not was_truncated:
ret.append(text) # collect original text to return

elif okay_to_truncate and was_truncated:
# Truncate the text that maps to the truncated tokens.
# The offset_mapping describes the text position for each token.
# Added tokens were not in the text, so they show up as (0, 0).

# Get the text offsets for the tokens that are to be kept after truncation.
# Take the first set of offsets that mapped to this text's positions.
# This first set represents what will be kept after truncation.
# Each offset tells us which chars in the original text map to this token.
offsets = next(tokenized["offset_mapping"][pos] for pos in positions)

# Find the first offset that is not empty (0, 0) to avoid added tokens
start = next(offset for offset in offsets if offset != (0, 0))

# Find the last offset that is not empty (0, 0) to avoid added tokens
end = next(
offset for offset in reversed(list(offsets)) if offset != (0, 0)
)

# Use the start-beginning end-ending to slice the text based on token truncation
# i.e. if start=(0,5) and end=(72,78) then we want slice [0:78]
truncated_text = text[start[0] : end[1]]

ret.append(truncated_text) # return the truncated text for this one

return ret

@EmbeddingTask.taskmethod()
Expand All @@ -318,7 +395,7 @@ def run_embedding(

text = self._truncate_input_tokens(truncate_input_tokens, [text])[0]
return EmbeddingResult(
result=Vector1D.from_vector(self.model.encode(text)),
result=Vector1D.from_vector(self._encode_with_retry(text)),
producer_id=self.PRODUCER_ID,
)

Expand Down Expand Up @@ -350,7 +427,7 @@ def run_embeddings(

texts = self._truncate_input_tokens(truncate_input_tokens, texts)

embeddings = self.model.encode(texts)
embeddings = self._encode_with_retry(texts)
vectors = [Vector1D.from_vector(e) for e in embeddings]
return EmbeddingResults(
results=ListOfVector1D(vectors=vectors), producer_id=self.PRODUCER_ID
Expand Down Expand Up @@ -384,8 +461,8 @@ def run_sentence_similarity(
)[0]
sentences = self._truncate_input_tokens(truncate_input_tokens, sentences)

source_embedding = self.model.encode(source_sentence)
embeddings = self.model.encode(sentences)
source_embedding = self._encode_with_retry(source_sentence)
embeddings = self._encode_with_retry(sentences)

res = cos_sim(source_embedding, embeddings)
return SentenceSimilarityResult(
Expand Down Expand Up @@ -422,8 +499,8 @@ def run_sentence_similarities(
)
sentences = self._truncate_input_tokens(truncate_input_tokens, sentences)

source_embedding = self.model.encode(source_sentences)
embeddings = self.model.encode(sentences)
source_embedding = self._encode_with_retry(source_sentences)
embeddings = self._encode_with_retry(sentences)

res = cos_sim(source_embedding, embeddings)
float_list_list = res.tolist()
Expand Down Expand Up @@ -579,11 +656,15 @@ def get_text(doc):
queries = self._truncate_input_tokens(truncate_input_tokens, queries)

doc_embeddings = normalize_embeddings(
self.model.encode(doc_texts, convert_to_tensor=True).to(self.model.device)
self._encode_with_retry(doc_texts, convert_to_tensor=True).to(
self.model.device
)
)

query_embeddings = normalize_embeddings(
self.model.encode(queries, convert_to_tensor=True).to(self.model.device)
self._encode_with_retry(queries, convert_to_tensor=True).to(
self.model.device
)
)

res = semantic_search(
Expand Down
Loading

0 comments on commit 8b7366a

Please sign in to comment.