Skip to content

Commit

Permalink
add missing code
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Apr 16, 2024
1 parent 315aad3 commit cd3c6e5
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 7 deletions.
1 change: 1 addition & 0 deletions nemo/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from nemo.export.tensorrt_llm import TensorRTLLM
except Exception as e:
LOGGER.warning("TensorRTLLM could not be imported.")
print(e)
17 changes: 12 additions & 5 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from nemo.export.trt_llm.model_config_trt import model_config_to_tensorrt_llm
from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer
from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_llm_to_model_config
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_refit
from nemo.export.trt_llm.utils import is_nemo_file, unpack_nemo_ckpt

Expand All @@ -39,6 +40,13 @@
except Exception:
use_deploy = False

def print_mem(prefix):
torch.cuda.empty_cache()
pyt = torch.cuda.memory_allocated() / (1024**3)
el = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / (1024**3)
print(f"Mem Usage | {prefix} | {pyt} {el} | {el-pyt}")



@wrapt.decorator
def noop_decorator(func):
Expand Down Expand Up @@ -240,11 +248,10 @@ def build(
use_refit: bool = False,
reshard_model: bool = False,
):
from tensorrt_llm.bindings import MpiComm
from megatron.core import parallel_state
assert tensorrt_llm.mpi_rank() == torch.distributed.get_rank()

gpus_per_node = 8
logger.set_level('info')

self.use_refit = use_refit
self.tokenizer = build_tokenizer(tokenizer)
Expand All @@ -267,7 +274,7 @@ def build(

model_parallel_size = tp_size*pp_size
model_parallel_rank = tp_size*pp_rank + tp_rank
MpiComm.split(dp_rank, model_parallel_rank)
tensorrt_llm.bindings.MpiComm.split(dp_rank, model_parallel_rank)

# Get the model parallel group using global ids from NeMo
tp_groups = [[j*tp_size+i for i in range(tp_size)] for j in range(pp_size*dp_size)]
Expand All @@ -276,7 +283,7 @@ def build(
mp_group+=tp_groups[dp_rank + idx*dp_size]
device_ids = [i % gpus_per_node for i in mp_group]

mapping = Mapping(
mapping = tensorrt_llm.Mapping(
world_size = tp_size*pp_size,
rank = tp_size*pp_rank + tp_rank,
gpus_per_node = gpus_per_node,
Expand Down Expand Up @@ -315,7 +322,7 @@ def build(
torch.distributed.barrier()
print_mem("post build_and_save_engine")

self.model_runner, self.session_params = load_dataparallel(
self.model_runner, self.session_params = load_refit(
engine_dir=self.model_dir,
device_ids=device_ids,
)
Expand Down
21 changes: 20 additions & 1 deletion nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import typing
from collections import defaultdict
from pathlib import Path
from functools import cache

import numpy as np
import tensorstore # This is important even though not used. Otherwise zarr raises error.
Expand Down Expand Up @@ -310,6 +311,20 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
return weights_dict, llm_config, tokenizer


@cache
def rename_layer_num(param_name, layer_num):
split_key = param_name.split(".")
layer_index = int(get_layer_index(split_key))
split_key[layer_index] = str(layer_num)

return ".".join(split_key)

@cache
def get_layer_num(param_name):
split_key = param_name.split(".")
layer_index = int(get_layer_index(split_key))
return int(split_key[layer_index])

@torch.no_grad()
def convert_nemo_model(nemo_model, nemo_model_config, tokenizer_vocab_size, reshard_model=False, cpu=True):
from megatron.core import parallel_state
Expand Down Expand Up @@ -343,7 +358,9 @@ def convert_nemo_model(nemo_model, nemo_model_config, tokenizer_vocab_size, res
if num_kv_heads == 0:
num_kv_heads = 1 if multi_query_mode else num_attention_heads
reshard_model = reshard_model and pp_size > 1
weights_dict = persistent_weight_dict if cpu else {}

from nemo.export.trt_llm.nemo.convert import weights_dict as persistent_weights_dict
weights_dict = persistent_weights_dict if cpu else {}

export_config = {
"apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p",
Expand All @@ -363,11 +380,13 @@ def convert_nemo_model(nemo_model, nemo_model_config, tokenizer_vocab_size, res
model_level_params = {}
starmap_args = []

import time
tic = time.time()

layers_per_pp = num_layers // pp_size
layers_per_chunk = layers_per_pp // vp_size

# ----------------Gather layers from other shards ----------------
if vp_size > 1: # consolidate params across model chunks
for idx, model_chunk in enumerate(nemo_model):
for key, val in model_chunk.state_dict().items():
Expand Down
1 change: 1 addition & 0 deletions nemo/export/trt_llm/nemo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
import tensorrt_llm
from tensorrt_llm import str_dtype_to_trt
from tensorrt_llm._utils import torch_dtype_to_np, np_dtype_to_trt, trt_dtype_to_str
from transformers import AutoTokenizer, LlamaConfig, PretrainedConfig, PreTrainedTokenizer

from nemo.export.trt_llm.model_config import (
Expand Down
23 changes: 22 additions & 1 deletion nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from tensorrt_llm.runtime import ModelConfig, SamplingConfig, ModelRunnerCpp

from tensorrt_llm.bindings import (DataType, GenerationInput, GenerationOutput,
GptJsonConfig, GptSession, GptSessionConfig,
KvCacheConfig, PromptTuningParams, WorldConfig)
from tensorrt_llm.bindings import SamplingConfig as GptSamplingConfig

from transformers import PreTrainedTokenizer

from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group
Expand Down Expand Up @@ -309,6 +315,21 @@ def load(
add_bos=add_bos,
)

@dataclass
class GptSession_params:
session_config: GptSessionConfig
model_config: ModelConfig
world_config: WorldConfig
engine_data: bytearray

def create_gpt_session(
session_params: GptSession_params, engine_data: bytearray = None):
if engine_data is None:
engine_data = session_params.engine_data
return GptSession(session_params.session_config,
session_params.model_config,
session_params.world_config,
engine_data)

def load_refit(engine_dir, device_ids):
"""Loaded the compiled LLM model and run it.
Expand Down

0 comments on commit cd3c6e5

Please sign in to comment.