From e2ca226bebb4d69d9463e143c9859deac6e95b65 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 15 May 2024 18:39:27 -0700 Subject: [PATCH] refix llama Signed-off-by: root --- nemo/export/tensorrt_llm.py | 12 ++----- nemo/export/trt_llm/nemo/convert.py | 36 +++++++++---------- nemo/export/trt_llm/nemo/nemo_ckpt_convert.py | 22 +++++------- nemo/export/trt_llm/tensorrt_llm_build.py | 6 +--- nemo/export/trt_llm/tensorrt_llm_run.py | 15 ++++++-- 5 files changed, 43 insertions(+), 48 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 31b6b9c25a2a..1a32c5cbcc2c 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -285,11 +285,7 @@ def build( # TRTLLM asserts that rank equals the device num however this # is not true for the megatron core mapping TP->DP->PP. # So we manipulate TRTLLM to emulate a TP->PP single node setup - tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank) - device_ids = [ - ((i+torch.cuda.current_device()-mp_rank)+mp_size)%mp_size - for i in range(mp_size)] - + tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank) mapping = tensorrt_llm.Mapping( world_size = mp_size, rank = mp_rank, @@ -303,7 +299,7 @@ def build( pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {mapping.pp_rank}''' ) mp_group_ranks = torch.distributed.distributed_c10d.get_process_group_ranks(mp_group) - print(f"{torch.distributed.get_rank()} color {dp_rank} mp_rank {mp_rank} mp_group_ranks {mp_group_ranks} device_ids {device_ids}") + print(f"{torch.distributed.get_rank()} color {dp_rank} mp_rank {mp_rank} mp_group_ranks {mp_group_ranks}") print(f"trtllm mpi : {tensorrt_llm.bindings.MpiComm.getRank()} {tensorrt_llm.bindings.MpiComm.getSize()}") model_config, weights = nemo_llm_model_to_model_config( @@ -338,9 +334,7 @@ def build( print_mem("post build_and_save_engine") - self.model_runner, self.session_params = load_refit( - engine_dir=self.model_dir, - device_ids=device_ids) + self.model_runner, self.session_params = load_refit(engine_dir=self.model_dir) print(f"device: {origdev} {torch.cuda.current_device()}") diff --git a/nemo/export/trt_llm/nemo/convert.py b/nemo/export/trt_llm/nemo/convert.py index 7034bdc30e39..30e6508fba43 100644 --- a/nemo/export/trt_llm/nemo/convert.py +++ b/nemo/export/trt_llm/nemo/convert.py @@ -394,7 +394,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t #Similar to split_save_weight but done on GPU for performance @torch.no_grad() -def save_weight_torch(key, val, config, weight_type): +def save_weight_torch(key, val, config): num_layers = config["num_layers"] storage_type = config["storage_type"] split_gated_activation = config["split_gated_activation"] @@ -433,32 +433,32 @@ def save(key, tensor, add_prefix=True): if "self_attention" in key: key = key.replace("self_attention", "attention") - if "attention.linear_qkv.layer_norm_weight" in key: - key = key.replace("attention.linear_qkv.layer_norm_weight", "input_layernorm.weight") - if "mlp.linear_fc1.layer_norm_weight" in key: - key = key.replace("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight") - if weight_type == 'layernorm_weight': + if ('layer_norm_weight' in key + or 'layernorm.weight' in key + or "final_layernorm.weight" in key + or "ln_f.weight" in key + ): + if "attention.linear_qkv.layer_norm_weight" in key: + key = key.replace("attention.linear_qkv.layer_norm_weight", "input_layernorm.weight") + elif "mlp.linear_fc1.layer_norm_weight" in key: + key = key.replace("mlp.linear_fc1.layer_norm_weight", "post_layernorm.weight") + elif "pre_mlp_layernorm.weight" in key: + key = key.replace("pre_mlp_layernorm.weight", "post_layernorm.weight") + if config.get("apply_layernorm_1p", False): val = val.float() + 1.0 save(key, val) elif ( - "input_layernorm.bias" in key - or "pre_mlp_layernorm.bias" in key - or "attention.dense.bias" in key - or "attention.linear_proj.bias" in key - or "post_attention_layernorm.bias" in key - or "mlp.dense_4h_to_h.bias" in key - or "mlp.linear_fc2.bias" in key + "input_layernorm.bias" in key + or "pre_mlp_layernorm.bias" in key or "ln_f.bias" in key or "vocab_embedding" in key ): - if "mlp.linear_fc2.bias" in key: - key = key.replace("mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias") - elif "attention.linear_proj.bias" in key: - key = key.replace("attention.linear_proj.bias", "attention.dense.bias") - save(key, val) + if "pre_mlp_layernorm.bias" in key: + key = key.replace("pre_mlp_layernorm.bias", "post_layernorm.bias") + save(key, val) elif ( "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key diff --git a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py index a43642c82898..9bbe95193a35 100644 --- a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py +++ b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py @@ -439,12 +439,12 @@ def convert_nemo_model( layer_params = { k: v for k, v in layer_params.items() if k.startswith("layers.") } + for key, val in layer_params.items(): starmap_args.append({ "key": key, "val": val, "config": export_config, - "weight_type" : 'layernorm_weight' if 'layernorm.weight' in key else None }) def broadcast_item(item, group, src_rank): @@ -454,7 +454,7 @@ def broadcast_item(item, group, src_rank): #broadcast a tensor across PP group and save it def save_pp_weight( - src_key_or_tensor, dst_key, pp_src_idx, transpose_weights=True, weight_type=None): + src_key_or_tensor, dst_key, pp_src_idx, transpose_weights=True): have_tensor = False if torch.distributed.get_rank() == pp_src_idx: @@ -484,12 +484,10 @@ def save_pp_weight( temp_config = dict(export_config) temp_config['transpose_weights'] = transpose_weights - temp_config['weight_type'] = weight_type starmap_args.append({ "key": dst_key, "val": tensor, "config": temp_config, - "weight_type": weight_type }) # ----------------Convert Final Layernorm---------------- if pp_is_last or reshard_model: @@ -497,7 +495,6 @@ def save_pp_weight( get_layer_name("final_layernorm.weight", transformer_layer_prefix), "ln_f.weight", pp_last_rank, - weight_type='layernorm_weight' ) save_pp_weight( get_layer_name("final_layernorm.bias", transformer_layer_prefix), @@ -548,16 +545,13 @@ def remove_vocab_padding(tensor): toc = time.time() print(f" weight save took {toc-tic}") - renamed_weight_dict = {} - if trt_model_type == 'GPTForCausalLM': - for key, val in weights_dict.items(): - if 'layernorm' in key: - new_key = key.replace("pre_mlp_layernorm", "post_layernorm") - else: - new_key = key - renamed_weight_dict[new_key] = val + print(f"{torch.cuda.current_device()} {pp_is_first} {pp_is_last}") + if torch.cuda.current_device() == 0: + import pdb + pdb.set_trace() + torch.distributed.barrier() - return renamed_weight_dict + return weights_dict diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index fe1f6dde6b24..d9d0db6b6f44 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -367,10 +367,6 @@ def build_and_save_engine( use_refit=True, trt_model_type='LLaMAForCausalLM' ): - '''Minimum implementation of TRTLLM 0.9's unified builder api''' - logger.set_level('info') - - try: model_cls = getattr(tensorrt_llm.models, trt_model_type) except: @@ -400,7 +396,7 @@ def build_and_save_engine( model = model_cls.from_config(model_config) model = optimize_model( model, - use_parallel_embedding=model_config.use_parallel_embedding, + use_parallel_embedding=True, share_embedding_table=model_config.share_embedding_table, ) preprocess_weights(model_weights, model_config) diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index b0f233e8871b..4ed5bb472598 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -331,7 +331,7 @@ def create_gpt_session( session_params.world_config, engine_data) -def load_refit(engine_dir, device_ids): +def load_refit(engine_dir): """Loaded the compiled LLM model and run it. It also supports running the TRT LLM model on multi-GPU. @@ -344,7 +344,16 @@ def load_refit(engine_dir, device_ids): tp_size = json_config.tensor_parallelism pp_size = json_config.pipeline_parallelism mp_size = tp_size*pp_size - world_config = WorldConfig.mpi(gpus_per_node=999, #Unused so just choose a big number to avoid asserts + + # TRTLLM assumes rank < gpus_per_node but this is not true for multinode setups + # So hack around this using an arbitrarily big gpus_per_node to avoid asserts + gpus_per_node = 9999 + mp_rank = tensorrt_llm.bindings.MpiComm.getRank() + device_ids = [ + (i+torch.cuda.current_device()-mp_rank) for i in range(mp_size)] + print(f"{torch.cuda.current_device()} device_ids {device_ids}") + + world_config = WorldConfig.mpi(gpus_per_node=gpus_per_node, tensor_parallelism=tp_size, pipeline_parallelism=pp_size, device_ids=device_ids) @@ -363,6 +372,8 @@ def load_refit(engine_dir, device_ids): session_config = GptSessionConfig(max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_sequence_length=max_seq_len) + session_config.gen_micro_batch_size = max_batch_size + session_config.ctx_micro_batch_size = max_batch_size session_config.kv_cache_config = KvCacheConfig( max_tokens=max_seq_len*max_batch_size, max_attention_window=max_seq_len