Skip to content

Commit

Permalink
nemotron support
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 7, 2024
1 parent 48e9169 commit dbe18d0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
module.cuda(torch.cuda.current_device())
else:
self.model.cuda(torch.cuda.current_device())

self._wrap_model_for_O2()

self.enable_autocast = (
Expand Down
1 change: 1 addition & 0 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def build(
nemo_model_config=nemo_model_config,
reshard_model=self.reshard_model,
mapping=mapping,
trt_model_type=trt_model_type,
)

print_mem("pre build_and_save_engine")
Expand Down
33 changes: 25 additions & 8 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def nemo_llm_model_to_model_config(
tokenizer,
nemo_model_config,
reshard_model,
mapping
mapping,
trt_model_type,
) -> Tuple[PretrainedConfig, dict]:
"""Converts the NEMO model object and construct the `ModelConfig` before tensorrt_llm deployment."""
from megatron.core import parallel_state
Expand All @@ -270,14 +271,29 @@ def nemo_llm_model_to_model_config(
tokenizer_vocab_size=tokenizer.vocab_size,
reshard_model=reshard_model)

renamed_weight_dict = {}
if trt_model_type == 'GPTForCausalLM':
for key, val in weights_dict.items():
if 'layernorm' in key:
new_key = key.replace("pre_mlp_layernorm", "post_layernorm")
else:
new_key = key
renamed_weight_dict[new_key] = val

activation = None
if nemo_model_config['activation'] == 'fast-swiglu':
activation = 'silu'
else:
activation = nemo_model_config['activation']

if isinstance(nemo_model, list):
torch_dtype = next(iter(nemo_model[0].state_dict().values())).dtype
else:
torch_dtype = next(iter(nemo_model.state_dict().values())).dtype

str_dtype = trt_dtype_to_str(np_dtype_to_trt(torch_dtype_to_np(torch_dtype)))
model_config = PretrainedConfig(
architecture='LlamaForCausalLM',
architecture=trt_model_type,
dtype=str_dtype,
logits_dtype='float32',
vocab_size=tokenizer.vocab_size,
Expand All @@ -286,7 +302,7 @@ def nemo_llm_model_to_model_config(
num_hidden_layers=nemo_model_config.get('num_layers'),
num_attention_heads=nemo_model_config.get('num_attention_heads'),
num_key_value_heads=nemo_model_config.get('num_query_groups'),
hidden_act='silu',
hidden_act=activation,
intermediate_size=nemo_model_config.get('ffn_hidden_size'),
norm_epsilon=nemo_model_config.get('layernorm_epsilon'),
position_embedding_type="rope_gpt_neox",
Expand All @@ -301,15 +317,16 @@ def nemo_llm_model_to_model_config(
'pre_quant_scale': False,
'exclude_modules': None},
kv_dtype=str_dtype,
rotary_scaling=None,
moe_normalization_mode=None,
rotary_base=10000.0,
rotary_pct=nemo_model_config.get('rotary_percentage', 1.0),
rotary_base=nemo_model_config.get('rotary_base', 10000),
moe_num_experts=0,
moe_top_k=0,
moe_tp_mode=2,
attn_bias=False,
disable_weight_only_quant_plugin=False,
mlp_bias=False
attn_bias=False,
mlp_bias=False,
bias=False
)
model_config.mapping = mapping
return model_config, weights_dict
return model_config, renamed_weight_dict

0 comments on commit dbe18d0

Please sign in to comment.