From 45a6822ab93b3e48d11e8d046e661e614a657ae7 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Thu, 16 May 2024 17:42:43 -0700 Subject: [PATCH] fix refitting Signed-off-by: Jimmy Zhang --- nemo/export/__init__.py | 1 + nemo/export/tensorrt_llm.py | 20 +++++++------------- nemo/export/trt_llm/tensorrt_llm_build.py | 2 ++ 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index 5bf092cc2d4c..799c64331ab5 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -19,6 +19,7 @@ use_TensorRTLLM = True +try: from nemo.export.tensorrt_llm import TensorRTLLM except Exception as e: LOGGER.warning("TensorRTLLM could not be imported.") diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 50d07ac18317..e62c237f5674 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -348,6 +348,7 @@ def build( model_config=model_config, model_weights=weights, model_dir=self.model_dir, + use_refit=use_refit, ) torch.distributed.barrier() @@ -356,28 +357,22 @@ def build( with open(cfg_path, "w", encoding="utf-8") as f: json.dump(engine.config.to_dict(), f, indent=4) - print(f"engine saved to {self.model_dir}") - - print_mem("post build_and_save_engine") - self.model_runner, self.session_params = load_refit(engine_dir=self.model_dir) print_mem("post load_refit") - print(f"device: {origdev} {torch.cuda.current_device()}") + print(f"engine saved to {self.model_dir} device: {origdev} {torch.cuda.current_device()}") def refit( self, nemo_model, nemo_model_config, ): - assert self.use_refit, "TRT-LLM model must be built() with refit=True" - assert self.engine, "TRT-LLM model must be loaded with build() prior to refitting" - from .trt_llm.nemo.nemo_ckpt_convert import convert_nemo_model - + from .trt_llm.tensorrt_llm_run import create_gpt_session + assert self.use_refit, "TRT-LLM model must be built() with refit=True" + print_mem("pre refit") - import time tic = time.time() @@ -389,7 +384,6 @@ def refit( tokenizer_vocab_size=self.tokenizer.vocab_size, reshard_model=self.reshard_model, ) - toc = time.time() print_mem("post nemo_model_to_model_config") print(f" nemo_model_to_model_config took {toc-tic}") @@ -398,13 +392,13 @@ def refit( tic = time.time() self.model_runner.session = create_gpt_session(self.session_params) toc = time.time() - print(f" session load took f{toc-tic}") + print(f" session load took {toc-tic}") tic = time.time() session = self.model_runner.session session.refit_engine(weights, self.session_params.model_config.data_type) toc = time.time() - print(f"refit_runtime_engine took f{toc-tic}") + print(f"refit_runtime_engine took {toc-tic}") print_mem("post refit") diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 290045537d08..434d465ff62b 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -371,6 +371,7 @@ def build_and_save_engine( lora_target_modules=None, max_prompt_embedding_table_size=0, enable_multi_block_mode: bool = False, + use_refit: bool = False, ): try: model_cls = getattr(tensorrt_llm.models, model_config.architecture) @@ -397,6 +398,7 @@ def build_and_save_engine( 'gather_generation_logits': False, 'strongly_typed': False, 'builder_opt': None, + 'use_refit': use_refit, } build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)