Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: terrykong <terrykong@users.noreply.github.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
  • Loading branch information
terrykong committed Oct 2, 2024
1 parent 148543d commit f3550aa
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 26 deletions.
13 changes: 10 additions & 3 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_distributed, refit, unload_engine
from nemo.export.trt_llm.tensorrt_llm_run import (
generate,
generate_streaming,
load,
load_distributed,
refit,
unload_engine,
)

use_deploy = True
try:
Expand Down Expand Up @@ -490,7 +497,7 @@ def build(
engine = build_and_save_engine(
max_input_len=max_input_len,
max_output_len=max_output_len,
max_seq_len=max_input_len+max_output_len,
max_seq_len=max_input_len + max_output_len,
max_batch_size=max_batch_size,
model_config=model_config[0],
model_weights=weights[0],
Expand Down Expand Up @@ -968,6 +975,6 @@ def _load(self):
"model needs to be exported again. "
"Error message: " + repr(error)
) from error

def unload_engine(self):
unload_engine()
5 changes: 3 additions & 2 deletions nemo/export/trt_llm/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import tensorrt_llm
import torch
from tensorrt_llm._utils import torch_to_numpy, mpi_comm
from tensorrt_llm._utils import mpi_comm, torch_to_numpy

# A global dicts to store exported weights.
# This is set to be a global variable to avoid extra code modification from tensorrt_llm.
Expand Down Expand Up @@ -498,6 +498,7 @@ def init_model_parallel_from_nemo(reshard_model):
# Also split the python mpi communicator and set the global world one to the local split one
new_comm = mpi_comm().Split(color=dp_rank, key=mp_rank)
from mpi4py import MPI

MPI.COMM_WORLD = new_comm

return mp_rank, dp_rank, tp_size, pp_size, dp_size
return mp_rank, dp_rank, tp_size, pp_size, dp_size
46 changes: 25 additions & 21 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,16 @@
from typing import List, Optional

import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
from mpi4py.futures import MPIPoolExecutor
from tensorrt_llm._utils import mpi_comm
from tensorrt_llm.builder import Engine
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig, GenerationSession
from tensorrt_llm.mapping import Mapping

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Mapping' is not used.
from tensorrt_llm.builder import Engine
from tensorrt_llm._utils import mpi_comm
import tensorrt as trt

from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import GenerationSession, ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'GenerationSession' is not used.
from transformers import PreTrainedTokenizer

LOGGER = logging.getLogger("NeMo")
Expand Down Expand Up @@ -485,20 +484,17 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node):
# https://github.com/terrykong/TensorRT-LLM/blob/05316d3313360012536ace46c781518f5afae75e/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp#L478
engine_filename = f"rank{engine_index}.engine"
serialize_path = Path(engine_dir) / engine_filename
#$#$#$assert torch.cuda.current_device() == mpi_device
# $#$#$assert torch.cuda.current_device() == mpi_device
with open(serialize_path, "rb") as f:
engine_data = bytearray(f.read())

with open(config_path) as f:
json_config_str = f.read()

engine = Engine.from_buffer(
engine_buffer=engine_data,
json_config_str=json_config_str,
rank=model_parallel_rank)
engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank)
decoder = ModelRunner.from_engine(
engine=engine,
#rank=world_config.rank,
# rank=world_config.rank,
# We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process
# So we will set it to the current
rank=torch.cuda.current_device(),
Expand All @@ -523,7 +519,9 @@ def refit(weights_dict: dict):
global tensorrt_llm_worker_context
decoder = tensorrt_llm_worker_context.decoder
if not isinstance(decoder, ModelRunner):
raise ValueError(f"Refit is only supported with ModelRunner, but export has been configured with {type(decoder)=}")
raise ValueError(
f"Refit is only supported with ModelRunner, but export has been configured with {type(decoder)=}"
)

engine = decoder.session.runtime.engine
# The session dtype plumbs the model_config's dtype
Expand All @@ -538,15 +536,19 @@ def refit(weights_dict: dict):
skipped_weights.append(trt_name)
continue
trt_weight = trt.Weights(model_dtype, weight.data_ptr(), torch.numel(weight))
trt_wt_location = (
trt.TensorLocation.DEVICE if weight.is_cuda else trt.TensorLocation.HOST
)
assert model_dtype == refitter.get_weights_prototype(trt_name).dtype == maybe_cast_to_trt_dtype(weight.dtype), f"Expected all three of these dtypes to be the same {model_dtype=} {refitter.get_weights_prototype(trt_name).dtype=} weight.dtype={maybe_cast_to_trt_dtype(weight.dtype)}"

refitter.set_named_weights(trt_name, trt_weight, trt_wt_location), f"Unable to set {trt_name=} {trt_weight=} {trt_wt_location=}"
trt_wt_location = trt.TensorLocation.DEVICE if weight.is_cuda else trt.TensorLocation.HOST
assert (
model_dtype == refitter.get_weights_prototype(trt_name).dtype == maybe_cast_to_trt_dtype(weight.dtype)
), f"Expected all three of these dtypes to be the same {model_dtype=} {refitter.get_weights_prototype(trt_name).dtype=} weight.dtype={maybe_cast_to_trt_dtype(weight.dtype)}"

refitter.set_named_weights(
trt_name, trt_weight, trt_wt_location
), f"Unable to set {trt_name=} {trt_weight=} {trt_wt_location=}"
remaining_refit_weights.remove(trt_name)
if skipped_weights:
logging.warning(f"These weights were ignored during refit since they are not present in engine: {skipped_weights}")
logging.warning(
f"These weights were ignored during refit since they are not present in engine: {skipped_weights}"
)
if remaining_refit_weights:
logging.warning(f"Weights dict did not contain weights for these named TRT weights: {remaining_refit_weights}")

Expand All @@ -561,7 +563,9 @@ def unload_engine():
global tensorrt_llm_worker_context
decoder = tensorrt_llm_worker_context.decoder
if not isinstance(decoder, ModelRunner):
raise ValueError(f"unload_engine is only supported with ModelRunner, but export has been configured with {type(decoder)=}")
raise ValueError(
f"unload_engine is only supported with ModelRunner, but export has been configured with {type(decoder)=}"
)

logging.info("Unloading engine...")
del tensorrt_llm_worker_context.decoder
Expand Down

0 comments on commit f3550aa

Please sign in to comment.