diff --git a/caikit_nlp/modules/text_embedding/crossencoder.py b/caikit_nlp/modules/text_embedding/crossencoder.py index 71e66704..7097411f 100644 --- a/caikit_nlp/modules/text_embedding/crossencoder.py +++ b/caikit_nlp/modules/text_embedding/crossencoder.py @@ -514,19 +514,17 @@ def _truncation_needed(self, encoding, texts): # Find the last offset by counting attn masks # and keeping the last non-zero offset end. - token_count = 0 index = 0 # index of longest type_id = 0 # track type_id of longest for n, attn in enumerate(attn_mask): if attn == 1: - token_count += 1 end = offsets[n][1] # Index to end character from offset if end > index: # Grab last non-zero end index (ensures increasing too) type_id = type_ids[n] index = end - end_index = index # longest - end_typeid = type_id # longest + end_index = index # longest last char index + end_typeid = type_id # longest type (query or text) # If last token offset is before the last char, then it was truncated return end_index < len(texts[end_typeid].strip()) @@ -629,8 +627,8 @@ def predict( if truncation_needed_indexes: self.raise_truncation_error(max_len, truncation_needed_indexes) - # # We cannot send offset_mapping to the model with features, - # # but we needed offset_mapping for other uses. + # We cannot send offset_mapping to the model with features, + # but we needed offset_mapping for other uses. if "offset_mapping" in features: del features["offset_mapping"]