From 4dc69e3ea3281c7467987aa68b5002e58cf72ebe Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 May 2024 16:23:58 -0700 Subject: [PATCH] rebase fixes Signed-off-by: root --- nemo/export/__init__.py | 1 - nemo/export/tensorrt_llm.py | 9 +- nemo/export/trt_llm/nemo/nemo_ckpt_convert.py | 6 +- nemo/export/trt_llm/nemo_utils.py | 171 +++++++++++++++++- nemo/export/trt_llm/tensorrt_llm_build.py | 20 +- nemo/export/trt_llm/tensorrt_llm_run.py | 14 +- 6 files changed, 191 insertions(+), 30 deletions(-) diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index 799c64331ab5..5bf092cc2d4c 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -19,7 +19,6 @@ use_TensorRTLLM = True -try: from nemo.export.tensorrt_llm import TensorRTLLM except Exception as e: LOGGER.warning("TensorRTLLM could not be imported.") diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 3223c4a90646..50d07ac18317 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -269,6 +269,7 @@ def build( self, nemo_model, nemo_model_config, + trt_model_type, tokenizer, max_input_len: int = 1024, max_input_tokens: int = 4096, @@ -352,11 +353,11 @@ def build( myrank = torch.distributed.get_rank() cfg_path = Path(os.path.join(self.model_dir, f'config_{myrank}.json')) + with open(cfg_path, "w", encoding="utf-8") as f: + json.dump(engine.config.to_dict(), f, indent=4) + print(f"engine saved to {self.model_dir}") - print(self.model_dir, f'config_{myrank}.json') - if not cfg_path.exists(): - with open(cfg_path, "w", encoding="utf-8") as f: - json.dump(engine.config.to_dict(), f, indent=4) + print_mem("post build_and_save_engine") diff --git a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py index 59d9112061bc..3a0f6cbbaa0f 100644 --- a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py +++ b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py @@ -32,6 +32,7 @@ from tqdm import tqdm from transformers import AutoTokenizer, GPT2Tokenizer, LlamaConfig +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.export.tarutils import TarPath, ZarrPathStore from nemo.export.trt_llm.nemo.convert import save_weight_torch, split_and_save_weight from nemo.export.trt_llm.nemo.nemo import UnpackedNemoCheckpointDir, extract_layers_with_prefix, nemo_to_llm_config @@ -356,7 +357,8 @@ def convert_nemo_model( state_dict = nemo_model[0].state_dict() else: state_dict = nemo_model.state_dict() - storage_type = next(iter(state_dict.values())).dtype + + storage_type = torch_dtype_from_precision(nemo_model_config.precision) prefix, transformer_layer_prefix = get_layer_prefix(state_dict, is_mcore) if num_kv_heads == 0: @@ -369,7 +371,7 @@ def convert_nemo_model( export_config = { "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", "tp_size": tp_size, - "split_gated_activation": "swiglu" in nemo_model_config.get("activation", "gelu"), + "split_gated_activation": "glu" in nemo_model_config.get("activation", "gelu"), "num_attention_heads": nemo_model_config["num_attention_heads"], "num_kv_heads": num_kv_heads, "transpose_weights": True, diff --git a/nemo/export/trt_llm/nemo_utils.py b/nemo/export/trt_llm/nemo_utils.py index 0d36e5b823a5..d758851d68ae 100644 --- a/nemo/export/trt_llm/nemo_utils.py +++ b/nemo/export/trt_llm/nemo_utils.py @@ -47,6 +47,8 @@ LinearConfig, ModelConfig, ) + +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.export.trt_llm.nemo.nemo import UnpackedNemoCheckpointDir from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer, convert_dist_checkpoint, convert_nemo_model from nemo.export.trt_llm.tensor_utils import get_tensor_from_dict, get_tensor_parallel_group, split @@ -268,9 +270,6 @@ def nemo_llm_model_to_model_config( trt_model_type, ) -> Tuple[PretrainedConfig, dict]: """Converts the NEMO model object and construct the `ModelConfig` before tensorrt_llm deployment.""" - from megatron.core import parallel_state - from tensorrt_llm.models.modeling_utils import PretrainedConfig - from tensorrt_llm import Mapping weights_dict = convert_nemo_model( nemo_model=nemo_model, @@ -285,11 +284,7 @@ def nemo_llm_model_to_model_config( else: activation = nemo_model_config['activation'] - if isinstance(nemo_model, list): - torch_dtype = next(iter(nemo_model[0].state_dict().values())).dtype - else: - torch_dtype = next(iter(nemo_model.state_dict().values())).dtype - + torch_dtype = torch_dtype_from_precision(nemo_model_config.precision) str_dtype = trt_dtype_to_str(np_dtype_to_trt(torch_dtype_to_np(torch_dtype))) model_config = PretrainedConfig( architecture=trt_model_type, @@ -325,7 +320,165 @@ def nemo_llm_model_to_model_config( disable_weight_only_quant_plugin=False, attn_bias=False, mlp_bias=False, - bias=False + bias=False, + gpus_per_node=8, ) model_config.mapping = mapping return model_config, weights_dict + +def nemo_to_trtllm_config( + in_file: str, + decoder_type: str, + nemo_export_dir: Union[str, Path], + dtype: str = "bfloat16", + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + use_parallel_embedding: bool = False, + save_nemo_model_config: bool = False, +) -> Tuple[List[Dict], List[PretrainedConfig], PreTrainedTokenizer]: + """Converts the NEMO file and construct the `PretrainedConfig` before tensorrt_llm deployment.""" + dtype_str = dtype + + weights_dict, nemo_model_config, tokenizer = _nemo_llm_decode( + in_file=in_file, + out_dir=nemo_export_dir, + tensor_parallelism=tensor_parallel_size, + processes=1, + storage_type=dtype_str, + use_parallel_embedding=use_parallel_embedding, + load_checkpoints_on_gpu=False, + decoder_type=decoder_type, + save_nemo_model_config=save_nemo_model_config, + ) + + world_size = tensor_parallel_size * pipeline_parallel_size + + lm_head_weight = weights_dict["lm_head.weight"] + + vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0] + vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) + + if vocab_size_padded != vocab_size: + pad_width = vocab_size_padded - vocab_size + lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0) + + hidden_act = nemo_model_config.get('activation') + hidden_act = ( + hidden_act.split("-")[-1] if nemo_model_config.get('num_moe_experts', 0) else non_gated_version(hidden_act) + ) + + config = { + 'architecture': DECODER_MODEL_TYPE[decoder_type], + 'dtype': dtype_str, + 'num_hidden_layers': nemo_model_config.get('num_layers'), + 'num_attention_heads': nemo_model_config.get('num_attention_heads'), + 'num_key_value_heads': nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + 'head_size': nemo_model_config.get('kv_channels'), + 'hidden_size': nemo_model_config.get('hidden_size'), + 'intermediate_size': nemo_model_config.get('ffn_hidden_size'), + 'norm_epsilon': nemo_model_config.get('layernorm_epsilon'), + 'vocab_size': vocab_size_padded, + 'position_embedding_type': ( + "rope_gpt_neox" if nemo_model_config.get('position_embedding_type') == "rope" else "learned_absolute" + ), + 'max_position_embeddings': nemo_model_config.get('max_position_embeddings'), + 'hidden_act': hidden_act, + 'use_parallel_embedding': use_parallel_embedding, + 'embedding_sharding_dim': 0, + 'share_embedding_table': False, + 'quantization': { + 'quant_algo': None, + 'kv_cache_quant_algo': None, + }, + 'bias': nemo_model_config.get('bias'), + 'apply_query_key_layer_scaling': False, + 'rotary_pct': nemo_model_config.get('rotary_percentage', 1.0), + 'rotary_base': nemo_model_config.get('rotary_base', 10000), + 'moe_num_experts': nemo_model_config.get('num_moe_experts', 0), + 'moe_top_k': nemo_model_config.get('moe_router_topk'), + 'moe_normalization_mode': nemo_model_config.get( + 'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE + ), + 'moe_tp_mode': nemo_model_config.get('moe_tp_mode', MoeConfig.ParallelismMode.TENSOR_PARALLEL), + 'logits_dtype': 'float32', + 'world_size': world_size, + 'tp_size': tensor_parallel_size, + 'pp_size': pipeline_parallel_size, + } + + model_configs = [] + weights_dicts = [] + num_layers = nemo_model_config.get('num_layers') + rotary_scaling = nemo_model_config.get("seq_len_interpolation_factor") + + if decoder_type == "falcon": + config["new_decoder_architecture"] = False if num_layers == 32 else True + config["parallel_attention"] = True + if rotary_scaling is not None: + config["rotary_scaling"] = {"type": "linear", "factor": float(rotary_scaling)} + + pp_key = { + "transformer.vocab_embedding.weight", + "transformer.position_embedding.weight", + "lm_head.weight", + "transformer.ln_f.weight", + "transformer.ln_f.bias", + } + + for i in range(world_size): + mapping = tensorrt_llm.Mapping( + world_size=world_size, rank=i, tp_size=tensor_parallel_size, pp_size=pipeline_parallel_size + ) + layers_range = mapping.pp_layers(num_layers) + + weights_dict_local = {} + for k, v in weights_dict.items(): + if k in pp_key: + continue + new_key = k + if new_key.endswith(".bin"): # TP split + if new_key.endswith(f"{mapping.tp_rank}.bin"): + new_key = new_key.replace(f".{mapping.tp_rank}.bin", "") + if "layers" in new_key: # PP + layer_num = int(new_key.split(".")[2]) + if layer_num in layers_range: + new_key = new_key.replace(f"layers.{layer_num}", f"layers.{layer_num-layers_range[0]}") + if config.get("new_decoder_architecture", False) and "post_layernorm" in new_key: + new_key = new_key.replace("post_layernorm", "mlp_layernorm") + weights_dict_local[new_key] = v + + if mapping.is_first_pp_rank(): + embedding_weight = ( + np.ascontiguousarray( + split(weights_dict["transformer.vocab_embedding.weight"], mapping.tp_size, mapping.tp_rank) + ) + if use_parallel_embedding + else weights_dict["transformer.vocab_embedding.weight"] + ) + + weights_dict_local["transformer.vocab_embedding.weight"] = embedding_weight + + pos_embedding_weight = weights_dict.get("transformer.position_embedding.weight") + if pos_embedding_weight is not None: + if use_parallel_embedding: + pos_embedding_weight = np.ascontiguousarray( + split(pos_embedding_weight, mapping.tp_size, mapping.tp_rank) + ) + weights_dict_local["transformer.position_embedding.weight"] = pos_embedding_weight + + if mapping.is_last_pp_rank(): + weights_dict_local["lm_head.weight"] = np.ascontiguousarray( + split(lm_head_weight, mapping.tp_size, mapping.tp_rank) + ) + weights_dict_local["transformer.ln_f.weight"] = weights_dict["transformer.ln_f.weight"] + + ln_f_bias = weights_dict.get("transformer.ln_f.bias") + if ln_f_bias is not None: + weights_dict_local["transformer.ln_f.bias"] = ln_f_bias + + model_config = PretrainedConfig(**config) + model_config.mapping = mapping + model_configs.append(model_config) + weights_dicts.append(weights_dict_local) + + return weights_dicts, model_configs, tokenizer \ No newline at end of file diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 17eb13d3a999..290045537d08 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -28,7 +28,7 @@ from tensorrt_llm.builder import BuildConfig, Builder from tensorrt_llm.commands.build import build as build_trtllm from tensorrt_llm.logger import logger -from tensorrt_llm.lora_manager import LoraBuildConfig +# from tensorrt_llm.lora_manager import LoraBuildConfig from tensorrt_llm.models.modeling_utils import add_lora, optimize_model, preprocess_weights from tensorrt_llm.network import net_guard from tensorrt_llm.plugin import PluginConfig @@ -400,15 +400,15 @@ def build_and_save_engine( } build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config) - if use_lora_plugin is not None: - build_config.plugin_config.set_lora_plugin(use_lora_plugin) - lora_config = LoraBuildConfig( - lora_dir=lora_ckpt_list, - lora_ckpt_source='nemo', - max_lora_rank=max_lora_rank, - lora_target_modules=lora_target_modules, - ) - build_config.lora_config = lora_config + # if use_lora_plugin is not None: + # build_config.plugin_config.set_lora_plugin(use_lora_plugin) + # lora_config = LoraBuildConfig( + # lora_dir=lora_ckpt_list, + # lora_ckpt_source='nemo', + # max_lora_rank=max_lora_rank, + # lora_target_modules=lora_target_modules, + # ) + # build_config.lora_config = lora_config model = model_cls.from_config(model_config) # use_parallel_embedding=True, diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 6d9cce99bdbb..f1766046c1ff 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -26,9 +26,15 @@ from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import ModelConfig, ModelRunnerCpp, SamplingConfig +from tensorrt_llm.runtime import ModelConfig, SamplingConfig +from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession from transformers import PreTrainedTokenizer +from tensorrt_llm.bindings import (DataType, GenerationInput, GenerationOutput, + GptJsonConfig, GptSession, GptSessionConfig, + KvCacheConfig, PromptTuningParams, WorldConfig) +from tensorrt_llm.bindings import SamplingConfig as GptSamplingConfig + from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group from nemo.export.trt_llm.tensorrt_llm_model import LMHeadModelBuilder @@ -55,7 +61,7 @@ class TensorrtLLMHostContext: class TensorrtLLMWorkerContext: """The MPI worker side context for TRT LLM inference.""" - decoder: ModelRunnerCpp = None + decoder: ModelRunnerCppGptSession = None sampling_config: SamplingConfig = None lora_manager: LoraManager = None max_batch_size: int = 0 @@ -147,7 +153,7 @@ def _load(tokenizer: PreTrainedTokenizer, engine_dir, lora_ckpt_list=None, num_b runtime_rank = tensorrt_llm.mpi_rank() - decoder = ModelRunnerCpp.from_dir( + decoder = ModelRunnerCppGptSession.from_dir( engine_dir=engine_dir, lora_dir=lora_ckpt_list, lora_ckpt_source="nemo", @@ -367,7 +373,7 @@ def load_refit(engine_dir): ) session = create_gpt_session(session_params, engine_data) - model_runner = ModelRunnerCpp(session, + model_runner = ModelRunnerCppGptSession(session, lora_manager=None, max_batch_size=max_batch_size, max_input_len=max_input_len,