Skip to content

Commit

Permalink
llama refix
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 16, 2024
1 parent bed73ec commit 541532d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 50 deletions.
12 changes: 3 additions & 9 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,7 @@ def build(
# TRTLLM asserts that rank equals the device num however this
# is not true for the megatron core mapping TP->DP->PP.
# So we manipulate TRTLLM to emulate a TP->PP single node setup
tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank)
device_ids = [
((i+torch.cuda.current_device()-mp_rank)+mp_size)%mp_size
for i in range(mp_size)]

tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank)
mapping = tensorrt_llm.Mapping(
world_size = mp_size,
rank = mp_rank,
Expand All @@ -303,7 +299,7 @@ def build(
pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {mapping.pp_rank}'''
)
mp_group_ranks = torch.distributed.distributed_c10d.get_process_group_ranks(mp_group)
print(f"{torch.distributed.get_rank()} color {dp_rank} mp_rank {mp_rank} mp_group_ranks {mp_group_ranks} device_ids {device_ids}")
print(f"{torch.distributed.get_rank()} color {dp_rank} mp_rank {mp_rank} mp_group_ranks {mp_group_ranks}")
print(f"trtllm mpi : {tensorrt_llm.bindings.MpiComm.getRank()} {tensorrt_llm.bindings.MpiComm.getSize()}")

model_config, weights = nemo_llm_model_to_model_config(
Expand Down Expand Up @@ -338,9 +334,7 @@ def build(

print_mem("post build_and_save_engine")

self.model_runner, self.session_params = load_refit(
engine_dir=self.model_dir,
device_ids=device_ids)
self.model_runner, self.session_params = load_refit(engine_dir=self.model_dir)

print(f"device: {origdev} {torch.cuda.current_device()}")

Expand Down
36 changes: 18 additions & 18 deletions nemo/export/trt_llm/nemo/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t

#Similar to split_save_weight but done on GPU for performance
@torch.no_grad()
def save_weight_torch(key, val, config, weight_type):
def save_weight_torch(key, val, config):
num_layers = config["num_layers"]
storage_type = config["storage_type"]
split_gated_activation = config["split_gated_activation"]
Expand Down Expand Up @@ -433,32 +433,32 @@ def save(key, tensor, add_prefix=True):

if "self_attention" in key:
key = key.replace("self_attention", "attention")
if "attention.linear_qkv.layer_norm_weight" in key:
key = key.replace("attention.linear_qkv.layer_norm_weight", "input_layernorm.weight")
if "mlp.linear_fc1.layer_norm_weight" in key:
key = key.replace("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight")

if weight_type == 'layernorm_weight':
if ('layer_norm_weight' in key
or 'layernorm.weight' in key
or "final_layernorm.weight" in key
or "ln_f.weight" in key
):
if "attention.linear_qkv.layer_norm_weight" in key:
key = key.replace("attention.linear_qkv.layer_norm_weight", "input_layernorm.weight")
elif "mlp.linear_fc1.layer_norm_weight" in key:
key = key.replace("mlp.linear_fc1.layer_norm_weight", "post_layernorm.weight")
elif "pre_mlp_layernorm.weight" in key:
key = key.replace("pre_mlp_layernorm.weight", "post_layernorm.weight")

if config.get("apply_layernorm_1p", False):
val = val.float() + 1.0
save(key, val)
elif (
"input_layernorm.bias" in key
or "pre_mlp_layernorm.bias" in key
or "attention.dense.bias" in key
or "attention.linear_proj.bias" in key
or "post_attention_layernorm.bias" in key
or "mlp.dense_4h_to_h.bias" in key
or "mlp.linear_fc2.bias" in key
"input_layernorm.bias" in key
or "pre_mlp_layernorm.bias" in key
or "ln_f.bias" in key
or "vocab_embedding" in key
):
if "mlp.linear_fc2.bias" in key:
key = key.replace("mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias")
elif "attention.linear_proj.bias" in key:
key = key.replace("attention.linear_proj.bias", "attention.dense.bias")
save(key, val)
if "pre_mlp_layernorm.bias" in key:
key = key.replace("pre_mlp_layernorm.bias", "post_layernorm.bias")

save(key, val)
elif (
"attention.dense.weight" in key
or "mlp.dense_4h_to_h.weight" in key
Expand Down
20 changes: 4 additions & 16 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,12 @@ def convert_nemo_model(
layer_params = {
k: v for k, v in layer_params.items() if k.startswith("layers.")
}

for key, val in layer_params.items():
starmap_args.append({
"key": key,
"val": val,
"config": export_config,
"weight_type" : 'layernorm_weight' if 'layernorm.weight' in key else None
})

def broadcast_item(item, group, src_rank):
Expand All @@ -454,7 +454,7 @@ def broadcast_item(item, group, src_rank):

#broadcast a tensor across PP group and save it
def save_pp_weight(
src_key_or_tensor, dst_key, pp_src_idx, transpose_weights=True, weight_type=None):
src_key_or_tensor, dst_key, pp_src_idx, transpose_weights=True):

have_tensor = False
if torch.distributed.get_rank() == pp_src_idx:
Expand Down Expand Up @@ -484,20 +484,17 @@ def save_pp_weight(

temp_config = dict(export_config)
temp_config['transpose_weights'] = transpose_weights
temp_config['weight_type'] = weight_type
starmap_args.append({
"key": dst_key,
"val": tensor,
"config": temp_config,
"weight_type": weight_type
})
# ----------------Convert Final Layernorm----------------
if pp_is_last or reshard_model:
save_pp_weight(
get_layer_name("final_layernorm.weight", transformer_layer_prefix),
"ln_f.weight",
pp_last_rank,
weight_type='layernorm_weight'
)
save_pp_weight(
get_layer_name("final_layernorm.bias", transformer_layer_prefix),
Expand Down Expand Up @@ -547,17 +544,8 @@ def remove_vocab_padding(tensor):
save_weight_torch(**starmap_arg)
toc = time.time()
print(f" weight save took {toc-tic}")

renamed_weight_dict = {}
if trt_model_type == 'GPTForCausalLM':
for key, val in weights_dict.items():
if 'layernorm' in key:
new_key = key.replace("pre_mlp_layernorm", "post_layernorm")
else:
new_key = key
renamed_weight_dict[new_key] = val

return renamed_weight_dict
print(f"{torch.cuda.current_device()} {pp_is_first} {pp_is_last}")
return weights_dict



Expand Down
6 changes: 1 addition & 5 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,6 @@ def build_and_save_engine(
use_refit=True,
trt_model_type='LLaMAForCausalLM'
):
'''Minimum implementation of TRTLLM 0.9's unified builder api'''
logger.set_level('info')


try:
model_cls = getattr(tensorrt_llm.models, trt_model_type)
except:
Expand Down Expand Up @@ -400,7 +396,7 @@ def build_and_save_engine(
model = model_cls.from_config(model_config)
model = optimize_model(
model,
use_parallel_embedding=model_config.use_parallel_embedding,
use_parallel_embedding=True,
share_embedding_table=model_config.share_embedding_table,
)
preprocess_weights(model_weights, model_config)
Expand Down
15 changes: 13 additions & 2 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def create_gpt_session(
session_params.world_config,
engine_data)

def load_refit(engine_dir, device_ids):
def load_refit(engine_dir):
"""Loaded the compiled LLM model and run it.
It also supports running the TRT LLM model on multi-GPU.
Expand All @@ -344,7 +344,16 @@ def load_refit(engine_dir, device_ids):
tp_size = json_config.tensor_parallelism
pp_size = json_config.pipeline_parallelism
mp_size = tp_size*pp_size
world_config = WorldConfig.mpi(gpus_per_node=999, #Unused so just choose a big number to avoid asserts

# TRTLLM assumes rank < gpus_per_node but this is not true for multinode setups
# So hack around this using an arbitrarily big gpus_per_node to avoid asserts
gpus_per_node = 9999
mp_rank = tensorrt_llm.bindings.MpiComm.getRank()
device_ids = [
(i+torch.cuda.current_device()-mp_rank) for i in range(mp_size)]
print(f"{torch.cuda.current_device()} device_ids {device_ids}")

world_config = WorldConfig.mpi(gpus_per_node=gpus_per_node,
tensor_parallelism=tp_size,
pipeline_parallelism=pp_size,
device_ids=device_ids)
Expand All @@ -363,6 +372,8 @@ def load_refit(engine_dir, device_ids):
session_config = GptSessionConfig(max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_sequence_length=max_seq_len)
session_config.gen_micro_batch_size = max_batch_size
session_config.ctx_micro_batch_size = max_batch_size
session_config.kv_cache_config = KvCacheConfig(
max_tokens=max_seq_len*max_batch_size,
max_attention_window=max_seq_len
Expand Down

0 comments on commit 541532d

Please sign in to comment.