diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index d987eca2a7a8..11d5f30956ab 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -241,19 +241,18 @@ def build( self, nemo_model, nemo_model_config, + trt_model_type, tokenizer, max_input_len: int = 256, max_output_len: int = 256, max_batch_size: int = 8, gpus_per_node: int = 8, - use_refit: bool = False, + use_refit: bool = True, reshard_model: bool = False, ): from megatron.core import parallel_state assert tensorrt_llm.mpi_rank() == torch.distributed.get_rank() - gpus_per_node = 8 - self.use_refit = use_refit self.tokenizer = build_tokenizer(tokenizer) @@ -312,13 +311,15 @@ def build( ) print_mem("pre build_and_save_engine") - self.engine = build_and_save_engine( + build_and_save_engine( max_input_len=max_input_len, max_output_len=max_output_len, max_batch_size=max_batch_size, model_config=model_config, model_weights=weights, model_dir=self.model_dir, + use_refit=self.use_refit, + trt_model_type=trt_model_type ) torch.distributed.barrier() print_mem("post build_and_save_engine") @@ -352,7 +353,7 @@ def refit( nemo_model_config, ): assert self.use_refit, "TRT-LLM model must be built() with refit=True" - assert self.engine, "TRT-LLM model must be loaded with build() prior to refitting" + assert self.model_runner, "TRT-LLM model must be loaded with build() prior to refitting" from .trt_llm.nemo.nemo_ckpt_convert import convert_nemo_model diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 90760b4f1c46..70c185394438 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -14,6 +14,7 @@ import argparse +from importlib.machinery import SourceFileLoader import logging import os import time @@ -24,13 +25,19 @@ import tensorrt_llm import torch from tensorrt_llm import str_dtype_to_trt -from tensorrt_llm._utils import np_dtype_to_trt -from tensorrt_llm.builder import Builder +from tensorrt_llm._utils import np_dtype_to_trt, torch_dtype_to_np, np_dtype_to_trt, trt_dtype_to_str +from tensorrt_llm.builder import Builder, BuildConfig from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import add_lora from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode +from tensorrt_llm.commands.build import build_model, build as build_trtllm +from tensorrt_llm.plugin import PluginConfig +from tensorrt_llm.models.llama.model import LLaMAForCausalLM +from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights + + MODEL_NAME = "NeMo" @@ -356,15 +363,15 @@ def build_and_save_engine( model_dir=None, model_weights=None, model_config=None, + use_refit=True, + trt_model_type='LLaMAForCausalLM' ): '''Minimum implementation of TRTLLM 0.9's unified builder api''' - from tensorrt_llm.commands.build import build_model, build as build_trtllm - from tensorrt_llm.plugin import PluginConfig - from tensorrt_llm.builder import BuildConfig - from tensorrt_llm._utils import torch_dtype_to_np, np_dtype_to_trt, trt_dtype_to_str - from tensorrt_llm.models.llama.model import LLaMAForCausalLM - from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights + try: + model_cls = getattr(tensorrt_llm.models, trt_model_type) + except: + raise AttributeError(f"Could not find TRTLLM model type: {trt_model_type}!") str_dtype = model_config.dtype plugin_config = PluginConfig() @@ -384,11 +391,11 @@ def build_and_save_engine( 'gather_generation_logits': False, 'strongly_typed': False, 'builder_opt': None, - 'use_refit': True, + 'use_refit': use_refit, } build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config) - - model = LLaMAForCausalLM.from_config(model_config) + + model = model_cls.from_config(model_config) model = optimize_model( model, use_parallel_embedding=model_config.use_parallel_embedding,