Skip to content

Commit

Permalink
add word list format refstring (see PR 8865)
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Apr 16, 2024
1 parent 735d705 commit 4f5bbe8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def nemo_llm_to_model_config(
return model_configs, tokenizer


def to_word_list_format(word_dict: List[List[str]], tokenizer=None):
def to_word_list_format(word_dict: List[List[str]], tokenizer=None, ref_string='<extra_id_1>'):
'''
format of word_dict
len(word_dict) should be same to batch_size
Expand All @@ -214,7 +214,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):
# We use a similar trick as in NeMo to deal with the fact that the encoding of a single word
# can't always be trusted. See
# https://github.com/NVIDIA/NeMo/blob/bb575b72fd0be51ae10cc77d9f89ddb9e9d3b96d/nemo/collections/nlp/modules/common/text_generation_strategy.py#L229
ids_ref = tokenizer.encode("<extra_id_1>")
ids_ref = tokenizer.encode(ref_string)
for word_dict_item in word_dict:
item_flat_ids = []
item_offsets = []
Expand All @@ -224,7 +224,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):

words = list(csv.reader(word_dict_item))[0]
for word in words:
ids = tokenizer.encode(f"<extra_id_1>{word}")
ids = tokenizer.encode(f"{ref_string}{word}")
if ids[0 : len(ids_ref)] == ids_ref:
# It worked! We can obtain the token(s) associated to `word` by stripping the prefix tokens.
ids = ids[len(ids_ref) :]
Expand Down

0 comments on commit 4f5bbe8

Please sign in to comment.