Skip to content

Commit

Permalink
rebase aligner main
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Apr 16, 2024
1 parent 1790ce9 commit 10d8318
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
11 changes: 6 additions & 5 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,18 @@ def build(
self,
nemo_model,
nemo_model_config,
trt_model_type,
tokenizer,
max_input_len: int = 256,
max_output_len: int = 256,
max_batch_size: int = 8,
gpus_per_node: int = 8,
use_refit: bool = False,
use_refit: bool = True,
reshard_model: bool = False,
):
from megatron.core import parallel_state
assert tensorrt_llm.mpi_rank() == torch.distributed.get_rank()

gpus_per_node = 8

self.use_refit = use_refit
self.tokenizer = build_tokenizer(tokenizer)

Expand Down Expand Up @@ -312,13 +311,15 @@ def build(
)

print_mem("pre build_and_save_engine")
self.engine = build_and_save_engine(
build_and_save_engine(
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
model_config=model_config,
model_weights=weights,
model_dir=self.model_dir,
use_refit=self.use_refit,
trt_model_type=trt_model_type
)
torch.distributed.barrier()
print_mem("post build_and_save_engine")
Expand Down Expand Up @@ -352,7 +353,7 @@ def refit(
nemo_model_config,
):
assert self.use_refit, "TRT-LLM model must be built() with refit=True"
assert self.engine, "TRT-LLM model must be loaded with build() prior to refitting"
assert self.model_runner, "TRT-LLM model must be loaded with build() prior to refitting"

from .trt_llm.nemo.nemo_ckpt_convert import convert_nemo_model

Expand Down
29 changes: 18 additions & 11 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import argparse
from importlib.machinery import SourceFileLoader
import logging
import os
import time
Expand All @@ -24,13 +25,19 @@
import tensorrt_llm
import torch
from tensorrt_llm import str_dtype_to_trt
from tensorrt_llm._utils import np_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm._utils import np_dtype_to_trt, torch_dtype_to_np, np_dtype_to_trt, trt_dtype_to_str
from tensorrt_llm.builder import Builder, BuildConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import add_lora
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.commands.build import build_model, build as build_trtllm
from tensorrt_llm.plugin import PluginConfig
from tensorrt_llm.models.llama.model import LLaMAForCausalLM
from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights



MODEL_NAME = "NeMo"

Expand Down Expand Up @@ -356,15 +363,15 @@ def build_and_save_engine(
model_dir=None,
model_weights=None,
model_config=None,
use_refit=True,
trt_model_type='LLaMAForCausalLM'
):
'''Minimum implementation of TRTLLM 0.9's unified builder api'''

from tensorrt_llm.commands.build import build_model, build as build_trtllm
from tensorrt_llm.plugin import PluginConfig
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm._utils import torch_dtype_to_np, np_dtype_to_trt, trt_dtype_to_str
from tensorrt_llm.models.llama.model import LLaMAForCausalLM
from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights
try:
model_cls = getattr(tensorrt_llm.models, trt_model_type)
except:
raise AttributeError(f"Could not find TRTLLM model type: {trt_model_type}!")

str_dtype = model_config.dtype
plugin_config = PluginConfig()
Expand All @@ -384,11 +391,11 @@ def build_and_save_engine(
'gather_generation_logits': False,
'strongly_typed': False,
'builder_opt': None,
'use_refit': True,
'use_refit': use_refit,
}
build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)

model = LLaMAForCausalLM.from_config(model_config)
model = model_cls.from_config(model_config)
model = optimize_model(
model,
use_parallel_embedding=model_config.use_parallel_embedding,
Expand Down

0 comments on commit 10d8318

Please sign in to comment.