From bed73ec2038b3ac91b6d7731b74335133a6b7a78 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Mon, 13 May 2024 11:34:10 -0700 Subject: [PATCH] fix multinode DP Signed-off-by: Jimmy Zhang --- nemo/export/tensorrt_llm.py | 20 ++++++++++---------- nemo/export/trt_llm/tensorrt_llm_run.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 4e81b3b4b41d..31b6b9c25a2a 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -279,7 +279,6 @@ def build( mp_group = parallel_state.get_model_parallel_group() mp_size = tp_size*pp_size mp_rank = tp_size*pp_rank + tp_rank - if dp_size > 1: self.model_dir = os.path.join(self.model_dir, f"dp_rank{dp_rank}") @@ -288,9 +287,9 @@ def build( # So we manipulate TRTLLM to emulate a TP->PP single node setup tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank) device_ids = [ - ((i+mp_rank-torch.cuda.current_device())+mp_size)%mp_size + ((i+torch.cuda.current_device()-mp_rank)+mp_size)%mp_size for i in range(mp_size)] - + mapping = tensorrt_llm.Mapping( world_size = mp_size, rank = mp_rank, @@ -299,7 +298,7 @@ def build( pp_size = pp_size) LOGGER.info( - f'''TRT-LLM rank mapping: Rank {torch.distributed.get_rank()} -> {mp_rank}: + f'''TRT-LLM rank mapping: Rank {torch.distributed.get_rank()}, mp_rank {mp_rank}: tp_rank {parallel_state.get_tensor_model_parallel_rank()} -> {mapping.tp_rank}, pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {mapping.pp_rank}''' ) @@ -328,13 +327,14 @@ def build( trt_model_type=trt_model_type ) torch.distributed.barrier() - print(f"engine saved to {self.model_dir}") - if torch.cuda.current_device() == 0: - cfg_path = Path(os.path.join(self.model_dir, 'config.json')) - if not cfg_path.exists(): - with open(cfg_path, "w", encoding="utf-8") as f: - json.dump(engine.config.to_dict(), f, indent=4) + myrank = torch.distributed.get_rank() + cfg_path = Path(os.path.join(self.model_dir, f'config_{myrank}.json')) + print(f"engine saved to {self.model_dir}") + print(self.model_dir, f'config_{myrank}.json') + if not cfg_path.exists(): + with open(cfg_path, "w", encoding="utf-8") as f: + json.dump(engine.config.to_dict(), f, indent=4) print_mem("post build_and_save_engine") diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 9e8def13fc5d..b0f233e8871b 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -337,7 +337,7 @@ def load_refit(engine_dir, device_ids): It also supports running the TRT LLM model on multi-GPU. """ - config_path = Path(engine_dir) / "config.json" + config_path = Path(engine_dir) / f"config_{torch.distributed.get_rank()}.json" json_config = GptJsonConfig.parse_file(config_path) model_config = json_config.model_config