From 4f5bbe855766129b8eef983b5f68281abd274dae Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Tue, 16 Apr 2024 13:30:27 -0700 Subject: [PATCH] add word list format refstring (see PR 8865) Signed-off-by: Jimmy Zhang --- nemo/export/trt_llm/nemo_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/export/trt_llm/nemo_utils.py b/nemo/export/trt_llm/nemo_utils.py index 1c8fa1b5b913..b92ee9044076 100644 --- a/nemo/export/trt_llm/nemo_utils.py +++ b/nemo/export/trt_llm/nemo_utils.py @@ -197,7 +197,7 @@ def nemo_llm_to_model_config( return model_configs, tokenizer -def to_word_list_format(word_dict: List[List[str]], tokenizer=None): +def to_word_list_format(word_dict: List[List[str]], tokenizer=None, ref_string=''): ''' format of word_dict len(word_dict) should be same to batch_size @@ -214,7 +214,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None): # We use a similar trick as in NeMo to deal with the fact that the encoding of a single word # can't always be trusted. See # https://github.com/NVIDIA/NeMo/blob/bb575b72fd0be51ae10cc77d9f89ddb9e9d3b96d/nemo/collections/nlp/modules/common/text_generation_strategy.py#L229 - ids_ref = tokenizer.encode("") + ids_ref = tokenizer.encode(ref_string) for word_dict_item in word_dict: item_flat_ids = [] item_offsets = [] @@ -224,7 +224,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None): words = list(csv.reader(word_dict_item))[0] for word in words: - ids = tokenizer.encode(f"{word}") + ids = tokenizer.encode(f"{ref_string}{word}") if ids[0 : len(ids_ref)] == ids_ref: # It worked! We can obtain the token(s) associated to `word` by stripping the prefix tokens. ids = ids[len(ids_ref) :]