From 0142ee74b12d8af89bb03ae6e8801129e6b6b99a Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 6 Sep 2024 00:36:46 +0000 Subject: [PATCH] [feat] Upgrade nemo-export path for aligner to TRTLLM-v12 and use python runtime Signed-off-by: Terry Kong --- nemo/export/tensorrt_llm.py | 7 +- nemo/export/trt_llm/converter/utils.py | 10 +- nemo/export/trt_llm/tensorrt_llm_build.py | 2 + nemo/export/trt_llm/tensorrt_llm_run.py | 124 +++++++++++++++------- 4 files changed, 100 insertions(+), 43 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 2a89b76cc099..6573a6c896f8 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -45,7 +45,7 @@ from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine -from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_distributed, refit +from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_distributed, refit, unload_engine use_deploy = True try: @@ -490,12 +490,12 @@ def build( engine = build_and_save_engine( max_input_len=max_input_len, max_output_len=max_output_len, + max_seq_len=max_input_len+max_output_len, max_batch_size=max_batch_size, model_config=model_config[0], model_weights=weights[0], model_dir=self.model_dir, model_type=model_type, - custom_all_reduce=False, use_refit=use_refit, ) torch.distributed.barrier() @@ -968,3 +968,6 @@ def _load(self): "model needs to be exported again. " "Error message: " + repr(error) ) from error + + def unload_engine(self): + unload_engine() diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index eab17167cbd5..57633f7d925c 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -16,7 +16,7 @@ import numpy as np import tensorrt_llm import torch -from tensorrt_llm._utils import torch_to_numpy +from tensorrt_llm._utils import torch_to_numpy, mpi_comm # A global dicts to store exported weights. # This is set to be a global variable to avoid extra code modification from tensorrt_llm. @@ -492,6 +492,12 @@ def init_model_parallel_from_nemo(reshard_model): pp_size = 1 mp_rank = tp_size * pp_rank + tp_rank + # Need to split cpp MPI World Comm because TensorRT-LLM NCCL plugins refer to the locally split comm. + # High level call structure is: MpiComm::split -> MpiComm::setSession -> LOCAL_COMM_SESSION (used in allReducePlugin.cpp) tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank) + # Also split the python mpi communicator and set the global world one to the local split one + new_comm = mpi_comm().Split(color=dp_rank, key=mp_rank) + from mpi4py import MPI + MPI.COMM_WORLD = new_comm - return mp_rank, dp_rank, tp_size, pp_size, dp_size + return mp_rank, dp_rank, tp_size, pp_size, dp_size \ No newline at end of file diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index e37c3ba1c845..aaf51b957638 100755 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -53,6 +53,7 @@ def build_and_save_engine( multiple_profiles: bool = False, gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", + reduce_fusion: bool = False, ): architecture = "LLaMAForCausalLM" if model_config.architecture == "LlamaForCausalLM" else model_config.architecture try: @@ -71,6 +72,7 @@ def build_and_save_engine( plugin_config.remove_input_padding = remove_input_padding plugin_config.use_paged_context_fmha = paged_context_fmha plugin_config.multiple_profiles = multiple_profiles + plugin_config.reduce_fusion = reduce_fusion max_num_tokens, opt_num_tokens = check_max_num_tokens( max_num_tokens=max_num_tokens, diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 852eddc6a468..2ef8bb803748 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -28,7 +28,11 @@ from mpi4py.futures import MPIPoolExecutor from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig +from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig, GenerationSession +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.builder import Engine +from tensorrt_llm._utils import mpi_comm +import tensorrt as trt from transformers import PreTrainedTokenizer @@ -36,16 +40,10 @@ use_trtllm_bindings = True try: - from tensorrt_llm.bindings import GptJsonConfig, GptSession, GptSessionConfig, KvCacheConfig, WorldConfig + from tensorrt_llm.bindings import GptJsonConfig, KvCacheConfig, WorldConfig except Exception as e: use_trtllm_bindings = False -use_cpp_gpt_session = True -try: - from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession -except Exception as e: - use_cpp_gpt_session = False - @dataclass class TensorrtLLMHostContext: @@ -63,7 +61,7 @@ class TensorrtLLMHostContext: class TensorrtLLMWorkerContext: """The MPI worker side context for TRT LLM inference.""" - decoder: ModelRunner = None + decoder: ModelRunner | ModelRunnerCpp = None sampling_config: SamplingConfig = None lora_manager: LoraManager = None max_batch_size: int = 0 @@ -123,7 +121,6 @@ def _read_config(config_path: Path): lora_plugin=config["plugin_config"]["lora_plugin"], lora_target_modules=config["builder_config"]["lora_target_modules"], quant_mode=quant_mode, - use_custom_all_reduce=config["plugin_config"]["use_custom_all_reduce"], use_context_fmha_for_generation=config["plugin_config"]["use_context_fmha_for_generation"], gather_context_logits=config["builder_config"]["gather_context_logits"], gather_generation_logits=config["builder_config"]["gather_generation_logits"], @@ -456,7 +453,7 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): this function creates a custom mapping of device_id to WorldConfig """ global tensorrt_llm_worker_context - if isinstance(tensorrt_llm_worker_context.decoder, ModelRunnerCppGptSession): + if isinstance(tensorrt_llm_worker_context.decoder, ModelRunner): return config_path = Path(engine_dir) / f"config_{torch.distributed.get_rank()}.json" @@ -480,46 +477,95 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): device_ids = [i for i in range(gpus_per_node)] for _ in range(offset): device_ids.append(device_ids.pop(0)) - world_config = WorldConfig.mpi( - gpus_per_node=gpus_per_node, tensor_parallelism=tp_size, pipeline_parallelism=pp_size, device_ids=device_ids - ) - engine_filename = json_config.engine_filename(world_config) + engine_index = model_parallel_rank + mpi_rank = mpi_comm().Get_rank() + # TODO: copied from worldConfig.h (getDevice()) + mpi_device = mpi_rank % gpus_per_node + # TODO: check if API exists (copied from gptJsonConfig.cpp) + # https://github.com/terrykong/TensorRT-LLM/blob/05316d3313360012536ace46c781518f5afae75e/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp#L478 + engine_filename = f"rank{engine_index}.engine" serialize_path = Path(engine_dir) / engine_filename - assert torch.cuda.current_device() == world_config.device - - session_config = GptSessionConfig( - max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_sequence_length=max_seq_len - ) - session_config.gen_micro_batch_size = max_batch_size - session_config.ctx_micro_batch_size = max_batch_size - session_config.kv_cache_config = KvCacheConfig( - max_tokens=max_seq_len * max_batch_size, max_attention_window=max_seq_len - ) - + #$#$#$assert torch.cuda.current_device() == mpi_device with open(serialize_path, "rb") as f: engine_data = bytearray(f.read()) - session = GptSession(session_config, model_config, world_config, engine_data) - decoder = ModelRunnerCppGptSession( - session, - lora_manager=None, - max_batch_size=max_batch_size, - max_input_len=max_input_len, - max_seq_len=max_seq_len, - max_beam_width=max_beam_width, + with open(config_path) as f: + json_config_str = f.read() + + engine = Engine.from_buffer( + engine_buffer=engine_data, + json_config_str=json_config_str, + rank=model_parallel_rank) + decoder = ModelRunner.from_engine( + engine=engine, + #rank=world_config.rank, + # We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process + # So we will set it to the current + rank=torch.cuda.current_device(), ) tensorrt_llm_worker_context.decoder = decoder tensorrt_llm_worker_context.max_batch_size = max_batch_size tensorrt_llm_worker_context.max_input_len = max_input_len - # Save the model config in case for refit - tensorrt_llm_worker_context.model_config = model_config -def refit(weights_dict): +def maybe_cast_to_trt_dtype(dtype): + if isinstance(dtype, trt.DataType): + return dtype + elif isinstance(dtype, torch.dtype): + return tensorrt_llm._utils.torch_dtype_to_trt(dtype) + else: + raise NotImplementedError(f"Expects the type to be a tensorrt.DataType or torch.dtype, but got {type(dtype)=}") + + +def refit(weights_dict: dict): global tensorrt_llm_worker_context - dtype = tensorrt_llm_worker_context.model_config.data_type - tensorrt_llm_worker_context.decoder.session.refit_engine(weights_dict, dtype) + decoder = tensorrt_llm_worker_context.decoder + if not isinstance(decoder, ModelRunner): + raise ValueError(f"Refit is only supported with ModelRunner, but export has been configured with {type(decoder)=}") + + engine = decoder.session.runtime.engine + # The session dtype plumbs the model_config's dtype + model_dtype = maybe_cast_to_trt_dtype(decoder.session.dtype) + assert engine.refittable, "Tried refitting engine without refit enabled" + + refitter = trt.Refitter(engine=engine, logger=trt.Logger(trt.Logger.ERROR)) + remaining_refit_weights = set(refitter.get_all_weights()) + skipped_weights = [] + for trt_name, weight in weights_dict.items(): + if trt_name not in remaining_refit_weights: + skipped_weights.append(trt_name) + continue + trt_weight = trt.Weights(model_dtype, weight.data_ptr(), torch.numel(weight)) + trt_wt_location = ( + trt.TensorLocation.DEVICE if weight.is_cuda else trt.TensorLocation.HOST + ) + assert model_dtype == refitter.get_weights_prototype(trt_name).dtype == maybe_cast_to_trt_dtype(weight.dtype), f"Expected all three of these dtypes to be the same {model_dtype=} {refitter.get_weights_prototype(trt_name).dtype=} weight.dtype={maybe_cast_to_trt_dtype(weight.dtype)}" + + refitter.set_named_weights(trt_name, trt_weight, trt_wt_location), f"Unable to set {trt_name=} {trt_weight=} {trt_wt_location=}" + remaining_refit_weights.remove(trt_name) + if skipped_weights: + logging.warning(f"These weights were ignored during refit since they are not present in engine: {skipped_weights}") + if remaining_refit_weights: + logging.warning(f"Weights dict did not contain weights for these named TRT weights: {remaining_refit_weights}") + + if not refitter.refit_cuda_engine(): + raise ValueError(f"Refit failed!") + + +def unload_engine(): + """ + Deletes the ModelRunner which should free up device memory + """ + global tensorrt_llm_worker_context + decoder = tensorrt_llm_worker_context.decoder + if not isinstance(decoder, ModelRunner): + raise ValueError(f"unload_engine is only supported with ModelRunner, but export has been configured with {type(decoder)=}") + + logging.info("Unloading engine...") + del tensorrt_llm_worker_context.decoder + tensorrt_llm_worker_context.decoder = None + logging.info("Engine unloaded!") def prepare_input_tensors(