Skip to content

Commit

Permalink
Text Embedding: Refactor and comment the truncation code
Browse files Browse the repository at this point in the history
* Mostly hard to understand.  Added comments and var names.
* Refactoring led to some simpler code
* Moved the common easy branchs (no truncation) up
  for readability

Signed-off-by: markstur <mark.sturdevant@ibm.com>
  • Loading branch information
markstur committed Dec 7, 2023
1 parent 060869b commit a87cc64
Showing 1 changed file with 40 additions and 31 deletions.
71 changes: 40 additions & 31 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,40 +317,21 @@ def _truncate_input_tokens(
max_length=max_length,
)

for i, text in enumerate(texts):
mapping = tokenized["overflow_to_sample_mapping"]
lengths = [
tokenized["length"][idx] for idx, v in enumerate(mapping) if v == i
texts_map = tokenized["overflow_to_sample_mapping"]

for text_number, text in enumerate(texts):
# positions: the positions (in lengths and offsets arrays) that belong to this text
positions = [
position
for position, sample_number in enumerate(texts_map)
if sample_number == text_number
]
was_truncated = len(lengths) > 1 # multiple lengths when truncated
lengths = [tokenized["length"][pos] for pos in positions]

if okay_to_truncate and was_truncated:
# Get the text offsets for the tokens that are kept after truncation
offsets = [
tokenized["offset_mapping"][idx]
for idx, v in enumerate(mapping)
if v == i
]
truncated_offsets = offsets[0]
# Find the first token offset (i.e. skip added start token)
start = next(
idx for idx, val in (enumerate(truncated_offsets)) if val != (0, 0)
)
# Find the last token offset (i.e. skip added end token)
end = next(
idx
for idx, val in reversed(list(enumerate(truncated_offsets)))
if val != (0, 0)
)
# Use the start/end offsets to slice the text based on token truncation
ret.append(
text[truncated_offsets[start][0] : truncated_offsets[end][1]]
)

elif okay_to_truncate and not was_truncated:
ret.append(text) # return original text
was_truncated = len(lengths) > 1 # multiple lengths when truncated

elif was_truncated:
if not okay_to_truncate and was_truncated:
# Raise error. We don't allow silent truncation in this case.
tokens = sum(lengths) # add up total tokens for error message
error.log_raise(
"<NLP08391926E>",
Expand All @@ -360,6 +341,34 @@ def _truncate_input_tokens(
),
)

elif okay_to_truncate and not was_truncated:
ret.append(text) # collect original text to return

elif okay_to_truncate and was_truncated:
# Truncate the text that maps to the truncated tokens.
# The offset_mapping describes the text position for each token.
# Added tokens were not in the text, so they show up as (0, 0).

# Get the text offsets for the tokens that are to be kept after truncation.
# Take the first set of offsets that mapped to this text's positions.
# This first set represents what will be kept after truncation.
# Each offset tells us which chars in the original text map to this token.
offsets = next(tokenized["offset_mapping"][pos] for pos in positions)

# Find the first offset that is not empty (0, 0) to avoid added tokens
start = next(offset for offset in offsets if offset != (0, 0))

# Find the last offset that is not empty (0, 0) to avoid added tokens
end = next(
offset for offset in reversed(list(offsets)) if offset != (0, 0)
)

# Use the start-beginning end-ending to slice the text based on token truncation
# i.e. if start=(0,5) and end=(72,78) then we want slice [0:78]
truncated_text = text[start[0] : end[1]]

ret.append(truncated_text) # return the truncated text for this one

return ret

@EmbeddingTask.taskmethod()
Expand Down

0 comments on commit a87cc64

Please sign in to comment.