Skip to content

Commit

Permalink
fix multinode DP
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 16, 2024
1 parent ecb4c77 commit 6511f49
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def build(
mp_group = parallel_state.get_model_parallel_group()
mp_size = tp_size*pp_size
mp_rank = tp_size*pp_rank + tp_rank

if dp_size > 1:
self.model_dir = os.path.join(self.model_dir, f"dp_rank{dp_rank}")

Expand All @@ -316,9 +315,9 @@ def build(
# So we manipulate TRTLLM to emulate a TP->PP single node setup
tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank)
device_ids = [
((i+mp_rank-torch.cuda.current_device())+mp_size)%mp_size
((i+torch.cuda.current_device()-mp_rank)+mp_size)%mp_size
for i in range(mp_size)]

mapping = tensorrt_llm.Mapping(
world_size = mp_size,
rank = mp_rank,
Expand All @@ -327,7 +326,7 @@ def build(
pp_size = pp_size)

LOGGER.info(
f'''TRT-LLM rank mapping: Rank {torch.distributed.get_rank()} -> {mp_rank}:
f'''TRT-LLM rank mapping: Rank {torch.distributed.get_rank()}, mp_rank {mp_rank}:
tp_rank {parallel_state.get_tensor_model_parallel_rank()} -> {mapping.tp_rank},
pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {mapping.pp_rank}'''
)
Expand All @@ -354,13 +353,14 @@ def build(
model_dir=self.model_dir,
)
torch.distributed.barrier()
print(f"engine saved to {self.model_dir}")

if torch.cuda.current_device() == 0:
cfg_path = Path(os.path.join(self.model_dir, 'config.json'))
if not cfg_path.exists():
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)
myrank = torch.distributed.get_rank()
cfg_path = Path(os.path.join(self.model_dir, f'config_{myrank}.json'))
print(f"engine saved to {self.model_dir}")
print(self.model_dir, f'config_{myrank}.json')
if not cfg_path.exists():
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)

print_mem("post build_and_save_engine")

Expand Down
2 changes: 1 addition & 1 deletion nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def load_refit(engine_dir, device_ids):
It also supports running the TRT LLM model on multi-GPU.
"""

config_path = Path(engine_dir) / "config.json"
config_path = Path(engine_dir) / f"config_{torch.distributed.get_rank()}.json"
json_config = GptJsonConfig.parse_file(config_path)
model_config = json_config.model_config

Expand Down

0 comments on commit 6511f49

Please sign in to comment.