diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 4df9791e..c2ef5045 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -13,9 +13,11 @@ # limitations under the License. # Standard +from copy import deepcopy from typing import List, Optional import importlib import os +import time # Third Party from torch.backends import mps @@ -76,6 +78,22 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument FALSY = ("no", "n", "false", "0", "f", "off") +def env_var_to_int(name, default): + """Returns the integer value of name env var or default value if None or invalid integer""" + s = os.getenv(name, default) + try: + return int(s) + except (TypeError, ValueError): + return default + + +# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used +BATCH_SIZE = env_var_to_int("BATCH_SIZE", default=0) + +# Retry count for catching sporadic encode() or tokenize() errors (in case if they come back) +RETRY_COUNT = env_var_to_int("RETRY_COUNT", default=5) + + @module( "eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f", "EmbeddingModule", @@ -101,6 +119,12 @@ def __init__( super().__init__() self.model = model + # Separate copy of tokenizer just for _truncate_input_tokens() + # This avoids RuntimeError('Already borrowed') which is way too frequent + # otherwise when using Python threads and tokenize for truncation followed + # by sentence-transformers tokenize/encode. + self._tokenizer = deepcopy(self.model.tokenizer) + @classmethod def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": """Load model @@ -214,6 +238,30 @@ def _optimize(model, ipex, device): logger.warning(warn_msg, exc_info=True) return model + @staticmethod + def _with_retry(fn, *args, **kwargs): + retries = max(RETRY_COUNT, 0) + for count in range(1 + retries): # try once plus retries (if needed) + try: + return fn(*args, **kwargs) + except Exception as e: # pylint: disable=broad-exception-caught + warn_msg = f"Retry {fn} due to: {e}" + logger.warning(warn_msg, exc_info=True) + time.sleep(0.1 * (count * 2)) + error.log_raise("", RuntimeError(f"Too many retries of fn={fn}")) + + def _encode_with_retry(self, *args, **kwargs): + """All encode calls should use this for consistent param adding and retry loop""" + + # Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE + if BATCH_SIZE > 0: + if kwargs is None: + kwargs = {} + if "batch_size" not in kwargs: + kwargs["batch_size"] = BATCH_SIZE + + return self._with_retry(self.model.encode, *args, **kwargs) + def _truncate_input_tokens( self, truncate_input_tokens, texts: List[str] ) -> List[str]: @@ -257,32 +305,33 @@ def _truncate_input_tokens( max_length = max_tokens ret = texts # will not alter texts when truncation is not allowed - for text in texts: - tokenized = self.model.tokenizer( - text, - return_attention_mask=False, - return_token_type_ids=False, - return_overflowing_tokens=True, - return_length=True, - truncation=True, - max_length=max_length, - ) + tokenized = self._with_retry( + self._tokenizer, + texts, + return_attention_mask=False, + return_token_type_ids=False, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_length=True, + truncation=True, + max_length=max_length, + ) - lengths = tokenized["length"] - was_truncated = len(lengths) > 1 # multiple lengths when truncated + texts_map = tokenized["overflow_to_sample_mapping"] - if okay_to_truncate and was_truncated: - # decode the truncated input tokens back to text to be returned - ret.append( - self.model.tokenizer.decode( - tokenized.input_ids[0], skip_special_tokens=False - ) - ) + for text_number, text in enumerate(texts): + # positions: the positions (in lengths and offsets arrays) that belong to this text + positions = [ + position + for position, sample_number in enumerate(texts_map) + if sample_number == text_number + ] + lengths = [tokenized["length"][pos] for pos in positions] - elif okay_to_truncate and not was_truncated: - ret.append(text) # return original text + was_truncated = len(lengths) > 1 # multiple lengths when truncated - elif was_truncated: + if not okay_to_truncate and was_truncated: + # Raise error. We don't allow silent truncation in this case. tokens = sum(lengths) # add up total tokens for error message error.log_raise( "", @@ -292,6 +341,34 @@ def _truncate_input_tokens( ), ) + elif okay_to_truncate and not was_truncated: + ret.append(text) # collect original text to return + + elif okay_to_truncate and was_truncated: + # Truncate the text that maps to the truncated tokens. + # The offset_mapping describes the text position for each token. + # Added tokens were not in the text, so they show up as (0, 0). + + # Get the text offsets for the tokens that are to be kept after truncation. + # Take the first set of offsets that mapped to this text's positions. + # This first set represents what will be kept after truncation. + # Each offset tells us which chars in the original text map to this token. + offsets = next(tokenized["offset_mapping"][pos] for pos in positions) + + # Find the first offset that is not empty (0, 0) to avoid added tokens + start = next(offset for offset in offsets if offset != (0, 0)) + + # Find the last offset that is not empty (0, 0) to avoid added tokens + end = next( + offset for offset in reversed(list(offsets)) if offset != (0, 0) + ) + + # Use the start-beginning end-ending to slice the text based on token truncation + # i.e. if start=(0,5) and end=(72,78) then we want slice [0:78] + truncated_text = text[start[0] : end[1]] + + ret.append(truncated_text) # return the truncated text for this one + return ret @EmbeddingTask.taskmethod() @@ -318,7 +395,7 @@ def run_embedding( text = self._truncate_input_tokens(truncate_input_tokens, [text])[0] return EmbeddingResult( - result=Vector1D.from_vector(self.model.encode(text)), + result=Vector1D.from_vector(self._encode_with_retry(text)), producer_id=self.PRODUCER_ID, ) @@ -350,7 +427,7 @@ def run_embeddings( texts = self._truncate_input_tokens(truncate_input_tokens, texts) - embeddings = self.model.encode(texts) + embeddings = self._encode_with_retry(texts) vectors = [Vector1D.from_vector(e) for e in embeddings] return EmbeddingResults( results=ListOfVector1D(vectors=vectors), producer_id=self.PRODUCER_ID @@ -384,8 +461,8 @@ def run_sentence_similarity( )[0] sentences = self._truncate_input_tokens(truncate_input_tokens, sentences) - source_embedding = self.model.encode(source_sentence) - embeddings = self.model.encode(sentences) + source_embedding = self._encode_with_retry(source_sentence) + embeddings = self._encode_with_retry(sentences) res = cos_sim(source_embedding, embeddings) return SentenceSimilarityResult( @@ -422,8 +499,8 @@ def run_sentence_similarities( ) sentences = self._truncate_input_tokens(truncate_input_tokens, sentences) - source_embedding = self.model.encode(source_sentences) - embeddings = self.model.encode(sentences) + source_embedding = self._encode_with_retry(source_sentences) + embeddings = self._encode_with_retry(sentences) res = cos_sim(source_embedding, embeddings) float_list_list = res.tolist() @@ -579,11 +656,15 @@ def get_text(doc): queries = self._truncate_input_tokens(truncate_input_tokens, queries) doc_embeddings = normalize_embeddings( - self.model.encode(doc_texts, convert_to_tensor=True).to(self.model.device) + self._encode_with_retry(doc_texts, convert_to_tensor=True).to( + self.model.device + ) ) query_embeddings = normalize_embeddings( - self.model.encode(queries, convert_to_tensor=True).to(self.model.device) + self._encode_with_retry(queries, convert_to_tensor=True).to( + self.model.device + ) ) res = semantic_search( diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 57feec1a..fa386b3e 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -73,6 +73,15 @@ ## Tests ######################################################################## +@pytest.fixture(scope="module") +def loaded_model(tmp_path_factory): + models_dir = tmp_path_factory.mktemp("models") + model_path = str(models_dir / "model_id") + BOOTSTRAPPED_MODEL.save(model_path) + model = EmbeddingModule.load(model_path) + return model + + def _assert_is_expected_vector(vector): assert isinstance(vector.data.values[0], np.float32) assert len(vector.data.values) == 32 @@ -132,8 +141,8 @@ def _assert_valid_scores(scores, type_tests={}): return type_tests -def test_bootstrap(): - assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap error" +def test_bootstrap_reuse(): + assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap reuse error" def test_save_load_and_run(): @@ -192,31 +201,27 @@ def test_load_without_artifacts(): EmbeddingModule.load(ModuleConfig({})) -def test_run_embedding_type_check(): +def test_run_embedding_type_check(loaded_model): """Input cannot be a list""" - model = BOOTSTRAPPED_MODEL with pytest.raises(TypeError): - model.run_embedding([INPUT]) + loaded_model.run_embedding([INPUT]) pytest.fail("Should not reach here") -def test_run_embedding(): - model = BOOTSTRAPPED_MODEL - res = model.run_embedding(text=INPUT) +def test_run_embedding(loaded_model): + res = loaded_model.run_embedding(text=INPUT) _assert_is_expected_embedding_result(res) -def test_run_embeddings_str_type(): +def test_run_embeddings_str_type(loaded_model): """Supposed to be a list, gets fixed automatically.""" - model = BOOTSTRAPPED_MODEL - res = model.run_embeddings(texts=INPUT) + res = loaded_model.run_embeddings(texts=INPUT) assert isinstance(res.results.vectors, list) assert len(res.results.vectors) == 1 -def test_run_embeddings(): - model = BOOTSTRAPPED_MODEL - res = model.run_embeddings(texts=[INPUT]) +def test_run_embeddings(loaded_model): + res = loaded_model.run_embeddings(texts=[INPUT]) assert isinstance(res.results.vectors, list) _assert_is_expected_embeddings_results(res.results) @@ -231,16 +236,16 @@ def test_run_embeddings(): (QUERY, DOCS, "topN string is not an integer or None"), ], ) -def test_run_rerank_query_type_error(query, docs, top_n): +def test_run_rerank_query_type_error(query, docs, top_n, loaded_model): """test for type checks matching task/run signature""" with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.run_rerank_query(query=query, documents=docs, top_n=top_n) + loaded_model.run_rerank_query(query=query, documents=docs, top_n=top_n) pytest.fail("Should not reach here.") -def test_run_rerank_query_no_type_error(): +def test_run_rerank_query_no_type_error(loaded_model): """no type error with list of string queries and list of dict documents""" - BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=1) + loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=1) @pytest.mark.parametrize( @@ -254,25 +259,25 @@ def test_run_rerank_query_no_type_error(): (9999, len(DOCS)), ], ) -def test_run_rerank_query_top_n(top_n, expected): - res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) +def test_run_rerank_query_top_n(top_n, expected, loaded_model): + res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) assert isinstance(res, RerankResult) assert len(res.result.scores) == expected -def test_run_rerank_query_no_query(): +def test_run_rerank_query_no_query(loaded_model): with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.run_rerank_query(query=None, documents=DOCS, top_n=99) + loaded_model.run_rerank_query(query=None, documents=DOCS, top_n=99) -def test_run_rerank_query_zero_docs(): +def test_run_rerank_query_zero_docs(loaded_model): """No empty doc list therefore result is zero result scores""" with pytest.raises(ValueError): - BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=[], top_n=99) + loaded_model.run_rerank_query(query=QUERY, documents=[], top_n=99) -def test_run_rerank_query(): - res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS) +def test_run_rerank_query(loaded_model): + res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS) assert isinstance(res, RerankResult) scores = res.result.scores @@ -286,16 +291,16 @@ def test_run_rerank_query(): @pytest.mark.parametrize( "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] ) -def test_run_rerank_queries_type_error(queries, docs): +def test_run_rerank_queries_type_error(queries, docs, loaded_model): """type error check ensures params are lists and not just 1 string or just one doc (for example)""" with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs) + loaded_model.run_rerank_queries(queries=queries, documents=docs) pytest.fail("Should not reach here.") -def test_run_rerank_queries_no_type_error(): +def test_run_rerank_queries_no_type_error(loaded_model): """no type error with list of string queries and list of dict documents""" - BOOTSTRAPPED_MODEL.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99) + loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99) @pytest.mark.parametrize( @@ -309,11 +314,9 @@ def test_run_rerank_queries_no_type_error(): (9999, len(DOCS)), ], ) -def test_run_rerank_queries_top_n(top_n, expected): +def test_run_rerank_queries_top_n(top_n, expected, loaded_model): """no type error with list of string queries and list of dict documents""" - res = BOOTSTRAPPED_MODEL.run_rerank_queries( - queries=QUERIES, documents=DOCS, top_n=top_n - ) + res = loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=top_n) assert isinstance(res, RerankResults) assert len(res.results) == len(QUERIES) for result in res.results: @@ -329,16 +332,16 @@ def test_run_rerank_queries_top_n(top_n, expected): ], ids=["no queries", "no docs", "no queries and no docs"], ) -def test_run_rerank_queries_no_queries_or_no_docs(queries, docs): +def test_run_rerank_queries_no_queries_or_no_docs(queries, docs, loaded_model): """No queries and/or no docs therefore result is zero results""" with pytest.raises(ValueError): - BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs, top_n=9) + loaded_model.run_rerank_queries(queries=queries, documents=docs, top_n=9) -def test_run_rerank_queries(): +def test_run_rerank_queries(loaded_model): top_n = 2 - rerank_result = BOOTSTRAPPED_MODEL.run_rerank_queries( + rerank_result = loaded_model.run_rerank_queries( queries=QUERIES, documents=DOCS, top_n=top_n ) assert isinstance(rerank_result, RerankResults) @@ -360,18 +363,20 @@ def test_run_rerank_queries(): _assert_types_found(types_found) -def test_run_sentence_similarity(): - model = BOOTSTRAPPED_MODEL - res = model.run_sentence_similarity(source_sentence=QUERY, sentences=SENTENCES) +def test_run_sentence_similarity(loaded_model): + res = loaded_model.run_sentence_similarity( + source_sentence=QUERY, sentences=SENTENCES + ) scores = res.result.scores assert len(scores) == len(SENTENCES) for score in scores: assert isinstance(score, float) -def test_run_sentence_similarities(): - model = BOOTSTRAPPED_MODEL - res = model.run_sentence_similarities(source_sentences=QUERIES, sentences=SENTENCES) +def test_run_sentence_similarities(loaded_model): + res = loaded_model.run_sentence_similarities( + source_sentences=QUERIES, sentences=SENTENCES + ) results = res.results assert len(results) == len(QUERIES) for result in results: @@ -444,136 +449,144 @@ def test__optimize(monkeypatch): assert fake == EmbeddingModule._optimize(fake, False, "bogus") -@pytest.mark.parametrize( - "truncate_input_tokens, expected_len", [(99, 205), (333, 673), (-1, 1022)] -) -def test__truncate_input_tokens(truncate_input_tokens, expected_len): - model = BOOTSTRAPPED_MODEL - model_max = model.model.max_seq_length +@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 10, 333]) +def test__truncate_input_tokens(truncate_input_tokens, loaded_model): - too_long = "x " * (model_max - 1) # This will go over - actual = model._truncate_input_tokens( - truncate_input_tokens=truncate_input_tokens, texts=[too_long] - )[0] + if truncate_input_tokens < 0: + num_xs = 500 # fill-er up + else: + num_xs = truncate_input_tokens - 4 # subtract room for (y y), but not z - assert len(actual) == expected_len + too_long = "x " * num_xs + "y y z " # z will go over + actual = loaded_model._truncate_input_tokens( + truncate_input_tokens=truncate_input_tokens, texts=[too_long, too_long] + ) + + assert actual[0] == actual[1] # they are still the same + + if truncate_input_tokens < 0: + assert actual[0] == too_long, "expected no truncation" + else: + assert actual[0] + " z " == too_long, "expected truncation" @pytest.mark.parametrize("truncate_input_tokens", [0, 513]) -def test__truncate_input_tokens_raises(truncate_input_tokens): - model = BOOTSTRAPPED_MODEL - model_max = model.model.max_seq_length +def test__truncate_input_tokens_raises(truncate_input_tokens, loaded_model): + model_max = loaded_model.model.max_seq_length too_long = "x " * (model_max - 1) # This will go over with pytest.raises(ValueError): - model._truncate_input_tokens( + loaded_model._truncate_input_tokens( truncate_input_tokens=truncate_input_tokens, texts=[too_long] ) -def test_not_too_many_tokens(): +def test_not_too_many_tokens(loaded_model): """Happy path for the endpoints using text that is not too many tokens.""" - model = BOOTSTRAPPED_MODEL - model_max = model.model.max_seq_length + model_max = loaded_model.model.max_seq_length ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens # embedding(s) - model.run_embedding(text=ok) - model.run_embeddings(texts=[ok]) + loaded_model.run_embedding(text=ok) + loaded_model.run_embeddings(texts=[ok]) # sentence similarity(ies) test both source_sentence and sentences - model.run_sentence_similarity(source_sentence=ok, sentences=[ok]) - model.run_sentence_similarities(source_sentences=[ok], sentences=[ok]) + loaded_model.run_sentence_similarity(source_sentence=ok, sentences=[ok]) + loaded_model.run_sentence_similarities(source_sentences=[ok], sentences=[ok]) # reranker test both query and document text - model.run_rerank_query(query=ok, documents=[{"text": ok}]) - model.run_rerank_queries(queries=[ok], documents=[{"text": ok}]) + loaded_model.run_rerank_query(query=ok, documents=[{"text": ok}]) + loaded_model.run_rerank_queries(queries=[ok], documents=[{"text": ok}]) -def test_too_many_tokens_default(): +def test_too_many_tokens_default(loaded_model): """These endpoints raise an error when truncation would happen.""" - model = BOOTSTRAPPED_MODEL - model_max = model.model.max_seq_length + model_max = loaded_model.model.max_seq_length ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens too_long = "x " * (model_max - 1) # This will go over # embedding(s) with pytest.raises(ValueError): - model.run_embedding(text=too_long) + loaded_model.run_embedding(text=too_long) with pytest.raises(ValueError): - model.run_embeddings(texts=[too_long]) + loaded_model.run_embeddings(texts=[too_long]) # sentence similarity(ies) test both source_sentence and sentences with pytest.raises(ValueError): - model.run_sentence_similarity(source_sentence=too_long, sentences=[ok]) + loaded_model.run_sentence_similarity(source_sentence=too_long, sentences=[ok]) with pytest.raises(ValueError): - model.run_sentence_similarity(source_sentence=ok, sentences=[too_long]) + loaded_model.run_sentence_similarity(source_sentence=ok, sentences=[too_long]) with pytest.raises(ValueError): - model.run_sentence_similarities(source_sentences=[too_long], sentences=[ok]) + loaded_model.run_sentence_similarities( + source_sentences=[too_long], sentences=[ok] + ) with pytest.raises(ValueError): - model.run_sentence_similarities(source_sentences=[ok], sentences=[too_long]) + loaded_model.run_sentence_similarities( + source_sentences=[ok], sentences=[too_long] + ) # reranker test both query and document text with pytest.raises(ValueError): - model.run_rerank_query(query=too_long, documents=[{"text": ok}]) + loaded_model.run_rerank_query(query=too_long, documents=[{"text": ok}]) with pytest.raises(ValueError): - model.run_rerank_query(query=ok, documents=[{"text": too_long}]) + loaded_model.run_rerank_query(query=ok, documents=[{"text": too_long}]) with pytest.raises(ValueError): - model.run_rerank_queries(queries=[too_long], documents=[{"text": ok}]) + loaded_model.run_rerank_queries(queries=[too_long], documents=[{"text": ok}]) with pytest.raises(ValueError): - model.run_rerank_queries(queries=[ok], documents=[{"text": too_long}]) + loaded_model.run_rerank_queries(queries=[ok], documents=[{"text": too_long}]) @pytest.mark.parametrize("truncate_input_tokens", [0, 513]) -def test_too_many_tokens_error_params(truncate_input_tokens): +def test_too_many_tokens_error_params(truncate_input_tokens, loaded_model): """truncate_input_tokens does not prevent these endpoints from raising an error. Test with 0 which uses the max model len (512) to determine truncation and raise error. Test with 513 (> 512) which detects truncation over 512 and raises an error. """ - model = BOOTSTRAPPED_MODEL - model_max = model.model.max_seq_length + model_max = loaded_model.model.max_seq_length ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens too_long = "x " * (model_max - 1) # This will go over # embedding(s) with pytest.raises(ValueError): - model.run_embedding(text=too_long, truncate_input_tokens=truncate_input_tokens) + loaded_model.run_embedding( + text=too_long, truncate_input_tokens=truncate_input_tokens + ) with pytest.raises(ValueError): - model.run_embeddings( + loaded_model.run_embeddings( texts=[too_long], truncate_input_tokens=truncate_input_tokens ) # sentence similarity(ies) test both source_sentence and sentences with pytest.raises(ValueError): - model.run_sentence_similarity( + loaded_model.run_sentence_similarity( source_sentence=too_long, sentences=[ok], truncate_input_tokens=truncate_input_tokens, ) with pytest.raises(ValueError): - model.run_sentence_similarity( + loaded_model.run_sentence_similarity( source_sentence=ok, sentences=[too_long], truncate_input_tokens=truncate_input_tokens, ) with pytest.raises(ValueError): - model.run_sentence_similarities( + loaded_model.run_sentence_similarities( source_sentences=[too_long], sentences=[ok], truncate_input_tokens=truncate_input_tokens, ) with pytest.raises(ValueError): - model.run_sentence_similarities( + loaded_model.run_sentence_similarities( source_sentences=[ok], sentences=[too_long], truncate_input_tokens=truncate_input_tokens, @@ -581,26 +594,26 @@ def test_too_many_tokens_error_params(truncate_input_tokens): # reranker test both query and document text with pytest.raises(ValueError): - model.run_rerank_query( + loaded_model.run_rerank_query( query=too_long, documents=[{"text": ok}], truncate_input_tokens=truncate_input_tokens, ) with pytest.raises(ValueError): - model.run_rerank_query( + loaded_model.run_rerank_query( query=ok, documents=[{"text": too_long}], truncate_input_tokens=truncate_input_tokens, ) with pytest.raises(ValueError): - model.run_rerank_queries( + loaded_model.run_rerank_queries( queries=[too_long], documents=[{"text": ok}], truncate_input_tokens=truncate_input_tokens, ) with pytest.raises(ValueError): - model.run_rerank_queries( + loaded_model.run_rerank_queries( queries=[ok], documents=[{"text": too_long}], truncate_input_tokens=truncate_input_tokens, @@ -608,65 +621,136 @@ def test_too_many_tokens_error_params(truncate_input_tokens): @pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 512]) -def test_too_many_tokens_with_truncation_working(truncate_input_tokens): +def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model): """truncate_input_tokens prevents these endpoints from raising an error when too many tokens. Test with -1 which lets the model do truncation instead of raising an error. Test with 99 (< 512) which causes our code to do the truncation instead of raising an error. """ - model = BOOTSTRAPPED_MODEL - model_max = model.model.max_seq_length + model_max = loaded_model.model.max_seq_length ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens too_long = "x " * (model_max - 1) # This will go over # embedding(s) - model.run_embedding(text=too_long, truncate_input_tokens=truncate_input_tokens) - model.run_embeddings(texts=[too_long], truncate_input_tokens=truncate_input_tokens) + loaded_model.run_embedding( + text=too_long, truncate_input_tokens=truncate_input_tokens + ) + loaded_model.run_embeddings( + texts=[too_long], truncate_input_tokens=truncate_input_tokens + ) # sentence similarity(ies) test both source_sentence and sentences - model.run_sentence_similarity( + loaded_model.run_sentence_similarity( source_sentence=too_long, sentences=[ok], truncate_input_tokens=truncate_input_tokens, ) - model.run_sentence_similarity( + loaded_model.run_sentence_similarity( source_sentence=ok, sentences=[too_long], truncate_input_tokens=truncate_input_tokens, ) - model.run_sentence_similarities( + loaded_model.run_sentence_similarities( source_sentences=[too_long], sentences=[ok], truncate_input_tokens=truncate_input_tokens, ) - model.run_sentence_similarities( + loaded_model.run_sentence_similarities( source_sentences=[ok], sentences=[too_long], truncate_input_tokens=truncate_input_tokens, ) # reranker test both query and document text - model.run_rerank_query( + loaded_model.run_rerank_query( query=too_long, documents=[{"text": ok}], truncate_input_tokens=truncate_input_tokens, ) - model.run_rerank_query( + loaded_model.run_rerank_query( query=ok, documents=[{"text": too_long}], truncate_input_tokens=truncate_input_tokens, ) - model.run_rerank_queries( + loaded_model.run_rerank_queries( queries=[too_long], documents=[{"text": ok}], truncate_input_tokens=truncate_input_tokens, ) - model.run_rerank_queries( + loaded_model.run_rerank_queries( queries=[ok], documents=[{"text": too_long}], truncate_input_tokens=truncate_input_tokens, ) + + +@pytest.mark.parametrize("truncate_input_tokens", [99, 512, -1]) +def test_embeddings_with_truncation(truncate_input_tokens, loaded_model): + """verify that results are as expected with truncation""" + + if truncate_input_tokens is None or truncate_input_tokens < 0: + # For -1 we don't truncate, but sentence-transformers will truncate at max_seq_length + repeat = loaded_model.model.max_seq_length + else: + repeat = truncate_input_tokens + + # Build a text like "x x x.. x " with room for one more token + repeat = repeat - 2 # space for start/end tokens + repeat = repeat - 1 # space for the final x or y token to show difference + + base = "x " * repeat # A bunch of "x" tokens + x = base + "x" # One last "x" that will not get truncated + y = base + "y" # A different last character "y" not truncated + z = y + "z" # Add token "z" after "y". This should get truncated. + + res = loaded_model.run_embeddings( + texts=[base, x, y, z], truncate_input_tokens=truncate_input_tokens + ) + + vectors = res.results.vectors + + # x...xyz is the same as x...xy because that is exactly where truncation worked + assert np.allclose(vectors[2].data.values, vectors[3].data.values) + + # Make sure the base, x, y are not a match (we kept the significant last char) + assert not np.allclose(vectors[0].data.values, vectors[1].data.values) + assert not np.allclose(vectors[0].data.values, vectors[2].data.values) + assert not np.allclose(vectors[1].data.values, vectors[2].data.values) + + +def test__with_retry_happy_path(loaded_model): + """works with args/kwargs, no problems""" + loaded_model._with_retry(print, "hello", "world", sep="<:)>", end="!!!\n") + + +def test__with_retry_fail(loaded_model): + """fn never works, loops then raises RuntimeError""" + + def fn(): + assert 0 + + with pytest.raises(RuntimeError): + loaded_model._with_retry(fn) + + +def test__with_retry_fail_fail_win(loaded_model): + """fn needs a few tries, logs, loops and succeeds""" + + def generate_ints(): + yield from range(9) # More than enough for retry loop + + ints = generate_ints() + + def fail_fail_win(): + for i in ints: + if i < 2: # fail, fail + assert 0 + else: # win + return i + 1 + + # Third try did not raise an exception. Returns 3. + assert 3 == loaded_model._with_retry(fail_fail_win)