From 5b0989f2d84e0e07cfd8c570868f985f0a5d543a Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 4 Sep 2024 15:35:16 -0700 Subject: [PATCH 1/7] CrossEncoderModule with rerank API This module is closely related to EmbeddingModule. Cross-encoder models use Q and A pairs and are trained return a relevance score for rank(). The existing rerank APIs in EmbeddingModule had to encode Q and A separately and use cosine similarity as a score. So the API is the same, but the results are supposed to be better (and slower). Cross-encoder models do not support returning embedding vectors or sentence-similarity. Support for the existing tokenization and model_info endpoints was also added. Signed-off-by: Mark Sturdevant --- caikit_nlp/modules/text_embedding/__init__.py | 1 + .../modules/text_embedding/crossencoder.py | 807 ++++++++++++++++++ .../text_embedding/test_crossencoder.py | 443 ++++++++++ 3 files changed, 1251 insertions(+) create mode 100644 caikit_nlp/modules/text_embedding/crossencoder.py create mode 100644 tests/modules/text_embedding/test_crossencoder.py diff --git a/caikit_nlp/modules/text_embedding/__init__.py b/caikit_nlp/modules/text_embedding/__init__.py index 2451f4a2..58e88b33 100644 --- a/caikit_nlp/modules/text_embedding/__init__.py +++ b/caikit_nlp/modules/text_embedding/__init__.py @@ -29,4 +29,5 @@ """ # Local +from .crossencoder import CrossEncoderModule from .embedding import EmbeddingModule diff --git a/caikit_nlp/modules/text_embedding/crossencoder.py b/caikit_nlp/modules/text_embedding/crossencoder.py new file mode 100644 index 00000000..220897da --- /dev/null +++ b/caikit_nlp/modules/text_embedding/crossencoder.py @@ -0,0 +1,807 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from copy import deepcopy +from functools import partial +from typing import Any, Dict, List, NamedTuple, Optional, TypeVar, Union +import importlib +import os +import threading + +# Third Party +from sentence_transformers import CrossEncoder +from torch.backends import mps +from torch.utils.data import DataLoader +import numpy as np +import torch + +# First Party +from caikit import get_config +from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.core.data_model.json_dict import JsonDict +from caikit.core.exceptions import error_handler +from caikit.interfaces.nlp.data_model import ( + RerankResult, + RerankResults, + RerankScore, + RerankScores, + Token, + TokenizationResults, +) +from caikit.interfaces.nlp.tasks import RerankTask, RerankTasks, TokenizationTask +import alog + +# Local +from caikit_nlp.modules.text_embedding.utils import env_val_to_bool + +logger = alog.use_channel("TXT_EMB") +error = error_handler.get(logger) + + +RT = TypeVar("RT") # return type + + +class RerankResultTuple(NamedTuple): + """Output of modified predict()""" + + scores: list + input_token_count: int + + +class PredictResultTuple(NamedTuple): + """Output of modified predict()""" + + scores: np.ndarray + input_token_count: int + + +# pylint: disable=too-many-lines disable=duplicate-code +@module( + "1673f8f2-726f-48cb-93a1-540c81f0f3c9", + "CrossEncoderModule", + "0.0.1", + tasks=[ + RerankTask, + RerankTasks, + TokenizationTask, + ], +) +class CrossEncoderModule(ModuleBase): + + _ARTIFACTS_PATH_KEY = "artifacts_path" + _ARTIFACTS_PATH_DEFAULT = "artifacts" + + def __init__( + self, + model: "CrossEncoderWithTruncate", + ): + super().__init__() + self.model = model + + # model_max_length attribute availability might(?) vary by model/tokenizer + self.model_max_length = getattr(model.tokenizer, "model_max_length", None) + + @classmethod + def load( + cls, model_path: Union[str, ModuleConfig], *args, **kwargs + ) -> "CrossEncoderModule": + """Load model + + Args: + model_path (Union[str, ModuleConfig]): Path to saved model or + in-memory ModuleConfig + + Returns: + CrossEncoderModule + Instance of this class built from the model. + """ + + config = ModuleConfig.load(model_path) + error.dir_check("", config.model_path) + + artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) + error.value_check( + "", + artifacts_path, + ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"), + ) + + artifacts_path = os.path.abspath( + os.path.join(config.model_path, artifacts_path) + ) + error.dir_check("", artifacts_path) + + # Read config/env settings that are needed at load time. + embedding_cfg = get_config().get("embedding", {}) + + trust_remote_code = env_val_to_bool(embedding_cfg.get("trust_remote_code")) + device = cls._select_device(False, embedding_cfg.get("device", "")) + + model = CrossEncoderWithTruncate( + model_name=artifacts_path, + device=device, + trust_remote_code=trust_remote_code, + ) + + return cls(model) + + @property + def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument + """Helper property to return public metadata about a specific Model. This + function is separate from `metadata` as that contains the entire ModelConfig + which might not want to be shared/exposed. + + Returns: + Dict[str, str]: A dictionary of this model's public metadata + """ + + return ( + {"max_seq_length": cls.model_max_length} + if cls.model_max_length is not None + else {} + ) + + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + result = self.model.get_tokenized([text], return_offsets_mapping=True) + + mapping = [ + interv for interv in result.offset_mapping[0] if (interv[1] - interv[0]) > 0 + ] + tokens = [Token(start=i[0], end=i[1], text=text[i[0] : i[1]]) for i in mapping] + + return TokenizationResults(token_count=len(result.input_ids[0]), results=tokens) + + @classmethod + def _get_ipex(cls, ipex_flag): + """Get IPEX optimization library if enabled and available, else return False + + Returns ipex library or False + """ + ret = False + + # Enabled by environment variable + # When IPEX is not false, attempt to import the library and use it. + if ipex_flag: + try: + ret = importlib.import_module("intel_extension_for_pytorch") + except Exception as ie: # pylint: disable=broad-exception-caught + # We don't require the module so catch, log, proceed to return False + msg = ( + f"IPEX enabled in env, but skipping ipex.optimize() because " + f"import intel_extension_for_pytorch failed with exception: {ie}" + ) + logger.warning(msg, exc_info=True) + + return ret + + @staticmethod + def _select_device(use_ipex, device): + """Use environment variables and availability to determine the device to use""" + if use_ipex: + # If enabled, use "xpu" (IPEX on GPU instead of IPEX on CPU) + if device == "xpu": + return "xpu" + elif device == "mps" and mps.is_built() and mps.is_available(): + # Never use on ipex, but otherwise use mps if enabled and available + return "mps" + + return "cuda" if torch.cuda.is_available() else None + + @staticmethod + def _get_backend(use_ipex, use_device): + """Determine the backend to use for torch compile. + + Considers global ipex if enabled first, next mps device, finally defaults. + + Returns the backend for torch.compile() + """ + if use_ipex: + return "ipex" + if use_device == "mps": + return mps + return "inductor" # default backend + + @staticmethod + def _optimize(model, ipex, device, autocast, pt2_compile): + if ipex: + if autocast: # IPEX performs best with autocast using bfloat16 + model = ipex.optimize( + model, dtype=torch.bfloat16, weights_prepack=False + ) + else: + model = ipex.optimize(model, weights_prepack=False) + + # torch.compile won't work everywhere, but when set we'll try it + if pt2_compile: + backend = CrossEncoderModule._get_backend(ipex, device) + try: + model = torch.compile(model, backend=backend, mode="max-autotune") + except Exception as e: # pylint: disable=broad-exception-caught + # Not always supported (e.g. in a python version) so catch, log, proceed. + warn_msg = ( + f"PT2_COMPILE enabled, but continuing without torch.compile() " + f"because it failed with exception: {e}" + ) + logger.warning(warn_msg, exc_info=True) + return model + + @RerankTask.taskmethod() + def run_rerank_query( + self, + query: str, + documents: List[JsonDict], + top_n: Optional[int] = None, + truncate_input_tokens: Optional[int] = 0, + return_documents: bool = True, + return_query: bool = True, + return_text: bool = True, + ) -> RerankResult: + """Rerank the documents returning the most relevant top_n in order for this query. + Args: + query: str + Query is the source string to be compared to the text of the documents. + documents: List[JsonDict] + Each document is a dict. The text value is used for comparison to the query. + If there is no text key, then _text is used and finally default is "". + top_n: Optional[int] + Results for the top n most relevant documents will be returned. + If top_n is not provided or (not > 0), then all are returned. + truncate_input_tokens: int + Truncation length for input tokens. + If less than zero, this is disabled (returns texts without processing). + If zero or greater than the model's maximum, then this is a test + to see if truncation is needed. If needed, an exception is thrown. + Otherwise, we take this usable truncation limit to truncate the tokens and then + decode them to return truncated strings that can be used with this model. + return_documents: bool + Default True + Setting to False will disable returning of the input document (index is returned). + return_query: bool + Default True + Setting to False will disable returning of the query (results are in query order) + return_text: bool + Default True + Setting to False will disable returning of document text string that was used. + Returns: + RerankResult + Returns the (top_n) scores in relevance order (most relevant first). + The results always include a score and index which may be used to find the document + in the original documents list. Optionally, the results also contain the entire + document with its score (for use in chaining) and for convenience the query and + text used for comparison may be returned. + + """ + + error.type_check( + "", + int, + allow_none=True, + top_n=top_n, + ) + + error.type_check( + "", + str, + query=query, + ) + + results = self.run_rerank_queries( + queries=[query], + documents=documents, + top_n=top_n, + truncate_input_tokens=truncate_input_tokens, + return_documents=return_documents, + return_queries=return_query, + return_text=return_text, + ) + + if results.results: + return RerankResult( + result=results.results[0], + producer_id=self.PRODUCER_ID, + input_token_count=results.input_token_count, + ) + + RerankResult( + result=RerankScore( + scores=[], + query=query if return_query else None, + ), + producer_id=self.PRODUCER_ID, + input_token_count=results.input_token_count, + ) + + @RerankTasks.taskmethod() + def run_rerank_queries( + self, + queries: List[str], + documents: List[JsonDict], + top_n: Optional[int] = None, + truncate_input_tokens: Optional[int] = 0, + return_documents: bool = True, + return_queries: bool = True, + return_text: bool = True, + ) -> RerankResults: + """Rerank the documents returning the most relevant top_n in order for each of the queries. + Args: + queries: List[str] + Each of the queries will be compared to the text of each of the documents. + documents: List[JsonDict] + Each document is a dict. The text value is used for comparison to the query. + If there is no text key, then _text is used and finally default is "". + top_n: Optional[int] + Results for the top n most relevant documents will be returned. + If top_n is not provided or (not > 0), then all are returned. + truncate_input_tokens: int + Truncation length for input tokens. + If less than zero, this is disabled (returns texts without processing). + If zero or greater than the model's maximum, then this is a test + to see if truncation is needed. If needed, an exception is thrown. + Otherwise, we take this usable truncation limit to truncate the tokens and then + decode them to return truncated strings that can be used with this model. + return_documents: bool + Default True + Setting to False will disable returning of the input document (index is returned). + return_queries: bool + Default True + Setting to False will disable returning of the query (results are in query order) + return_text: bool + Default True + Setting to False will disable returning of document text string that was used. + Returns: + RerankResults + For each query in queries (in the original order)... + Returns the (top_n) scores in relevance order (most relevant first). + The results always include a score and index which may be used to find the document + in the original documents list. Optionally, the results also contain the entire + document with its score (for use in chaining) and for convenience the query and + text used for comparison may be returned. + """ + + error.type_check( + "", + list, + queries=queries, + documents=documents, + ) + + error.value_check( + "", + queries and documents, + "Cannot rerank without a query and at least one document", + ) + + if top_n is None or top_n < 1: + top_n = len(documents) + + # Using input document dicts so get "text" else "_text" else default to "" + def get_text(doc): + return doc.get("text") or doc.get("_text", "") + + doc_texts = [get_text(doc) for doc in documents] + + input_token_count = 0 + results = [] + for query in queries: + scores, token_count = self.model.rank( + query=query, + documents=doc_texts, + top_k=top_n, + return_documents=False, + batch_size=32, + truncate_input_tokens=truncate_input_tokens, + ) + results.append(scores) + input_token_count += token_count + + # Fixup result dicts + for r in results: + for x in r: + x["score"] = float(x["score"].item()) + # Renaming corpus_id to index + corpus_id = x.pop("corpus_id") + x["index"] = corpus_id + # Optionally adding the original document and/or just the text that was used + if return_documents: + x["document"] = documents[corpus_id] + if return_text: + x["text"] = get_text(documents[corpus_id]) + + def add_query(q): + return queries[q] if return_queries else None + + results = [ + RerankScores( + query=add_query(q), + scores=[RerankScore(**x) for x in r], + ) + for q, r in enumerate(results) + ] + + return RerankResults( + results=results, + producer_id=self.PRODUCER_ID, + input_token_count=input_token_count, + ) + + @classmethod + def bootstrap(cls, *args, **kwargs) -> "CrossEncoderModule": + """Bootstrap a cross-encoder model + + Args: + args/kwargs are passed to CrossEncoder + """ + + # Add ability to bootstrap with trust_remote_code using env var. + if "trust_remote_code" not in kwargs: + # Read config/env settings that are needed at bootstrap time. + embedding_cfg = get_config().get("embedding", {}) + kwargs["trust_remote_code"] = env_val_to_bool( + embedding_cfg.get("trust_remote_code") + ) + + return cls(model=CrossEncoder(*args, **kwargs)) + + def save(self, model_path: str, *args, **kwargs): + """Save model using config in model_path + + Args: + model_path: str + Path to model config + """ + + error.type_check("", str, model_path=model_path) + model_config_path = model_path.strip() + error.value_check( + "", + model_config_path, + f"model_path '{model_config_path}' is invalid", + ) + + model_config_path = os.path.abspath( + model_config_path.strip() + ) # No leading/trailing spaces sneaky weirdness + + # Only allow new dirs because there are not enough controls to safely update in-place + os.makedirs(model_config_path, exist_ok=False) + + saver = ModuleSaver( + module=self, + model_path=model_config_path, + ) + artifacts_path = self._ARTIFACTS_PATH_DEFAULT + saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path}) + + # Save the model + self.model.save(os.path.join(model_config_path, artifacts_path)) + + # Save the config + ModuleConfig(saver.config).save(model_config_path) + + +class CrossEncoderWithTruncate(CrossEncoder): + def __init__( + self, + model_name: str, + num_labels: int = None, + max_length: int = None, + device: str = None, + tokenizer_args: Dict = None, + automodel_args: Dict = None, + trust_remote_code: bool = False, + revision: Optional[str] = None, + local_files_only: bool = False, + default_activation_function=None, + classifier_dropout: float = None, + ): + super().__init__( + model_name, + num_labels, + max_length, + device, + tokenizer_args, + automodel_args, + trust_remote_code, + revision, + local_files_only, + default_activation_function, + classifier_dropout, + ) + self.tokenizers = {} + + def _get_tokenizer_per_thread(self): + """Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID).""" + + # Keep copies of tokenizer per thread (in each wrapped model instance) + thread_id = threading.get_ident() + tokenizer = ( + self.tokenizers[thread_id] + if thread_id in self.tokenizers + else self.tokenizers.setdefault(thread_id, deepcopy(self.tokenizer)) + ) + + return tokenizer + + def get_tokenized(self, texts, **kwargs): + """Intentionally always call tokenizer the same way to avoid thread issues. + + Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID). + + Avoid changing the max length, truncation, and padding to avoid the + "Already borrowed" errors that come with concurrent threads attempting to use + the fast tokenizer with different truncation settings. + """ + + max_len = kwargs.get("truncate_input_tokens", self.tokenizer.model_max_length) + max_len = min(max_len, self.tokenizer.model_max_length) + if max_len <= 0: + max_len = None + + tokenizer = self._get_tokenizer_per_thread() + tokenized = tokenizer( + *texts, + return_attention_mask=True, # Used for determining token count + return_token_type_ids=False, # Needed for cross-encoders + return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches + return_offsets_mapping=True, # Used for truncation test + return_length=False, + return_tensors="pt", + truncation=True, + padding=True, + max_length=max_len, + ) + return tokenized + + def _truncation_needed(self, tokenized, max_length, texts): + """Check for truncation needed to meet max_length token limit + Returns: + List of indexes of the texts that need truncating ([] if none) + """ + + ret = [] # List of indexes for texts that need truncation + + if max_length is None: + max_length = self.tokenizer.model_max_length + + for i, encoding in enumerate(tokenized.encodings): + input_tokens = sum(encoding.attention_mask) + if input_tokens >= self.tokenizer.model_max_length: + # At model limit, including start/end... + # This may or may not have already been truncated at the model limit. + # Check the strlen and last offset. + # We need to know this, for "not okay_to_truncate" errors. + offsets = encoding.offsets + type_ids = encoding.type_ids + attn_mask = encoding.attention_mask + # Find the last offset by counting attn masks + # and keeping the last non-zero offset end. + token_count = 0 + index = 0 # index of longest + type_id = 0 # track type_id of longest + for n, attn in enumerate(attn_mask): + if attn == 1: + token_count += 1 + end = offsets[n][1] # Index to end character from offset + if ( + end > index + ): # Grab last non-zero end index (ensures increasing too) + type_id = type_ids[n] + index = end + if token_count >= max_length - 1: # Stop with room for an end token + break + end_index = index # longest + end_typeid = type_id # longest + + # Get position in (queries * docs) for this query or doc + if end_typeid == 0: # query + text_pos = i // len(texts[0]) + else: # doc + text_pos = i % len(texts[0]) + + if end_index < len(texts[end_typeid][text_pos]): + ret.append(i) + + return ret + + def smart_batching_collate_text_only( + self, batch, truncate_input_tokens: Optional[int] = 0 + ): + texts = [[] for _ in range(len(batch[0]))] + + for example in batch: + for idx, text in enumerate(example): + texts[idx].append(text.strip()) + + tokenized = self.get_tokenized( + texts, truncate_input_tokens=truncate_input_tokens + ) + + max_len = self.tokenizer.model_max_length + + if truncate_input_tokens == 0 or truncate_input_tokens > max_len: + # default (for zero or over max) is to error on truncation + truncated = self._truncation_needed(tokenized, max_len, texts) + + if truncated: + indexes = f"{', '.join(str(i) for i in truncated)}." + index_hint = ( + " for text at " + f"{'index' if len(truncated) == 1 else 'indexes'}: {indexes}" + ) + + error.log_raise( + "", + ValueError( + f"Token sequence length (+3 for separators) exceeds the " + f"maximum sequence length for this model ({max_len})" + f"{index_hint}" + ), + ) + + # We cannot send offset_mapping to the model with features, + # but we needed offset_mapping for other uses. + if "offset_mapping" in tokenized: + del tokenized["offset_mapping"] + + for name in tokenized: + tokenized[name] = tokenized[name].to(self._target_device) + + return tokenized + + def predict( + self, + sentences: List[List[str]], + batch_size: int = 32, + show_progress_bar: bool = None, + num_workers: int = 0, + activation_fct=None, + apply_softmax=False, + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + truncate_input_tokens: Optional[int] = 0, + ) -> Union[PredictResultTuple, List[float], np.ndarray, torch.Tensor]: + """ + Performs predictions with the CrossEncoder on the given sentence pairs. + + Args: + See overriden method for details. + truncate_input_tokens: Optional[int] = 0 added for truncation + + Returns: + Uses PredictResultTuple to add input_token_count + Union[List[float], np.ndarray, torch.Tensor]: Predictions for the passed sentence pairs. + The return type depends on the `convert_to_numpy` and `convert_to_tensor` parameters. + If `convert_to_tensor` is True, the output will be a torch.Tensor. + If `convert_to_numpy` is True, the output will be a numpy.ndarray. + Otherwise, the output will be a list of float values. + """ + input_was_string = False + if isinstance( + sentences[0], str + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + collate_fn = partial( + self.smart_batching_collate_text_only, + truncate_input_tokens=truncate_input_tokens, + ) + inp_dataloader = DataLoader( + sentences, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_workers, + shuffle=False, + ) + + iterator = inp_dataloader + + if activation_fct is None: + activation_fct = self.default_activation_function + + pred_scores = [] + input_token_count = 0 + self.model.eval() + self.model.to(self._target_device) + with torch.no_grad(): + for features in iterator: + + model_predictions = self.model(**features, return_dict=True) + logits = activation_fct(model_predictions.logits) + + if apply_softmax and len(logits[0]) > 1: + logits = torch.nn.functional.softmax(logits, dim=1) + pred_scores.extend(logits) + + # Sum the length of all encodings for all samples + for encoding in features.encodings: + # for mask in encoding.attention_mask: + input_token_count += sum(encoding.attention_mask) + + if self.config.num_labels == 1: + pred_scores = [score[0] for score in pred_scores] + + if convert_to_tensor: + pred_scores = torch.stack(pred_scores) + elif convert_to_numpy: + pred_scores = np.asarray( + [score.cpu().detach().numpy() for score in pred_scores] + ) + + if input_was_string: + pred_scores = pred_scores[0] + + return PredictResultTuple(pred_scores, input_token_count) + + def rank( + self, + query: str, + documents: List[str], + top_k: Optional[int] = None, + return_documents: bool = False, + batch_size: int = 32, + show_progress_bar: bool = None, + num_workers: int = 0, + activation_fct=None, + apply_softmax=False, + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + truncate_input_tokens: Optional[int] = 0, + ) -> Union[RerankResultTuple, List[Dict]]: + """ + Performs ranking with the CrossEncoder on the given query and documents. + + Returns a sorted list with the document indices and scores. + + Args: + See overridden method for argument description. + truncate_input_tokens (int, optional): Added to support truncation. + Returns: + RerankResultTuple: Adds input_token_count to result + """ + query_doc_pairs = [[query, doc] for doc in documents] + scores, input_token_count = self.predict( + query_doc_pairs, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + num_workers=num_workers, + activation_fct=activation_fct, + apply_softmax=apply_softmax, + convert_to_numpy=convert_to_numpy, + convert_to_tensor=convert_to_tensor, + truncate_input_tokens=truncate_input_tokens, + ) + results = [] + for i, score in enumerate(scores): + if return_documents: + results.append({"corpus_id": i, "score": score, "text": documents[i]}) + else: + results.append({"corpus_id": i, "score": score}) + + results = sorted(results, key=lambda x: x["score"], reverse=True) + return RerankResultTuple(results[:top_k], input_token_count) diff --git a/tests/modules/text_embedding/test_crossencoder.py b/tests/modules/text_embedding/test_crossencoder.py new file mode 100644 index 00000000..7000171c --- /dev/null +++ b/tests/modules/text_embedding/test_crossencoder.py @@ -0,0 +1,443 @@ +"""Tests for CrossEncoderModule""" + +# Standard +from typing import List +import os +import tempfile + +# Third Party +from pytest import approx +import pytest + +# First Party +from caikit.interfaces.nlp.data_model import ( + RerankResult, + RerankResults, + RerankScore, + RerankScores, + Token, + TokenizationResults, +) + +# Local +from caikit_nlp.modules.text_embedding import CrossEncoderModule +from tests.fixtures import SEQ_CLASS_MODEL + +## Setup ######################################################################## + +# Bootstrapped sequence classification model for reuse across tests +# .bootstrap is tested separately in the first test +# This model needs a tweak (num_labels = 1) to behave like a cross-encoder. +BOOTSTRAPPED_MODEL = CrossEncoderModule.bootstrap(SEQ_CLASS_MODEL) + +# Token counts: +# All expected token counts were calculated with reference to the +# `BertForSequenceClassification` model. Each model's tokenizer behaves differently +# which can lead to the expected token counts being invalid. + +INPUT = "The quick brown fox jumps over the lazy dog." +INPUT_TOKEN_COUNT = 36 + 2 # [CLS] Thequickbrownfoxjumpsoverthelazydog. [SEP] + +QUERY = "What is foo bar?" +QUERY_TOKEN_COUNT = 13 + 2 # [CLS] Whatisfoobar? [SEP] + +QUERIES: List[str] = [ + "Who is foo?", + "Where is the bar?", +] +QUERIES_TOKEN_COUNT = (9 + 2) + ( + 14 + 2 +) # [CLS] Whoisfoo? [SEP], [CLS] Whereisthebar? [SEP] + +DOCS = [ + { + "text": "foo", + "title": "title or whatever", + "str_test": "test string", + "int_test": 1, + "float_test": 1.234, + "score": 99999, + "nested_dict_test": {"deep1": 1, "deep string": "just testing"}, + }, + { + "_text": "bar", + "title": "title 2", + }, + { + "text": "foo and bar", + }, + { + "_text": "Where is the bar", + "another": "something else", + }, +] + +# The `text` and `_text` keys are extracted from DOCS as input to the tokenizer +# [CLS] foo [SEP], [CLS] bar [SEP], [CLS] fooandbar [SEP], [CLS] Whereisthebar [SEP] +DOCS_TOKEN_COUNT = (3 + 2) + (3 + 2) + (9 + 2) + (13 + 2) + +# [CLS] query [SEP] text [SEP] for each text in DOCS. +# Subtract one from QUERY_TOKEN_COUNT to avoid counting +# an extra [SEP]. +QUERY_DOCS_TOKENS = (QUERY_TOKEN_COUNT - 1) * len(DOCS) + DOCS_TOKEN_COUNT + +# [CLS] query [SEP] text [SEP] for each QUERY for each text in DOCS. +# Subtract len(QUERIES) from QUERY_TOKEN_COUNT to avoid counting +# an extra [SEP]. +QUERIES_DOCS_TOKENS = (QUERIES_TOKEN_COUNT - len(QUERIES)) * len(DOCS) + ( + DOCS_TOKEN_COUNT * len(QUERIES) +) + + +## Tests ######################################################################## + + +@pytest.fixture(scope="module", name="loaded_model") +def fixture_loaded_model(tmp_path_factory): + models_dir = tmp_path_factory.mktemp("models") + model_path = str(models_dir / "model_id") + BOOTSTRAPPED_MODEL.save(model_path) + model = CrossEncoderModule.load(model_path) + # Make our tiny test model act more like a cross-encoder model with 1 label + model.model.config.num_labels = 1 + return model + + +def _assert_is_expected_scores(rerank_scores): + # Just testing a few values for readability + assert isinstance(rerank_scores, RerankScores) + scores = rerank_scores.scores + assert approx(scores[0].score) == -0.015608355402946472 + assert approx(scores[1].score) == -0.015612606890499592 + assert approx(scores[2].score) == -0.015648163855075836 + + +def _assert_is_expected_rerank_result(actual): + assert isinstance(actual, RerankResult) + scores = actual.result + _assert_is_expected_scores(scores) + + +def _assert_is_expected_rerank_results(actual): + assert isinstance(actual, RerankResults) + + +def test_bootstrap(): + assert isinstance( + CrossEncoderModule.bootstrap(SEQ_CLASS_MODEL), CrossEncoderModule + ), "bootstrap error" + + +def _assert_valid_scores(scores): + for score in scores: + assert isinstance(score, RerankScore) + assert isinstance(score.score, float) + assert isinstance(score.index, int) + assert isinstance(score.text, str) + + document = score.document + assert isinstance(document, dict) + assert document == DOCS[score.index] + + # Test document key named score (None or 9999) is independent of the result score + assert score.score != document.get( + "score" + ), "unexpected passthru score same as result score" + + +def test_bootstrap_model(loaded_model): + assert isinstance(BOOTSTRAPPED_MODEL, CrossEncoderModule), "bootstrap model type" + assert ( + BOOTSTRAPPED_MODEL.model.__class__.__name__ == "CrossEncoder" + ), "bootstrap model class name" + # worth noting that bootstrap does not wrap, but load does + assert ( + loaded_model.model.__class__.__name__ == "CrossEncoderWithTruncate" + ), "loaded model class name" + + +def test_save_load_and_run(): + """Check if we can load and run a saved model successfully""" + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-xe-1st") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + new_model = CrossEncoderModule.load(model_path) + + assert isinstance(new_model, CrossEncoderModule), "save and load error" + assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" + + # Make our tiny test model act more like a cross-encoder model + new_model.model.config.num_labels = 1 + + # Use run_rerank_query just to make sure this new model is usable + top_n = 3 + rerank_result = new_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) + + assert isinstance(rerank_result, RerankResult) + + result = rerank_result.result + assert isinstance(result, RerankScores) + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == top_n + + _assert_valid_scores(scores) + + assert rerank_result.input_token_count == QUERY_DOCS_TOKENS + _assert_is_expected_rerank_result(rerank_result) + rerank_results = new_model.run_rerank_queries( + queries=QUERIES, documents=DOCS, top_n=1 + ) + _assert_is_expected_rerank_results(rerank_results) + + +def test_public_model_info(): + """Check if we can get model info successfully""" + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-xe-mi") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + new_model = CrossEncoderModule.load(model_path) + + result = new_model.public_model_info + assert "max_seq_length" in result + assert type(result["max_seq_length"]) is int + assert new_model.model.tokenizer.model_max_length == 512 + assert result["max_seq_length"] == new_model.model.tokenizer.model_max_length + + # We only have the following key(s) in model_info right now for cross-encoders... + assert list(result.keys()) == ["max_seq_length"] + + +def test_run_tokenization(loaded_model): + res = loaded_model.run_tokenizer(text=INPUT) + assert isinstance(res, TokenizationResults) + assert isinstance(res.results, list) + assert isinstance(res.results[0], Token) + assert res.token_count == INPUT_TOKEN_COUNT + + +@pytest.mark.parametrize( + "query,docs,top_n", + [ + (["test list"], DOCS, None), + (None, DOCS, 1234), + (False, DOCS, 1234), + (QUERY, {"testdict": "not list"}, 1234), + (QUERY, DOCS, "topN string is not an integer or None"), + ], +) +def test_run_rerank_query_type_error(query, docs, top_n, loaded_model): + """test for type checks matching task/run signature""" + match = r"type check failed" + with pytest.raises(TypeError, match=match): + loaded_model.run_rerank_query(query=query, documents=docs, top_n=top_n) + pytest.fail("Should not reach here.") + + +@pytest.mark.parametrize("top_n", [1, 99, None]) +def test_run_rerank_query_no_type_error(loaded_model, top_n): + """no type error with list of string queries and list of dict documents""" + res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) + + # [CLS] query [SEP] text [SEP] for each text in DOCS. + # Subtract one from QUERY_TOKEN_COUNT to avoid counting + # an extra [SEP]. + q_tokens = (QUERY_TOKEN_COUNT - 1) * len(DOCS) + expected = q_tokens + DOCS_TOKEN_COUNT + assert res.input_token_count == expected + + +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) +def test_run_rerank_query_top_n(top_n, expected, loaded_model): + res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) + assert isinstance(res, RerankResult) + assert len(res.result.scores) == expected + assert res.input_token_count == QUERY_DOCS_TOKENS + + +def test_run_rerank_query_no_query(loaded_model): + with pytest.raises(TypeError): + loaded_model.run_rerank_query(query=None, documents=DOCS, top_n=99) + + +def test_run_rerank_query_zero_docs(loaded_model): + """No empty doc list therefore result is zero result scores""" + with pytest.raises(ValueError): + loaded_model.run_rerank_query(query=QUERY, documents=[], top_n=99) + + +def test_run_rerank_query(loaded_model): + res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS) + assert isinstance(res, RerankResult) + + scores = res.result.scores + assert isinstance(scores, list) + assert len(scores) == len(DOCS) + + _assert_valid_scores(scores) + assert res.input_token_count == QUERY_DOCS_TOKENS + + +@pytest.mark.parametrize( + "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] +) +def test_run_rerank_queries_type_error(queries, docs, loaded_model): + """type error check ensures params are lists and not just 1 string or just one doc (for example)""" + with pytest.raises(TypeError): + loaded_model.run_rerank_queries(queries=queries, documents=docs) + pytest.fail("Should not reach here.") + + +def test_run_rerank_queries_no_type_error(loaded_model): + """no type error with list of string queries and list of dict documents""" + res = loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99) + + assert res.input_token_count == QUERIES_DOCS_TOKENS + + +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) +def test_run_rerank_queries_top_n(top_n, expected, loaded_model): + """no type error with list of string queries and list of dict documents""" + res = loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=top_n) + assert isinstance(res, RerankResults) + assert len(res.results) == len(QUERIES) + for result in res.results: + assert len(result.scores) == expected + assert res.input_token_count == QUERIES_DOCS_TOKENS + + +@pytest.mark.parametrize( + "queries, docs", + [ + ([], DOCS), + (QUERIES, []), + ([], []), + ], + ids=["no queries", "no docs", "no queries and no docs"], +) +def test_run_rerank_queries_no_queries_or_no_docs(queries, docs, loaded_model): + """No queries and/or no docs therefore result is zero results""" + + with pytest.raises(ValueError): + loaded_model.run_rerank_queries(queries=queries, documents=docs, top_n=9) + + +def test_run_rerank_queries(loaded_model): + top_n = 2 + rerank_result = loaded_model.run_rerank_queries( + queries=QUERIES, documents=DOCS, top_n=top_n + ) + assert isinstance(rerank_result, RerankResults) + + results = rerank_result.results + assert isinstance(results, list) + assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s) + + for result in results: + assert isinstance(result, RerankScores) + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == top_n + _assert_valid_scores(scores) + + assert rerank_result.input_token_count == QUERIES_DOCS_TOKENS + + +@pytest.mark.parametrize("truncate_input_tokens", [-1, 512]) +def test_truncate_input_tokens_default(truncate_input_tokens, loaded_model): + """Test truncation using model max. + -1 means let the model truncate at its model max + 512 is more explicitly the same thing (this model's max) + """ + model_max = loaded_model.model.tokenizer.model_max_length + + too_long = "x " * (model_max - 3) # 3 for tokens (no room for a query token) + just_barely = "x " * (model_max - 4) # 3 for tokens plus room for a query token + queries = ["x"] + docs = [{"text": t} for t in ["x", too_long, just_barely, too_long, just_barely]] + + # Just testing for no errors raised for now + _res = loaded_model.run_rerank_queries( + queries=queries, documents=docs, truncate_input_tokens=truncate_input_tokens + ) + + +@pytest.mark.parametrize("truncate_input_tokens", [0, 513]) +def test_truncate_input_tokens_errors(truncate_input_tokens, loaded_model): + """Test that we get truncation errors. + 0 (the default) means we return errors when truncation would happen. + 513+ (any number above the max) is treated the same way. + """ + model_max = loaded_model.model.tokenizer.model_max_length + + too_long = "x " * (model_max - 3) # 3 for tokens (no room for a query token) + just_barely = "x " * (model_max - 4) # 3 for tokens plus room for a query token + queries = ["x"] + docs = [{"text": t} for t in ["x", too_long, just_barely, too_long, just_barely]] + + match1 = rf"exceeds the maximum sequence length for this model \({model_max}\) for text at indexes: 1, 3." + with pytest.raises(ValueError, match=match1): + loaded_model.run_rerank_queries( + queries=queries, documents=docs, truncate_input_tokens=truncate_input_tokens + ) + + +@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 510, 511, 512]) +def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model): + """truncate_input_tokens prevents these endpoints from raising an error when too many tokens. + + Test with -1 which lets the model do truncation instead of raising an error. + Test with 99 (< 512 -2) which causes our code to do the truncation instead of raising an error. + Test with 510 (512 -2) which causes our code to do the truncation instead of raising an error. + 511 and 512 also behave like 510. The value is allowed, but begin/end tokens will take space. + """ + + model_max = loaded_model.model.tokenizer.model_max_length + + ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens + too_long = "x " * (model_max - 1) # This will go over + + # reranker test both query and document text + loaded_model.run_rerank_query( + query=too_long, + documents=[{"text": ok}], + truncate_input_tokens=truncate_input_tokens, + ) + loaded_model.run_rerank_query( + query=ok, + documents=[{"text": too_long}], + truncate_input_tokens=truncate_input_tokens, + ) + + loaded_model.run_rerank_queries( + queries=[too_long], + documents=[{"text": ok}], + truncate_input_tokens=truncate_input_tokens, + ) + loaded_model.run_rerank_queries( + queries=[ok], + documents=[{"text": too_long}], + truncate_input_tokens=truncate_input_tokens, + ) From 7146ffe2874216a340727766ed03b166352dbba4 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 11 Sep 2024 22:26:55 -0700 Subject: [PATCH 2/7] Cross-encoder improvements from code review * mostly removing unnecessary code * some better clarity Signed-off-by: Mark Sturdevant --- .../modules/text_embedding/crossencoder.py | 123 ++---------------- 1 file changed, 13 insertions(+), 110 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/crossencoder.py b/caikit_nlp/modules/text_embedding/crossencoder.py index 220897da..ac1f483b 100644 --- a/caikit_nlp/modules/text_embedding/crossencoder.py +++ b/caikit_nlp/modules/text_embedding/crossencoder.py @@ -15,14 +15,12 @@ # Standard from copy import deepcopy from functools import partial -from typing import Any, Dict, List, NamedTuple, Optional, TypeVar, Union -import importlib +from typing import Any, Dict, List, NamedTuple, Optional, Union import os import threading # Third Party from sentence_transformers import CrossEncoder -from torch.backends import mps from torch.utils.data import DataLoader import numpy as np import torch @@ -46,15 +44,12 @@ # Local from caikit_nlp.modules.text_embedding.utils import env_val_to_bool -logger = alog.use_channel("TXT_EMB") +logger = alog.use_channel("CROSS_ENCODER") error = error_handler.get(logger) -RT = TypeVar("RT") # return type - - class RerankResultTuple(NamedTuple): - """Output of modified predict()""" + """Output of modified rank()""" scores: list input_token_count: int @@ -115,7 +110,7 @@ def load( error.value_check( "", artifacts_path, - ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"), + f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'", ) artifacts_path = os.path.abspath( @@ -127,13 +122,13 @@ def load( embedding_cfg = get_config().get("embedding", {}) trust_remote_code = env_val_to_bool(embedding_cfg.get("trust_remote_code")) - device = cls._select_device(False, embedding_cfg.get("device", "")) model = CrossEncoderWithTruncate( model_name=artifacts_path, - device=device, trust_remote_code=trust_remote_code, ) + model.model.eval() + model.model.to(model._target_device) return cls(model) @@ -176,80 +171,6 @@ def run_tokenizer( return TokenizationResults(token_count=len(result.input_ids[0]), results=tokens) - @classmethod - def _get_ipex(cls, ipex_flag): - """Get IPEX optimization library if enabled and available, else return False - - Returns ipex library or False - """ - ret = False - - # Enabled by environment variable - # When IPEX is not false, attempt to import the library and use it. - if ipex_flag: - try: - ret = importlib.import_module("intel_extension_for_pytorch") - except Exception as ie: # pylint: disable=broad-exception-caught - # We don't require the module so catch, log, proceed to return False - msg = ( - f"IPEX enabled in env, but skipping ipex.optimize() because " - f"import intel_extension_for_pytorch failed with exception: {ie}" - ) - logger.warning(msg, exc_info=True) - - return ret - - @staticmethod - def _select_device(use_ipex, device): - """Use environment variables and availability to determine the device to use""" - if use_ipex: - # If enabled, use "xpu" (IPEX on GPU instead of IPEX on CPU) - if device == "xpu": - return "xpu" - elif device == "mps" and mps.is_built() and mps.is_available(): - # Never use on ipex, but otherwise use mps if enabled and available - return "mps" - - return "cuda" if torch.cuda.is_available() else None - - @staticmethod - def _get_backend(use_ipex, use_device): - """Determine the backend to use for torch compile. - - Considers global ipex if enabled first, next mps device, finally defaults. - - Returns the backend for torch.compile() - """ - if use_ipex: - return "ipex" - if use_device == "mps": - return mps - return "inductor" # default backend - - @staticmethod - def _optimize(model, ipex, device, autocast, pt2_compile): - if ipex: - if autocast: # IPEX performs best with autocast using bfloat16 - model = ipex.optimize( - model, dtype=torch.bfloat16, weights_prepack=False - ) - else: - model = ipex.optimize(model, weights_prepack=False) - - # torch.compile won't work everywhere, but when set we'll try it - if pt2_compile: - backend = CrossEncoderModule._get_backend(ipex, device) - try: - model = torch.compile(model, backend=backend, mode="max-autotune") - except Exception as e: # pylint: disable=broad-exception-caught - # Not always supported (e.g. in a python version) so catch, log, proceed. - warn_msg = ( - f"PT2_COMPILE enabled, but continuing without torch.compile() " - f"because it failed with exception: {e}" - ) - logger.warning(warn_msg, exc_info=True) - return model - @RerankTask.taskmethod() def run_rerank_query( self, @@ -320,18 +241,8 @@ def run_rerank_query( return_text=return_text, ) - if results.results: - return RerankResult( - result=results.results[0], - producer_id=self.PRODUCER_ID, - input_token_count=results.input_token_count, - ) - - RerankResult( - result=RerankScore( - scores=[], - query=query if return_query else None, - ), + return RerankResult( + result=results.results[0], producer_id=self.PRODUCER_ID, input_token_count=results.input_token_count, ) @@ -414,6 +325,7 @@ def get_text(doc): top_k=top_n, return_documents=False, batch_size=32, + convert_to_numpy=True, truncate_input_tokens=truncate_input_tokens, ) results.append(scores) @@ -684,7 +596,7 @@ def predict( convert_to_numpy: bool = True, convert_to_tensor: bool = False, truncate_input_tokens: Optional[int] = 0, - ) -> Union[PredictResultTuple, List[float], np.ndarray, torch.Tensor]: + ) -> PredictResultTuple: """ Performs predictions with the CrossEncoder on the given sentence pairs. @@ -694,11 +606,6 @@ def predict( Returns: Uses PredictResultTuple to add input_token_count - Union[List[float], np.ndarray, torch.Tensor]: Predictions for the passed sentence pairs. - The return type depends on the `convert_to_numpy` and `convert_to_tensor` parameters. - If `convert_to_tensor` is True, the output will be a torch.Tensor. - If `convert_to_numpy` is True, the output will be a numpy.ndarray. - Otherwise, the output will be a list of float values. """ input_was_string = False if isinstance( @@ -711,7 +618,7 @@ def predict( self.smart_batching_collate_text_only, truncate_input_tokens=truncate_input_tokens, ) - inp_dataloader = DataLoader( + iterator = DataLoader( sentences, batch_size=batch_size, collate_fn=collate_fn, @@ -719,15 +626,11 @@ def predict( shuffle=False, ) - iterator = inp_dataloader - if activation_fct is None: activation_fct = self.default_activation_function pred_scores = [] input_token_count = 0 - self.model.eval() - self.model.to(self._target_device) with torch.no_grad(): for features in iterator: @@ -750,7 +653,7 @@ def predict( pred_scores = torch.stack(pred_scores) elif convert_to_numpy: pred_scores = np.asarray( - [score.cpu().detach().numpy() for score in pred_scores] + [score.cpu().detach().float().item() for score in pred_scores] ) if input_was_string: @@ -772,7 +675,7 @@ def rank( convert_to_numpy: bool = True, convert_to_tensor: bool = False, truncate_input_tokens: Optional[int] = 0, - ) -> Union[RerankResultTuple, List[Dict]]: + ) -> RerankResultTuple: """ Performs ranking with the CrossEncoder on the given query and documents. From ac4699360c7ae1d1190d141add9596497195d623 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 11 Sep 2024 22:57:33 -0700 Subject: [PATCH 3/7] Cross-encoder docstring fix * The already borrowed errors are fixed with tokenizers per thread, so there were some misleading comments about not changing params for truncation (which we do for cross-encoder truncation). Signed-off-by: Mark Sturdevant --- caikit_nlp/modules/text_embedding/crossencoder.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/crossencoder.py b/caikit_nlp/modules/text_embedding/crossencoder.py index ac1f483b..bb1905db 100644 --- a/caikit_nlp/modules/text_embedding/crossencoder.py +++ b/caikit_nlp/modules/text_embedding/crossencoder.py @@ -460,14 +460,7 @@ def _get_tokenizer_per_thread(self): return tokenizer def get_tokenized(self, texts, **kwargs): - """Intentionally always call tokenizer the same way to avoid thread issues. - - Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID). - - Avoid changing the max length, truncation, and padding to avoid the - "Already borrowed" errors that come with concurrent threads attempting to use - the fast tokenizer with different truncation settings. - """ + """Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID)""" max_len = kwargs.get("truncate_input_tokens", self.tokenizer.model_max_length) max_len = min(max_len, self.tokenizer.model_max_length) @@ -480,7 +473,7 @@ def get_tokenized(self, texts, **kwargs): return_attention_mask=True, # Used for determining token count return_token_type_ids=False, # Needed for cross-encoders return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches - return_offsets_mapping=True, # Used for truncation test + return_offsets_mapping=True, # Used for truncation needed error return_length=False, return_tensors="pt", truncation=True, From 4e9c5aad94cb39bc636efa9d6ef338d40abc6ef6 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 12 Sep 2024 01:09:02 -0700 Subject: [PATCH 4/7] Cross-Encoder use configurable batch size. Default is 32. Can override with embedding batch_size in config or EMBEDDING_BATCH_SIZE env var. Signed-off-by: Mark Sturdevant --- caikit_nlp/modules/text_embedding/crossencoder.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/modules/text_embedding/crossencoder.py b/caikit_nlp/modules/text_embedding/crossencoder.py index bb1905db..73f20a2c 100644 --- a/caikit_nlp/modules/text_embedding/crossencoder.py +++ b/caikit_nlp/modules/text_embedding/crossencoder.py @@ -88,6 +88,14 @@ def __init__( # model_max_length attribute availability might(?) vary by model/tokenizer self.model_max_length = getattr(model.tokenizer, "model_max_length", None) + # Read config/env settings that are needed at run_* time. + embedding_cfg = get_config().get("embedding", {}) + + self.batch_size = embedding_cfg.get("batch_size", 32) + error.type_check("", int, EMBEDDING_BATCH_SIZE=self.batch_size) + if self.batch_size <= 0: + self.batch_size = 32 # 0 or negative, use the default. + @classmethod def load( cls, model_path: Union[str, ModuleConfig], *args, **kwargs @@ -324,7 +332,7 @@ def get_text(doc): documents=doc_texts, top_k=top_n, return_documents=False, - batch_size=32, + batch_size=self.batch_size, convert_to_numpy=True, truncate_input_tokens=truncate_input_tokens, ) From 211668ab53f48513b3073960aa8cd012fe8d6303 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 12 Sep 2024 12:27:44 -0700 Subject: [PATCH 5/7] Cross-encoder: Move truncation check and add tests * Moved the truncation check to a place that can determine the proper index for the error message (with batching). * Added test to validate some results after truncation. This is with a tiny model, but works for sanity. Signed-off-by: Mark Sturdevant --- .../modules/text_embedding/crossencoder.py | 159 +++++++++--------- .../text_embedding/test_crossencoder.py | 81 ++++++++- 2 files changed, 154 insertions(+), 86 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/crossencoder.py b/caikit_nlp/modules/text_embedding/crossencoder.py index 73f20a2c..71e66704 100644 --- a/caikit_nlp/modules/text_embedding/crossencoder.py +++ b/caikit_nlp/modules/text_embedding/crossencoder.py @@ -473,7 +473,11 @@ def get_tokenized(self, texts, **kwargs): max_len = kwargs.get("truncate_input_tokens", self.tokenizer.model_max_length) max_len = min(max_len, self.tokenizer.model_max_length) if max_len <= 0: - max_len = None + max_len = None # Use the default + elif max_len < 5: + # 1, 2, 3 don't really work (4 might but...) + # Bare minimum is [CLS] token [SEP] token [SEP] + max_len = 5 tokenizer = self._get_tokenizer_per_thread() tokenized = tokenizer( @@ -490,56 +494,42 @@ def get_tokenized(self, texts, **kwargs): ) return tokenized - def _truncation_needed(self, tokenized, max_length, texts): + def _truncation_needed(self, encoding, texts): """Check for truncation needed to meet max_length token limit Returns: - List of indexes of the texts that need truncating ([] if none) + True if was truncated, False otherwise """ - ret = [] # List of indexes for texts that need truncation - - if max_length is None: - max_length = self.tokenizer.model_max_length - - for i, encoding in enumerate(tokenized.encodings): - input_tokens = sum(encoding.attention_mask) - if input_tokens >= self.tokenizer.model_max_length: - # At model limit, including start/end... - # This may or may not have already been truncated at the model limit. - # Check the strlen and last offset. - # We need to know this, for "not okay_to_truncate" errors. - offsets = encoding.offsets - type_ids = encoding.type_ids - attn_mask = encoding.attention_mask - # Find the last offset by counting attn masks - # and keeping the last non-zero offset end. - token_count = 0 - index = 0 # index of longest - type_id = 0 # track type_id of longest - for n, attn in enumerate(attn_mask): - if attn == 1: - token_count += 1 - end = offsets[n][1] # Index to end character from offset - if ( - end > index - ): # Grab last non-zero end index (ensures increasing too) - type_id = type_ids[n] - index = end - if token_count >= max_length - 1: # Stop with room for an end token - break - end_index = index # longest - end_typeid = type_id # longest - - # Get position in (queries * docs) for this query or doc - if end_typeid == 0: # query - text_pos = i // len(texts[0]) - else: # doc - text_pos = i % len(texts[0]) - - if end_index < len(texts[end_typeid][text_pos]): - ret.append(i) - - return ret + input_tokens = sum(encoding.attention_mask) + if input_tokens < self.tokenizer.model_max_length: + return False + + # At model limit, including start/end... + # This may or may not have already been truncated at the model limit. + # Check the strlen and last offset. + # We need to know this, for default implementation of throwing error. + offsets = encoding.offsets + type_ids = encoding.type_ids + attn_mask = encoding.attention_mask + + # Find the last offset by counting attn masks + # and keeping the last non-zero offset end. + token_count = 0 + index = 0 # index of longest + type_id = 0 # track type_id of longest + + for n, attn in enumerate(attn_mask): + if attn == 1: + token_count += 1 + end = offsets[n][1] # Index to end character from offset + if end > index: # Grab last non-zero end index (ensures increasing too) + type_id = type_ids[n] + index = end + end_index = index # longest + end_typeid = type_id # longest + + # If last token offset is before the last char, then it was truncated + return end_index < len(texts[end_typeid].strip()) def smart_batching_collate_text_only( self, batch, truncate_input_tokens: Optional[int] = 0 @@ -554,37 +544,24 @@ def smart_batching_collate_text_only( texts, truncate_input_tokens=truncate_input_tokens ) - max_len = self.tokenizer.model_max_length + return tokenized - if truncate_input_tokens == 0 or truncate_input_tokens > max_len: - # default (for zero or over max) is to error on truncation - truncated = self._truncation_needed(tokenized, max_len, texts) - - if truncated: - indexes = f"{', '.join(str(i) for i in truncated)}." - index_hint = ( - " for text at " - f"{'index' if len(truncated) == 1 else 'indexes'}: {indexes}" - ) - - error.log_raise( - "", - ValueError( - f"Token sequence length (+3 for separators) exceeds the " - f"maximum sequence length for this model ({max_len})" - f"{index_hint}" - ), - ) - - # We cannot send offset_mapping to the model with features, - # but we needed offset_mapping for other uses. - if "offset_mapping" in tokenized: - del tokenized["offset_mapping"] - - for name in tokenized: - tokenized[name] = tokenized[name].to(self._target_device) + @staticmethod + def raise_truncation_error(max_len, truncation_needed_indexes): - return tokenized + indexes = f"{', '.join(str(i) for i in truncation_needed_indexes)}." + index_hint = ( + " for text at " + f"{'index' if len(truncation_needed_indexes) == 1 else 'indexes'}: {indexes}" + ) + error.log_raise( + "", + ValueError( + f"Token sequence length (+3 for separators) exceeds the " + f"maximum sequence length for this model ({max_len})" + f"{index_hint}" + ), + ) def predict( self, @@ -630,10 +607,35 @@ def predict( if activation_fct is None: activation_fct = self.default_activation_function + max_len = self.tokenizer.model_max_length pred_scores = [] input_token_count = 0 + row = -1 + truncation_needed_indexes = [] with torch.no_grad(): for features in iterator: + # Sum the length of all encodings for all samples + for encoding in features.encodings: + row += 1 + + # for mask in encoding.attention_mask: + input_token_count += sum(encoding.attention_mask) + + if truncate_input_tokens == 0 or truncate_input_tokens > max_len: + # default (for zero or over max) is to error on truncation + if self._truncation_needed(encoding, sentences[row]): + truncation_needed_indexes.append(row) + + if truncation_needed_indexes: + self.raise_truncation_error(max_len, truncation_needed_indexes) + + # # We cannot send offset_mapping to the model with features, + # # but we needed offset_mapping for other uses. + if "offset_mapping" in features: + del features["offset_mapping"] + + for name in features: + features[name] = features[name].to(self._target_device) model_predictions = self.model(**features, return_dict=True) logits = activation_fct(model_predictions.logits) @@ -642,11 +644,6 @@ def predict( logits = torch.nn.functional.softmax(logits, dim=1) pred_scores.extend(logits) - # Sum the length of all encodings for all samples - for encoding in features.encodings: - # for mask in encoding.attention_mask: - input_token_count += sum(encoding.attention_mask) - if self.config.num_labels == 1: pred_scores = [score[0] for score in pred_scores] diff --git a/tests/modules/text_embedding/test_crossencoder.py b/tests/modules/text_embedding/test_crossencoder.py index 7000171c..0a62e658 100644 --- a/tests/modules/text_embedding/test_crossencoder.py +++ b/tests/modules/text_embedding/test_crossencoder.py @@ -7,6 +7,7 @@ # Third Party from pytest import approx +import numpy as np import pytest # First Party @@ -392,12 +393,17 @@ def test_truncate_input_tokens_errors(truncate_input_tokens, loaded_model): """ model_max = loaded_model.model.tokenizer.model_max_length - too_long = "x " * (model_max - 3) # 3 for tokens (no room for a query token) - just_barely = "x " * (model_max - 4) # 3 for tokens plus room for a query token - queries = ["x"] - docs = [{"text": t} for t in ["x", too_long, just_barely, too_long, just_barely]] + too_long = "a " * (model_max - 3) # 3 for tokens (no room for a query token) + just_barely = "a " * (model_max - 4) # 3 for tokens plus room for a query token + queries = ["q"] - match1 = rf"exceeds the maximum sequence length for this model \({model_max}\) for text at indexes: 1, 3." + # Add 50 of these little ones to get past the first batch(es) + # to verify that this error message index is for the input + # position and not just an index into some internal batch. + docs = [{"text": "a"}] * 50 + docs.extend([{"text": t} for t in [too_long, just_barely, too_long, just_barely]]) + + match1 = rf"exceeds the maximum sequence length for this model \({model_max}\) for text at indexes: 50, 52." with pytest.raises(ValueError, match=match1): loaded_model.run_rerank_queries( queries=queries, documents=docs, truncate_input_tokens=truncate_input_tokens @@ -441,3 +447,68 @@ def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_m documents=[{"text": too_long}], truncate_input_tokens=truncate_input_tokens, ) + + +@pytest.mark.parametrize( + "truncate_input_tokens", [1, 2, 3, 4, 5, 6, 99, 100, 101, 510, 511, 512, -1] +) +def test_truncation(truncate_input_tokens, loaded_model): + """verify that results are as expected with truncation""" + + max_len = loaded_model.model.tokenizer.model_max_length + + if truncate_input_tokens is None or truncate_input_tokens < 0: + # For -1 we don't truncate, but model will + repeat = max_len + else: + repeat = min( + truncate_input_tokens, max_len + ) # max_len is used when we need -4 for begin/"q"/sep/end + + # Build a text like "x x x.. x " with room for one more token + repeat = repeat - 4 # room for separators and a single-token query + repeat = repeat - 1 # space for the final x or y token to show difference + + base = "" + if repeat > 0: + base = "x " * repeat # A bunch of "x" tokens + x = base + "x" # One last "x" that will not get truncated + y = base + "y" # A different last character "y" not truncated + z = y + " z" # Add token "z" after "y". This should get truncated. + + # Multiple queries to test query-loop vs queries + # Query for the significant added chars to affect score. + queries = ["y", "z"] + docs = [{"text": t} for t in [x, y, z]] + res = loaded_model.run_rerank_queries( + queries=queries, + documents=docs, + truncate_input_tokens=truncate_input_tokens, + ) + queries_results = res.results + + # Compare with results from individual embedding calls in a loop + query_results = [] + for query in queries: + r = loaded_model.run_rerank_query( + query=query, + documents=docs, + truncate_input_tokens=truncate_input_tokens, + ) + query_results.append(r.result) + + assert len(queries_results) == len( + query_results + ), "expected the same length results" + + # compare the scores (queries call vs query call in a loop) + for i, r in enumerate(queries_results): + queries_scores = [x.score for x in r.scores] + query_scores = [x.score for x in query_results[i].scores] + assert np.array_equal(queries_scores, query_scores) + + # x...xyz is the same as x...xy because that is exactly where truncation worked + assert query_scores[0] == query_scores[1] + + # Make sure the base, x, y are not a match (we kept the significant last char) + assert query_scores[1] != query_scores[2] From 2cb6183aeef9bba65200950209d10682f23da9ca Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 12 Sep 2024 13:38:28 -0700 Subject: [PATCH 6/7] Cross-encoder: fix truncation test The part that really tests that a token is truncated was wrong. * It was backwards and passing because the scores are sorted by rank * Using the index to get scores in the order of the inputs * Now correctly xx != xy but xy == xyz (truncated z) Signed-off-by: Mark Sturdevant --- tests/modules/text_embedding/test_crossencoder.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/modules/text_embedding/test_crossencoder.py b/tests/modules/text_embedding/test_crossencoder.py index 0a62e658..b928daf1 100644 --- a/tests/modules/text_embedding/test_crossencoder.py +++ b/tests/modules/text_embedding/test_crossencoder.py @@ -502,13 +502,17 @@ def test_truncation(truncate_input_tokens, loaded_model): ), "expected the same length results" # compare the scores (queries call vs query call in a loop) + # order is the same for i, r in enumerate(queries_results): queries_scores = [x.score for x in r.scores] query_scores = [x.score for x in query_results[i].scores] assert np.array_equal(queries_scores, query_scores) - # x...xyz is the same as x...xy because that is exactly where truncation worked - assert query_scores[0] == query_scores[1] + # To compare scores based on the inputs, we need to use the index too + indexed_query_scores = {s.index: s.score for s in query_results[i].scores} - # Make sure the base, x, y are not a match (we kept the significant last char) - assert query_scores[1] != query_scores[2] + # Make sure the x...xx, x...xy are not a match (we kept the significant last token) + assert indexed_query_scores[0] != indexed_query_scores[1] + + # x...xy is the same as x...xyz because we truncated the z token -- it worked! + assert indexed_query_scores[1] == indexed_query_scores[2] From 8fa67ccefb189473d2badfc2fb17fda79dbd9078 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 12 Sep 2024 13:53:17 -0700 Subject: [PATCH 7/7] Cross-encoder: remove some unused and tidy up some comments Signed-off-by: Mark Sturdevant --- caikit_nlp/modules/text_embedding/crossencoder.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/crossencoder.py b/caikit_nlp/modules/text_embedding/crossencoder.py index 71e66704..7097411f 100644 --- a/caikit_nlp/modules/text_embedding/crossencoder.py +++ b/caikit_nlp/modules/text_embedding/crossencoder.py @@ -514,19 +514,17 @@ def _truncation_needed(self, encoding, texts): # Find the last offset by counting attn masks # and keeping the last non-zero offset end. - token_count = 0 index = 0 # index of longest type_id = 0 # track type_id of longest for n, attn in enumerate(attn_mask): if attn == 1: - token_count += 1 end = offsets[n][1] # Index to end character from offset if end > index: # Grab last non-zero end index (ensures increasing too) type_id = type_ids[n] index = end - end_index = index # longest - end_typeid = type_id # longest + end_index = index # longest last char index + end_typeid = type_id # longest type (query or text) # If last token offset is before the last char, then it was truncated return end_index < len(texts[end_typeid].strip()) @@ -629,8 +627,8 @@ def predict( if truncation_needed_indexes: self.raise_truncation_error(max_len, truncation_needed_indexes) - # # We cannot send offset_mapping to the model with features, - # # but we needed offset_mapping for other uses. + # We cannot send offset_mapping to the model with features, + # but we needed offset_mapping for other uses. if "offset_mapping" in features: del features["offset_mapping"]