Skip to content

Commit

Permalink
move key rename
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 9, 2024
1 parent 0ea57dd commit 9f1a00f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
2 changes: 2 additions & 0 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def build(

self.use_refit = use_refit
self.tokenizer = build_tokenizer(tokenizer)
self.trt_model_type = trt_model_type

pp_size = parallel_state.get_pipeline_model_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -362,6 +363,7 @@ def refit(
weights = convert_nemo_model(
nemo_model=nemo_model,
nemo_model_config=nemo_model_config,
trt_model_type=self.trt_model_type,
tokenizer_vocab_size=self.tokenizer.vocab_size,
reshard_model=self.reshard_model,
)
Expand Down
29 changes: 22 additions & 7 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,14 @@ def get_layer_num(param_name):
return int(split_key[layer_index])

@torch.no_grad()
def convert_nemo_model(nemo_model, nemo_model_config, tokenizer_vocab_size, reshard_model=False, cpu=True):
def convert_nemo_model(
nemo_model,
nemo_model_config,
tokenizer_vocab_size,
trt_model_type,
reshard_model=False,
cpu=True
):
from megatron.core import parallel_state
from megatron.core.tensor_parallel.utils import VocabUtility

Expand Down Expand Up @@ -501,15 +508,13 @@ def save_pp_weight(

# ----------------Convert Embeddings----------------
def remove_vocab_padding(tensor):
vocab_size_per_tp = tensor.shape[0]
vocab_size_padded = vocab_size_per_tp*tp_size
vocab_size_padded = tensor.shape[0]*tp_size
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
vocab_size_padded, tp_rank, tp_size)

vocab_size_padded, tp_rank, tp_size)
dim_size = list(tensor.size())
dim_size[0] = vocab_size_padded

gathered_tensor = torch.zeros(dim_size, dtype=tensor.dtype).cuda()
gathered_tensor = torch.zeros(dim_size, dtype=tensor.dtype, device=torch.cuda.current_device())
gathered_tensor[vocab_start_index:vocab_end_index] = tensor
torch.distributed.all_reduce(gathered_tensor, group=tp_group)
return gathered_tensor[:tokenizer_vocab_size]
Expand Down Expand Up @@ -545,7 +550,17 @@ def remove_vocab_padding(tensor):
save_weight_torch(**starmap_arg)
toc = time.time()
print(f" weight save took {toc-tic}")
return weights_dict

renamed_weight_dict = {}
if trt_model_type == 'GPTForCausalLM':
for key, val in weights_dict.items():
if 'layernorm' in key:
new_key = key.replace("pre_mlp_layernorm", "post_layernorm")
else:
new_key = key
renamed_weight_dict[new_key] = val
return renamed_weight_dict



def create_out_dir(args):
Expand Down
12 changes: 2 additions & 10 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,9 @@ def nemo_llm_model_to_model_config(
nemo_model=nemo_model,
nemo_model_config=nemo_model_config,
tokenizer_vocab_size=tokenizer.vocab_size,
trt_model_type=trt_model_type,
reshard_model=reshard_model)

renamed_weight_dict = {}
if trt_model_type == 'GPTForCausalLM':
for key, val in weights_dict.items():
if 'layernorm' in key:
new_key = key.replace("pre_mlp_layernorm", "post_layernorm")
else:
new_key = key
renamed_weight_dict[new_key] = val

activation = None
if nemo_model_config['activation'] == 'fast-swiglu':
activation = 'silu'
Expand Down Expand Up @@ -326,4 +318,4 @@ def nemo_llm_model_to_model_config(
bias=False
)
model_config.mapping = mapping
return model_config, renamed_weight_dict
return model_config, weights_dict

0 comments on commit 9f1a00f

Please sign in to comment.