diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 50d07ac18317..cdbf331e0ad2 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -356,28 +356,21 @@ def build( with open(cfg_path, "w", encoding="utf-8") as f: json.dump(engine.config.to_dict(), f, indent=4) - print(f"engine saved to {self.model_dir}") - - print_mem("post build_and_save_engine") - self.model_runner, self.session_params = load_refit(engine_dir=self.model_dir) print_mem("post load_refit") - print(f"device: {origdev} {torch.cuda.current_device()}") + print(f"engine saved to {self.model_dir} device: {origdev} {torch.cuda.current_device()}") def refit( self, nemo_model, nemo_model_config, ): - assert self.use_refit, "TRT-LLM model must be built() with refit=True" - assert self.engine, "TRT-LLM model must be loaded with build() prior to refitting" - from .trt_llm.nemo.nemo_ckpt_convert import convert_nemo_model - + assert self.use_refit, "TRT-LLM model must be built() with refit=True" + print_mem("pre refit") - import time tic = time.time() @@ -389,7 +382,6 @@ def refit( tokenizer_vocab_size=self.tokenizer.vocab_size, reshard_model=self.reshard_model, ) - toc = time.time() print_mem("post nemo_model_to_model_config") print(f" nemo_model_to_model_config took {toc-tic}") @@ -398,13 +390,13 @@ def refit( tic = time.time() self.model_runner.session = create_gpt_session(self.session_params) toc = time.time() - print(f" session load took f{toc-tic}") + print(f" session load took {toc-tic}") tic = time.time() session = self.model_runner.session session.refit_engine(weights, self.session_params.model_config.data_type) toc = time.time() - print(f"refit_runtime_engine took f{toc-tic}") + print(f"refit_runtime_engine took {toc-tic}") print_mem("post refit")