diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 6573a6c896f8..9a4e05553d63 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -45,7 +45,14 @@ from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine -from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_distributed, refit, unload_engine +from nemo.export.trt_llm.tensorrt_llm_run import ( + generate, + generate_streaming, + load, + load_distributed, + refit, + unload_engine, +) use_deploy = True try: @@ -490,7 +497,7 @@ def build( engine = build_and_save_engine( max_input_len=max_input_len, max_output_len=max_output_len, - max_seq_len=max_input_len+max_output_len, + max_seq_len=max_input_len + max_output_len, max_batch_size=max_batch_size, model_config=model_config[0], model_weights=weights[0], @@ -968,6 +975,6 @@ def _load(self): "model needs to be exported again. " "Error message: " + repr(error) ) from error - + def unload_engine(self): unload_engine() diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 57633f7d925c..34c3cf0127c5 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -16,7 +16,7 @@ import numpy as np import tensorrt_llm import torch -from tensorrt_llm._utils import torch_to_numpy, mpi_comm +from tensorrt_llm._utils import mpi_comm, torch_to_numpy # A global dicts to store exported weights. # This is set to be a global variable to avoid extra code modification from tensorrt_llm. @@ -498,6 +498,7 @@ def init_model_parallel_from_nemo(reshard_model): # Also split the python mpi communicator and set the global world one to the local split one new_comm = mpi_comm().Split(color=dp_rank, key=mp_rank) from mpi4py import MPI + MPI.COMM_WORLD = new_comm - return mp_rank, dp_rank, tp_size, pp_size, dp_size \ No newline at end of file + return mp_rank, dp_rank, tp_size, pp_size, dp_size diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 6adf83954f01..3b1dcff72938 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -23,17 +23,16 @@ from typing import List, Optional import numpy as np +import tensorrt as trt import tensorrt_llm import torch from mpi4py.futures import MPIPoolExecutor +from tensorrt_llm._utils import mpi_comm +from tensorrt_llm.builder import Engine from tensorrt_llm.lora_manager import LoraManager -from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig, GenerationSession from tensorrt_llm.mapping import Mapping -from tensorrt_llm.builder import Engine -from tensorrt_llm._utils import mpi_comm -import tensorrt as trt - +from tensorrt_llm.quantization import QuantMode +from tensorrt_llm.runtime import GenerationSession, ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig from transformers import PreTrainedTokenizer LOGGER = logging.getLogger("NeMo") @@ -485,20 +484,17 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): # https://github.com/terrykong/TensorRT-LLM/blob/05316d3313360012536ace46c781518f5afae75e/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp#L478 engine_filename = f"rank{engine_index}.engine" serialize_path = Path(engine_dir) / engine_filename - #$#$#$assert torch.cuda.current_device() == mpi_device + # $#$#$assert torch.cuda.current_device() == mpi_device with open(serialize_path, "rb") as f: engine_data = bytearray(f.read()) with open(config_path) as f: json_config_str = f.read() - engine = Engine.from_buffer( - engine_buffer=engine_data, - json_config_str=json_config_str, - rank=model_parallel_rank) + engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank) decoder = ModelRunner.from_engine( engine=engine, - #rank=world_config.rank, + # rank=world_config.rank, # We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process # So we will set it to the current rank=torch.cuda.current_device(), @@ -523,7 +519,9 @@ def refit(weights_dict: dict): global tensorrt_llm_worker_context decoder = tensorrt_llm_worker_context.decoder if not isinstance(decoder, ModelRunner): - raise ValueError(f"Refit is only supported with ModelRunner, but export has been configured with {type(decoder)=}") + raise ValueError( + f"Refit is only supported with ModelRunner, but export has been configured with {type(decoder)=}" + ) engine = decoder.session.runtime.engine # The session dtype plumbs the model_config's dtype @@ -538,15 +536,19 @@ def refit(weights_dict: dict): skipped_weights.append(trt_name) continue trt_weight = trt.Weights(model_dtype, weight.data_ptr(), torch.numel(weight)) - trt_wt_location = ( - trt.TensorLocation.DEVICE if weight.is_cuda else trt.TensorLocation.HOST - ) - assert model_dtype == refitter.get_weights_prototype(trt_name).dtype == maybe_cast_to_trt_dtype(weight.dtype), f"Expected all three of these dtypes to be the same {model_dtype=} {refitter.get_weights_prototype(trt_name).dtype=} weight.dtype={maybe_cast_to_trt_dtype(weight.dtype)}" - - refitter.set_named_weights(trt_name, trt_weight, trt_wt_location), f"Unable to set {trt_name=} {trt_weight=} {trt_wt_location=}" + trt_wt_location = trt.TensorLocation.DEVICE if weight.is_cuda else trt.TensorLocation.HOST + assert ( + model_dtype == refitter.get_weights_prototype(trt_name).dtype == maybe_cast_to_trt_dtype(weight.dtype) + ), f"Expected all three of these dtypes to be the same {model_dtype=} {refitter.get_weights_prototype(trt_name).dtype=} weight.dtype={maybe_cast_to_trt_dtype(weight.dtype)}" + + refitter.set_named_weights( + trt_name, trt_weight, trt_wt_location + ), f"Unable to set {trt_name=} {trt_weight=} {trt_wt_location=}" remaining_refit_weights.remove(trt_name) if skipped_weights: - logging.warning(f"These weights were ignored during refit since they are not present in engine: {skipped_weights}") + logging.warning( + f"These weights were ignored during refit since they are not present in engine: {skipped_weights}" + ) if remaining_refit_weights: logging.warning(f"Weights dict did not contain weights for these named TRT weights: {remaining_refit_weights}") @@ -561,7 +563,9 @@ def unload_engine(): global tensorrt_llm_worker_context decoder = tensorrt_llm_worker_context.decoder if not isinstance(decoder, ModelRunner): - raise ValueError(f"unload_engine is only supported with ModelRunner, but export has been configured with {type(decoder)=}") + raise ValueError( + f"unload_engine is only supported with ModelRunner, but export has been configured with {type(decoder)=}" + ) logging.info("Unloading engine...") del tensorrt_llm_worker_context.decoder