Skip to content

Commit

Permalink
parallel embedding
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 17, 2024
1 parent 45a6822 commit df36d00
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
4 changes: 4 additions & 0 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ def remove_vocab_padding(tensor):
world_embed = model_level_params.get(get_layer_name("word_embedding", prefix), None)
if tp_size > 1 and pp_is_first:
world_embed = remove_vocab_padding(world_embed)
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
tokenizer_vocab_size, tp_rank, tp_size)
world_embed = world_embed[vocab_start_index:vocab_end_index]

save_pp_weight(
world_embed,
"vocab_embedding.weight",
Expand Down
1 change: 1 addition & 0 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def nemo_llm_model_to_model_config(
world_size=mapping.world_size,
tp_size=mapping.tp_size,
pp_size=mapping.pp_size,
use_parallel_embedding = True,
quantization = {
'quant_algo': None,
'kv_cache_quant_algo': None,
Expand Down
3 changes: 1 addition & 2 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,9 @@ def build_and_save_engine(
# build_config.lora_config = lora_config

model = model_cls.from_config(model_config)
# use_parallel_embedding=True,

model = optimize_model(
model,
use_parallel_embedding=model_config.use_parallel_embedding,
share_embedding_table=model_config.share_embedding_table,
)
preprocess_weights(model_weights, model_config)
Expand Down

0 comments on commit df36d00

Please sign in to comment.