Skip to content

Commit

Permalink
✅ based on the PR comments, changed test case to check for an expecte…
Browse files Browse the repository at this point in the history
…d number instead of checking if length is non-zero; added `return_attention_mask=True` in the `run_tokenizer` method

Signed-off-by: m-misiura <mmisiura@redhat.com>
  • Loading branch information
m-misiura committed Dec 4, 2024
1 parent 261e1a3 commit f629248
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def run_tokenizer(
The token count
"""
error.type_check("<NLP48137045E>", str, text=text)
tokenized_output = self.model.tokenizer(text)
tokenized_output = self.model.tokenizer(text, return_attention_mask=True)
return TokenizationResults(
token_count=len(tokenized_output["input_ids"]),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/modules/text_generation/test_text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device):
short_text = "This is a test sentence."
short_result = model.run_tokenizer(short_text)
assert isinstance(short_result, TokenizationResults)
assert short_result.token_count > 0
assert short_result.token_count == len(model.model.tokenizer.encode(short_text))

# Edge case: Long input
long_text = "This is a test sentence. " * 1000
long_result = model.run_tokenizer(long_text)
assert isinstance(long_result, TokenizationResults)
assert long_result.token_count > 0
assert long_result.token_count == len(model.model.tokenizer.encode(long_text))

0 comments on commit f629248

Please sign in to comment.