Skip to content

Commit

Permalink
rebase fixes
Browse files Browse the repository at this point in the history
Signed-off-by: root <root@eos0500.eos.clusters.nvidia.com>
  • Loading branch information
root committed May 16, 2024
1 parent c827d64 commit 4dc69e3
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 30 deletions.
1 change: 0 additions & 1 deletion nemo/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


use_TensorRTLLM = True
try:
from nemo.export.tensorrt_llm import TensorRTLLM
except Exception as e:
LOGGER.warning("TensorRTLLM could not be imported.")
Expand Down
9 changes: 5 additions & 4 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def build(
self,
nemo_model,
nemo_model_config,
trt_model_type,
tokenizer,
max_input_len: int = 1024,
max_input_tokens: int = 4096,
Expand Down Expand Up @@ -352,11 +353,11 @@ def build(

myrank = torch.distributed.get_rank()
cfg_path = Path(os.path.join(self.model_dir, f'config_{myrank}.json'))
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)

print(f"engine saved to {self.model_dir}")
print(self.model_dir, f'config_{myrank}.json')
if not cfg_path.exists():
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)


print_mem("post build_and_save_engine")

Expand Down
6 changes: 4 additions & 2 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tqdm import tqdm
from transformers import AutoTokenizer, GPT2Tokenizer, LlamaConfig

from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.export.tarutils import TarPath, ZarrPathStore
from nemo.export.trt_llm.nemo.convert import save_weight_torch, split_and_save_weight
from nemo.export.trt_llm.nemo.nemo import UnpackedNemoCheckpointDir, extract_layers_with_prefix, nemo_to_llm_config
Expand Down Expand Up @@ -356,7 +357,8 @@ def convert_nemo_model(
state_dict = nemo_model[0].state_dict()
else:
state_dict = nemo_model.state_dict()
storage_type = next(iter(state_dict.values())).dtype

storage_type = torch_dtype_from_precision(nemo_model_config.precision)
prefix, transformer_layer_prefix = get_layer_prefix(state_dict, is_mcore)

if num_kv_heads == 0:
Expand All @@ -369,7 +371,7 @@ def convert_nemo_model(
export_config = {
"apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p",
"tp_size": tp_size,
"split_gated_activation": "swiglu" in nemo_model_config.get("activation", "gelu"),
"split_gated_activation": "glu" in nemo_model_config.get("activation", "gelu"),
"num_attention_heads": nemo_model_config["num_attention_heads"],
"num_kv_heads": num_kv_heads,
"transpose_weights": True,
Expand Down
171 changes: 162 additions & 9 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
LinearConfig,
ModelConfig,
)

from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.export.trt_llm.nemo.nemo import UnpackedNemoCheckpointDir
from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer, convert_dist_checkpoint, convert_nemo_model
from nemo.export.trt_llm.tensor_utils import get_tensor_from_dict, get_tensor_parallel_group, split
Expand Down Expand Up @@ -268,9 +270,6 @@ def nemo_llm_model_to_model_config(
trt_model_type,
) -> Tuple[PretrainedConfig, dict]:
"""Converts the NEMO model object and construct the `ModelConfig` before tensorrt_llm deployment."""
from megatron.core import parallel_state
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from tensorrt_llm import Mapping

weights_dict = convert_nemo_model(
nemo_model=nemo_model,
Expand All @@ -285,11 +284,7 @@ def nemo_llm_model_to_model_config(
else:
activation = nemo_model_config['activation']

if isinstance(nemo_model, list):
torch_dtype = next(iter(nemo_model[0].state_dict().values())).dtype
else:
torch_dtype = next(iter(nemo_model.state_dict().values())).dtype

torch_dtype = torch_dtype_from_precision(nemo_model_config.precision)
str_dtype = trt_dtype_to_str(np_dtype_to_trt(torch_dtype_to_np(torch_dtype)))
model_config = PretrainedConfig(
architecture=trt_model_type,
Expand Down Expand Up @@ -325,7 +320,165 @@ def nemo_llm_model_to_model_config(
disable_weight_only_quant_plugin=False,
attn_bias=False,
mlp_bias=False,
bias=False
bias=False,
gpus_per_node=8,
)
model_config.mapping = mapping
return model_config, weights_dict

def nemo_to_trtllm_config(
in_file: str,
decoder_type: str,
nemo_export_dir: Union[str, Path],
dtype: str = "bfloat16",
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
use_parallel_embedding: bool = False,
save_nemo_model_config: bool = False,
) -> Tuple[List[Dict], List[PretrainedConfig], PreTrainedTokenizer]:
"""Converts the NEMO file and construct the `PretrainedConfig` before tensorrt_llm deployment."""
dtype_str = dtype

weights_dict, nemo_model_config, tokenizer = _nemo_llm_decode(
in_file=in_file,
out_dir=nemo_export_dir,
tensor_parallelism=tensor_parallel_size,
processes=1,
storage_type=dtype_str,
use_parallel_embedding=use_parallel_embedding,
load_checkpoints_on_gpu=False,
decoder_type=decoder_type,
save_nemo_model_config=save_nemo_model_config,
)

world_size = tensor_parallel_size * pipeline_parallel_size

lm_head_weight = weights_dict["lm_head.weight"]

vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0]
vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size)

if vocab_size_padded != vocab_size:
pad_width = vocab_size_padded - vocab_size
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0)

hidden_act = nemo_model_config.get('activation')
hidden_act = (
hidden_act.split("-")[-1] if nemo_model_config.get('num_moe_experts', 0) else non_gated_version(hidden_act)
)

config = {
'architecture': DECODER_MODEL_TYPE[decoder_type],
'dtype': dtype_str,
'num_hidden_layers': nemo_model_config.get('num_layers'),
'num_attention_heads': nemo_model_config.get('num_attention_heads'),
'num_key_value_heads': nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']),
'head_size': nemo_model_config.get('kv_channels'),
'hidden_size': nemo_model_config.get('hidden_size'),
'intermediate_size': nemo_model_config.get('ffn_hidden_size'),
'norm_epsilon': nemo_model_config.get('layernorm_epsilon'),
'vocab_size': vocab_size_padded,
'position_embedding_type': (
"rope_gpt_neox" if nemo_model_config.get('position_embedding_type') == "rope" else "learned_absolute"
),
'max_position_embeddings': nemo_model_config.get('max_position_embeddings'),
'hidden_act': hidden_act,
'use_parallel_embedding': use_parallel_embedding,
'embedding_sharding_dim': 0,
'share_embedding_table': False,
'quantization': {
'quant_algo': None,
'kv_cache_quant_algo': None,
},
'bias': nemo_model_config.get('bias'),
'apply_query_key_layer_scaling': False,
'rotary_pct': nemo_model_config.get('rotary_percentage', 1.0),
'rotary_base': nemo_model_config.get('rotary_base', 10000),
'moe_num_experts': nemo_model_config.get('num_moe_experts', 0),
'moe_top_k': nemo_model_config.get('moe_router_topk'),
'moe_normalization_mode': nemo_model_config.get(
'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE
),
'moe_tp_mode': nemo_model_config.get('moe_tp_mode', MoeConfig.ParallelismMode.TENSOR_PARALLEL),
'logits_dtype': 'float32',
'world_size': world_size,
'tp_size': tensor_parallel_size,
'pp_size': pipeline_parallel_size,
}

model_configs = []
weights_dicts = []
num_layers = nemo_model_config.get('num_layers')
rotary_scaling = nemo_model_config.get("seq_len_interpolation_factor")

if decoder_type == "falcon":
config["new_decoder_architecture"] = False if num_layers == 32 else True
config["parallel_attention"] = True
if rotary_scaling is not None:
config["rotary_scaling"] = {"type": "linear", "factor": float(rotary_scaling)}

pp_key = {
"transformer.vocab_embedding.weight",
"transformer.position_embedding.weight",
"lm_head.weight",
"transformer.ln_f.weight",
"transformer.ln_f.bias",
}

for i in range(world_size):
mapping = tensorrt_llm.Mapping(
world_size=world_size, rank=i, tp_size=tensor_parallel_size, pp_size=pipeline_parallel_size
)
layers_range = mapping.pp_layers(num_layers)

weights_dict_local = {}
for k, v in weights_dict.items():
if k in pp_key:
continue
new_key = k
if new_key.endswith(".bin"): # TP split
if new_key.endswith(f"{mapping.tp_rank}.bin"):
new_key = new_key.replace(f".{mapping.tp_rank}.bin", "")
if "layers" in new_key: # PP
layer_num = int(new_key.split(".")[2])
if layer_num in layers_range:
new_key = new_key.replace(f"layers.{layer_num}", f"layers.{layer_num-layers_range[0]}")
if config.get("new_decoder_architecture", False) and "post_layernorm" in new_key:
new_key = new_key.replace("post_layernorm", "mlp_layernorm")
weights_dict_local[new_key] = v

if mapping.is_first_pp_rank():
embedding_weight = (
np.ascontiguousarray(
split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank)
)
if use_parallel_embedding
else weights_dict["transformer.vocab_embedding.weight"]
)

weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight

pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight")
if pos_embedding_weight is not None:
if use_parallel_embedding:
pos_embedding_weight = np.ascontiguousarray(
split(pos_embedding_weight, mapping.tp_size, mapping.tp_rank)
)
weights_dict_local["transformer.position_embedding.weight"] = pos_embedding_weight

if mapping.is_last_pp_rank():
weights_dict_local["lm_head.weight"] = np.ascontiguousarray(
split(lm_head_weight, mapping.tp_size, mapping.tp_rank)
)
weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"]

ln_f_bias = weights_dict.get("transformer.ln_f.bias")
if ln_f_bias is not None:
weights_dict_local["transformer.ln_f.bias"] = ln_f_bias

model_config = PretrainedConfig(**config)
model_config.mapping = mapping
model_configs.append(model_config)
weights_dicts.append(weights_dict_local)

return weights_dicts, model_configs, tokenizer
20 changes: 10 additions & 10 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tensorrt_llm.builder import BuildConfig, Builder
from tensorrt_llm.commands.build import build as build_trtllm
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraBuildConfig
# from tensorrt_llm.lora_manager import LoraBuildConfig
from tensorrt_llm.models.modeling_utils import add_lora, optimize_model, preprocess_weights
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin import PluginConfig
Expand Down Expand Up @@ -400,15 +400,15 @@ def build_and_save_engine(
}
build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)

if use_lora_plugin is not None:
build_config.plugin_config.set_lora_plugin(use_lora_plugin)
lora_config = LoraBuildConfig(
lora_dir=lora_ckpt_list,
lora_ckpt_source='nemo',
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
)
build_config.lora_config = lora_config
# if use_lora_plugin is not None:
# build_config.plugin_config.set_lora_plugin(use_lora_plugin)
# lora_config = LoraBuildConfig(
# lora_dir=lora_ckpt_list,
# lora_ckpt_source='nemo',
# max_lora_rank=max_lora_rank,
# lora_target_modules=lora_target_modules,
# )
# build_config.lora_config = lora_config

model = model_cls.from_config(model_config)
# use_parallel_embedding=True,
Expand Down
14 changes: 10 additions & 4 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import ModelConfig, ModelRunnerCpp, SamplingConfig
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession
from transformers import PreTrainedTokenizer

from tensorrt_llm.bindings import (DataType, GenerationInput, GenerationOutput,
GptJsonConfig, GptSession, GptSessionConfig,
KvCacheConfig, PromptTuningParams, WorldConfig)
from tensorrt_llm.bindings import SamplingConfig as GptSamplingConfig

from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group
from nemo.export.trt_llm.tensorrt_llm_model import LMHeadModelBuilder

Expand All @@ -55,7 +61,7 @@ class TensorrtLLMHostContext:
class TensorrtLLMWorkerContext:
"""The MPI worker side context for TRT LLM inference."""

decoder: ModelRunnerCpp = None
decoder: ModelRunnerCppGptSession = None
sampling_config: SamplingConfig = None
lora_manager: LoraManager = None
max_batch_size: int = 0
Expand Down Expand Up @@ -147,7 +153,7 @@ def _load(tokenizer: PreTrainedTokenizer, engine_dir, lora_ckpt_list=None, num_b

runtime_rank = tensorrt_llm.mpi_rank()

decoder = ModelRunnerCpp.from_dir(
decoder = ModelRunnerCppGptSession.from_dir(
engine_dir=engine_dir,
lora_dir=lora_ckpt_list,
lora_ckpt_source="nemo",
Expand Down Expand Up @@ -367,7 +373,7 @@ def load_refit(engine_dir):
)
session = create_gpt_session(session_params, engine_data)

model_runner = ModelRunnerCpp(session,
model_runner = ModelRunnerCppGptSession(session,
lora_manager=None,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
Expand Down

0 comments on commit 4dc69e3

Please sign in to comment.