Skip to content

Commit

Permalink
multinode support
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 9, 2024
1 parent fb108a1 commit 0ea57dd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 45 deletions.
77 changes: 34 additions & 43 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,11 @@ def build(
max_input_tokens: int = 4096,
max_output_len: int = 1024,
max_batch_size: int = 4,
gpus_per_node: int = 8,
use_refit: bool = True,
reshard_model: bool = False,
):
origdev = torch.cuda.current_device()

from megatron.core import parallel_state
assert tensorrt_llm.mpi_rank() == torch.distributed.get_rank()

Expand All @@ -271,51 +272,51 @@ def build(
dp_rank = torch.distributed.get_rank() // tp_size
pp_rank = 0
pp_size = 1
mp_group = parallel_state.get_tensor_model_parallel_group()
else:
self.reshard_model = False
mp_group = parallel_state.get_model_parallel_group()
mp_size = tp_size*pp_size
mp_rank = tp_size*pp_rank + tp_rank

model_parallel_size = tp_size*pp_size
model_parallel_rank = tp_size*pp_rank + tp_rank
tensorrt_llm.bindings.MpiComm.split(dp_rank, model_parallel_rank)
if dp_size > 1:
self.model_dir = os.path.join(self.model_dir, f"dp_rank{dp_rank}")

# Get the model parallel group using global ids from NeMo
tp_groups = [[j*tp_size+i for i in range(tp_size)] for j in range(pp_size*dp_size)]
mp_group = []
for idx in range(pp_size):
mp_group+=tp_groups[dp_rank + idx*dp_size]
# device_ids = [i % gpus_per_node for i in mp_group]
device_ids = mp_group
# TRTLLM asserts that rank equals the device num however this
# is not true for the megatron core mapping TP->DP->PP.
# So we manipulate TRTLLM to emulate a TP->PP single node setup
tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank)
device_ids = [
(i+torch.cuda.current_device()-mp_rank) % mp_size
for i in range(mp_size)]
assert device_ids[mp_rank] == torch.cuda.current_device()

mapping = tensorrt_llm.Mapping(
world_size = tp_size*pp_size,
rank = tp_size*pp_rank + tp_rank,
gpus_per_node = gpus_per_node,
world_size = mp_size,
rank = mp_rank,
gpus_per_node = mp_size,
tp_size = tp_size,
pp_size = pp_size
)

if dp_size > 1:
self.model_dir = os.path.join(self.model_dir, f"dp_rank{dp_rank}")
pp_size = pp_size)

LOGGER.info(
f'''TRT-LLM rank mapping: Rank {torch.distributed.get_rank()} -> {model_parallel_rank}:
f'''TRT-LLM rank mapping: Rank {torch.distributed.get_rank()} -> {mp_rank}:
tp_rank {parallel_state.get_tensor_model_parallel_rank()} -> {mapping.tp_rank},
pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {mapping.pp_rank}'''
)
print(f"{torch.distributed.get_rank()} color {dp_rank} rank {model_parallel_rank} nemo_mp_group {mp_group} {device_ids} ")
# assert torch.cuda.current_device() == device_ids[model_parallel_rank]
mp_group_ranks = torch.distributed.distributed_c10d.get_process_group_ranks(mp_group)
print(f"{torch.distributed.get_rank()} color {dp_rank} mp_rank {mp_rank} mp_group_ranks {mp_group_ranks} device_ids {device_ids}")
print(f"trtllm mpi : {tensorrt_llm.bindings.MpiComm.getRank()} {tensorrt_llm.bindings.MpiComm.getSize()}")

model_config, weights = nemo_llm_model_to_model_config(
nemo_model=nemo_model,
tokenizer=self.tokenizer,
nemo_model_config=nemo_model_config,
reshard_model=self.reshard_model,
mapping=mapping,
trt_model_type=trt_model_type,
)
trt_model_type=trt_model_type)

print_mem("pre build_and_save_engine")
build_and_save_engine(
engine = build_and_save_engine(
max_input_len=max_input_len,
max_output_len=max_output_len,
max_input_tokens=max_input_tokens,
Expand All @@ -327,29 +328,19 @@ def build(
trt_model_type=trt_model_type
)
torch.distributed.barrier()
print(f"engine saved to {self.model_dir}")
if torch.cuda.current_device() == 0:
with open(os.path.join(self.model_dir, 'config.json'),
"w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)

print_mem("post build_and_save_engine")

self.model_runner, self.session_params = load_refit(
engine_dir=self.model_dir,
device_ids=device_ids,
)
device_ids=device_ids)

# sampling_config = tensorrt_llm.runtime.SamplingConfig(
# end_id=self.tokenizer.eos_id,
# pad_id=self.tokenizer.eos_id, #TODO
# )

# data = [529,17833,29918,333,29918,29900,29958,3924,13,13,29966,17833,29918,333,29918,29896,29958,2659,13,22110,8906,341,18219,25992,29973,13,29966,17833,29918,333,29918,29896,29958,7900,22137,13]
# inputdata = [torch.IntTensor(data)]
# resp = self.model_runner.generate(inputdata, sampling_config=sampling_config)
# resp = resp[0][0].tolist()
# resp = tokenizer.ids_to_text(resp)
# print(f"@@@@@@@@ {torch.cuda.current_device()} {orig_dev} {resp}")

# if torch.cuda.current_device() == 0:
# import pdb
# pdb.set_trace()
# torch.distributed.barrier()
print(f"device: {origdev} {torch.cuda.current_device()}")


def refit(
Expand Down
6 changes: 4 additions & 2 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,13 @@ def load_refit(engine_dir, device_ids):

tp_size = json_config.tensor_parallelism
pp_size = json_config.pipeline_parallelism
world_config = WorldConfig.mpi(tensor_parallelism=tp_size,
mp_size = tp_size*pp_size
world_config = WorldConfig.mpi(gpus_per_node=mp_size,
tensor_parallelism=tp_size,
pipeline_parallelism=pp_size,
device_ids=device_ids)

assert tensorrt_llm.bindings.MpiComm.getRank() == world_config.rank
assert torch.cuda.current_device() == world_config.device
engine_filename = json_config.engine_filename(world_config)
serialize_path = Path(engine_dir) / engine_filename

Expand Down

0 comments on commit 0ea57dd

Please sign in to comment.