From f629248b689b0c041ea4546e78826f0139c7602d Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 4 Dec 2024 20:16:17 +0000 Subject: [PATCH] :white_check_mark: based on the PR comments, changed test case to check for an expected number instead of checking if length is non-zero; added `return_attention_mask=True` in the `run_tokenizer` method Signed-off-by: m-misiura --- caikit_nlp/modules/text_generation/text_generation_local.py | 2 +- tests/modules/text_generation/test_text_generation_local.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 885028f4..ba19a585 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -592,7 +592,7 @@ def run_tokenizer( The token count """ error.type_check("", str, text=text) - tokenized_output = self.model.tokenizer(text) + tokenized_output = self.model.tokenizer(text, return_attention_mask=True) return TokenizationResults( token_count=len(tokenized_output["input_ids"]), ) diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index 8338a163..5e91bea0 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -228,10 +228,10 @@ def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device): short_text = "This is a test sentence." short_result = model.run_tokenizer(short_text) assert isinstance(short_result, TokenizationResults) - assert short_result.token_count > 0 + assert short_result.token_count == len(model.model.tokenizer.encode(short_text)) # Edge case: Long input long_text = "This is a test sentence. " * 1000 long_result = model.run_tokenizer(long_text) assert isinstance(long_result, TokenizationResults) - assert long_result.token_count > 0 + assert long_result.token_count == len(model.model.tokenizer.encode(long_text))