From 1d19459bf762e3c675cef0c74b6b4254882163b8 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 2 Oct 2024 00:53:04 +0000 Subject: [PATCH] tiny cleanup Signed-off-by: Terry Kong --- nemo/export/trt_llm/tensorrt_llm_run.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 3b1dcff72938..1772c071a745 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -478,13 +478,15 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): device_ids.append(device_ids.pop(0)) engine_index = model_parallel_rank mpi_rank = mpi_comm().Get_rank() - # TODO: copied from worldConfig.h (getDevice()) + # Copied from worldConfig.h (getDevice()) mpi_device = mpi_rank % gpus_per_node + # TODO: Consider re-enabling + # assert torch.cuda.current_device() == mpi_device + # TODO: check if API exists (copied from gptJsonConfig.cpp) # https://github.com/terrykong/TensorRT-LLM/blob/05316d3313360012536ace46c781518f5afae75e/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp#L478 engine_filename = f"rank{engine_index}.engine" serialize_path = Path(engine_dir) / engine_filename - # $#$#$assert torch.cuda.current_device() == mpi_device with open(serialize_path, "rb") as f: engine_data = bytearray(f.read()) @@ -494,9 +496,8 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank) decoder = ModelRunner.from_engine( engine=engine, - # rank=world_config.rank, # We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process - # So we will set it to the current + # So we will set it to the current device rank=torch.cuda.current_device(), _disable_torch_cuda_device_set=True, )