Skip to content

Commit

Permalink
fix refitting
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 17, 2024
1 parent 4dc69e3 commit 45a6822
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
1 change: 1 addition & 0 deletions nemo/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


use_TensorRTLLM = True
try:
from nemo.export.tensorrt_llm import TensorRTLLM
except Exception as e:
LOGGER.warning("TensorRTLLM could not be imported.")
Expand Down
20 changes: 7 additions & 13 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def build(
model_config=model_config,
model_weights=weights,
model_dir=self.model_dir,
use_refit=use_refit,
)
torch.distributed.barrier()

Expand All @@ -356,28 +357,22 @@ def build(
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)

print(f"engine saved to {self.model_dir}")


print_mem("post build_and_save_engine")

self.model_runner, self.session_params = load_refit(engine_dir=self.model_dir)
print_mem("post load_refit")

print(f"device: {origdev} {torch.cuda.current_device()}")
print(f"engine saved to {self.model_dir} device: {origdev} {torch.cuda.current_device()}")

def refit(
self,
nemo_model,
nemo_model_config,
):
assert self.use_refit, "TRT-LLM model must be built() with refit=True"
assert self.engine, "TRT-LLM model must be loaded with build() prior to refitting"

from .trt_llm.nemo.nemo_ckpt_convert import convert_nemo_model

from .trt_llm.tensorrt_llm_run import create_gpt_session
assert self.use_refit, "TRT-LLM model must be built() with refit=True"

print_mem("pre refit")

import time
tic = time.time()

Expand All @@ -389,7 +384,6 @@ def refit(
tokenizer_vocab_size=self.tokenizer.vocab_size,
reshard_model=self.reshard_model,
)

toc = time.time()
print_mem("post nemo_model_to_model_config")
print(f" nemo_model_to_model_config took {toc-tic}")
Expand All @@ -398,13 +392,13 @@ def refit(
tic = time.time()
self.model_runner.session = create_gpt_session(self.session_params)
toc = time.time()
print(f" session load took f{toc-tic}")
print(f" session load took {toc-tic}")

tic = time.time()
session = self.model_runner.session
session.refit_engine(weights, self.session_params.model_config.data_type)
toc = time.time()
print(f"refit_runtime_engine took f{toc-tic}")
print(f"refit_runtime_engine took {toc-tic}")

print_mem("post refit")

Expand Down
2 changes: 2 additions & 0 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def build_and_save_engine(
lora_target_modules=None,
max_prompt_embedding_table_size=0,
enable_multi_block_mode: bool = False,
use_refit: bool = False,
):
try:
model_cls = getattr(tensorrt_llm.models, model_config.architecture)
Expand All @@ -397,6 +398,7 @@ def build_and_save_engine(
'gather_generation_logits': False,
'strongly_typed': False,
'builder_opt': None,
'use_refit': use_refit,
}
build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)

Expand Down

0 comments on commit 45a6822

Please sign in to comment.