Skip to content

Commit

Permalink
dehardcode test string
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Apr 9, 2024
1 parent b273fa6 commit 845dc17
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,11 @@ def nemo_llm_to_model_config(
return model_configs, tokenizer


def to_word_list_format(word_dict: List[List[str]], tokenizer=None):
def to_word_list_format(
word_dict: List[List[str]],
tokenizer=None,
ref_str="<extra_id_1>",
):
'''
format of word_dict
len(word_dict) should be same to batch_size
Expand All @@ -213,7 +217,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):
# We use a similar trick as in NeMo to deal with the fact that the encoding of a single word
# can't always be trusted. See
# https://github.com/NVIDIA/NeMo/blob/bb575b72fd0be51ae10cc77d9f89ddb9e9d3b96d/nemo/collections/nlp/modules/common/text_generation_strategy.py#L229
ids_ref = tokenizer.encode("<extra_id_1>")
ids_ref = tokenizer.encode(ref_str)
for word_dict_item in word_dict:
item_flat_ids = []
item_offsets = []
Expand All @@ -223,7 +227,7 @@ def to_word_list_format(word_dict: List[List[str]], tokenizer=None):

words = list(csv.reader(word_dict_item))[0]
for word in words:
ids = tokenizer.encode(f"<extra_id_1>{word}")
ids = tokenizer.encode(f"{ref_str}{word}")
if ids[0 : len(ids_ref)] == ids_ref:
# It worked! We can obtain the token(s) associated to `word` by stripping the prefix tokens.
ids = ids[len(ids_ref) :]
Expand Down

0 comments on commit 845dc17

Please sign in to comment.