From 697d92073923a6e2f06c118dae59b0d60708ca9c Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Thu, 16 May 2024 09:00:52 -0700 Subject: [PATCH] fixes Signed-off-by: Jimmy Zhang --- nemo/export/tensorrt_llm.py | 1 + nemo/export/trt_llm/tensorrt_llm_build.py | 3 ++- nemo/export/trt_llm/tensorrt_llm_run.py | 5 +++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 1a32c5cbcc2c..c7fe4d59daf1 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -335,6 +335,7 @@ def build( print_mem("post build_and_save_engine") self.model_runner, self.session_params = load_refit(engine_dir=self.model_dir) + print_mem("post load_refit") print(f"device: {origdev} {torch.cuda.current_device()}") diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index d9d0db6b6f44..8f6bacb87b50 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -394,9 +394,10 @@ def build_and_save_engine( build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config) model = model_cls.from_config(model_config) + # use_parallel_embedding=True, + model = optimize_model( model, - use_parallel_embedding=True, share_embedding_table=model_config.share_embedding_table, ) preprocess_weights(model_weights, model_config) diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 4ed5bb472598..a976266ac8f2 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -347,10 +347,11 @@ def load_refit(engine_dir): # TRTLLM assumes rank < gpus_per_node but this is not true for multinode setups # So hack around this using an arbitrarily big gpus_per_node to avoid asserts - gpus_per_node = 9999 + gpus_per_node = 64 mp_rank = tensorrt_llm.bindings.MpiComm.getRank() device_ids = [ - (i+torch.cuda.current_device()-mp_rank) for i in range(mp_size)] + (i+torch.cuda.current_device()-mp_rank+gpus_per_node)%gpus_per_node + for i in range(mp_size)] print(f"{torch.cuda.current_device()} device_ids {device_ids}") world_config = WorldConfig.mpi(gpus_per_node=gpus_per_node,