Skip to content

Commit

Permalink
cleanup resharding
Browse files Browse the repository at this point in the history
Signed-off-by: root <worker@nvidia.com>
  • Loading branch information
root committed May 8, 2024
1 parent df7494d commit fb108a1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 55 deletions.
7 changes: 4 additions & 3 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def build(
mp_group = []
for idx in range(pp_size):
mp_group+=tp_groups[dp_rank + idx*dp_size]
device_ids = [i % gpus_per_node for i in mp_group]
# device_ids = [i % gpus_per_node for i in mp_group]
device_ids = mp_group

mapping = tensorrt_llm.Mapping(
world_size = tp_size*pp_size,
Expand All @@ -302,8 +303,8 @@ def build(
pp_rank {parallel_state.get_pipeline_model_parallel_rank()} -> {mapping.pp_rank}'''
)
print(f"{torch.distributed.get_rank()} color {dp_rank} rank {model_parallel_rank} nemo_mp_group {mp_group} {device_ids} ")
assert torch.cuda.current_device() == device_ids[model_parallel_rank]

# assert torch.cuda.current_device() == device_ids[model_parallel_rank]
model_config, weights = nemo_llm_model_to_model_config(
nemo_model=nemo_model,
tokenizer=self.tokenizer,
Expand Down
95 changes: 43 additions & 52 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def convert_nemo_model(nemo_model, nemo_model_config, tokenizer_vocab_size, res
pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
pp_group = parallel_state.get_pipeline_model_parallel_group()
pp_is_last = parallel_state.is_pipeline_last_stage()
pp_is_first = parallel_state.is_pipeline_first_stage()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
if not vp_size: vp_size = 1

Expand Down Expand Up @@ -392,21 +394,19 @@ def convert_nemo_model(nemo_model, nemo_model_config, tokenizer_vocab_size, res
if vp_size > 1: # consolidate params across model chunks
for idx, model_chunk in enumerate(nemo_model):
for key, val in model_chunk.state_dict().items():
if '_extra_state' in key:
continue
elif 'decoder.layers' in key:
key2 = rename_layer_num(key, get_layer_num(key) + idx*pp_size*layers_per_chunk)
tl_params[key2] = val
else:
model_level_params[key] = val
if torch.is_tensor(val):
if 'decoder.layers' in key:
key2 = rename_layer_num(key, get_layer_num(key) + idx*pp_size*layers_per_chunk)
tl_params[key2] = val
else:
model_level_params[key] = val
else:
for key, val in nemo_model.state_dict().items():
if '_extra_state' in key:
continue
elif 'decoder.layers' in key:
tl_params[key] = val
else:
model_level_params[key] = val
if torch.is_tensor(val):
if 'decoder.layers' in key:
tl_params[key] = val
else:
model_level_params[key] = val

if vp_size > 1 or reshard_model:
# gather layers across pp ranks
Expand Down Expand Up @@ -436,28 +436,35 @@ def convert_nemo_model(nemo_model, nemo_model_config, tokenizer_vocab_size, res
k: v for k, v in layer_params.items() if k.startswith("layers.")
}
for key, val in layer_params.items():
starmap_args.append(
{
starmap_args.append({
"key": key,
"val": val,
"config": export_config,
}
)
})

def broadcast_item(item, group, src_rank):
item = [item]
torch.distributed.broadcast_object_list(item, src_rank, group=group)
return item[0]

#broadcast a tensor across PP group and save it
def broadcast_save_weight(
def save_pp_weight(
src_key_or_tensor, dst_key, pp_src_idx, transpose_weights=False):

if (not reshard_model) or (reshard_model and torch.distributed.get_rank() == pp_src_idx):
if torch.is_tensor(src_key_or_tensor):
tensor = src_key_or_tensor
have_tensor = False
if torch.distributed.get_rank() == pp_src_idx:
if isinstance(src_key_or_tensor, str):
tensor = model_level_params.get(src_key_or_tensor, None)
have_tensor = torch.is_tensor(tensor)
else:
tensor = model_level_params[src_key_or_tensor]
assert torch.is_tensor(src_key_or_tensor)
tensor = src_key_or_tensor
have_tensor = True

if reshard_model:
have_tensor = broadcast_item(have_tensor, pp_group, pp_src_idx)
if not have_tensor:
return

if reshard_model:
if torch.distributed.get_rank() == pp_src_idx:
Expand All @@ -472,28 +479,20 @@ def broadcast_save_weight(

temp_config = dict(export_config)
temp_config['transpose_weights'] = transpose_weights
starmap_args.append(
{
starmap_args.append({
"key": dst_key,
"val": tensor,
"config": temp_config,
}
)

})
# ----------------Convert Final Layernorm----------------
if torch.distributed.get_rank() == pp_last_rank or reshard_model:
broadcast_save_weight(
if pp_is_last or reshard_model:
save_pp_weight(
get_layer_name("final_layernorm.weight", transformer_layer_prefix),
"ln_f.weight",
pp_last_rank,
transpose_weights=True
)

has_final_layer_bias = get_layer_name("final_layernorm.bias", transformer_layer_prefix) in model_level_params
if reshard_model:
has_final_layer_bias = broadcast_item(has_final_layer_bias, pp_group, pp_last_rank)
if has_final_layer_bias:
broadcast_save_weight(
save_pp_weight(
get_layer_name("final_layernorm.bias", transformer_layer_prefix),
"ln_f.bias",
pp_last_rank,
Expand All @@ -515,39 +514,32 @@ def remove_vocab_padding(tensor):
torch.distributed.all_reduce(gathered_tensor, group=tp_group)
return gathered_tensor[:tokenizer_vocab_size]

if torch.distributed.get_rank() == pp_first_rank:
world_embed = model_level_params[get_layer_name("word_embedding", prefix)]
if tp_size > 1:
if pp_is_first or reshard_model:
world_embed = model_level_params.get(get_layer_name("word_embedding", prefix), None)
if tp_size > 1 and pp_is_first:
world_embed = remove_vocab_padding(world_embed)
else:
world_embed = None

if torch.distributed.get_rank() == pp_first_rank or reshard_model:
broadcast_save_weight(
save_pp_weight(
world_embed,
"transformer.vocab_embedding.weight",
pp_first_rank,
transpose_weights=False,
)

if torch.distributed.get_rank() == pp_last_rank:
lm_head = model_level_params[get_layer_name("output_layer", prefix)]
if tp_size > 1:
if pp_is_last or reshard_model:
lm_head = model_level_params.get(get_layer_name("output_layer", prefix), None)
if tp_size > 1 and pp_is_last:
lm_head = remove_vocab_padding(lm_head)

vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
tokenizer_vocab_size, tp_rank, tp_size)
lm_head = lm_head[vocab_start_index:vocab_end_index]
else:
lm_head = None

if torch.distributed.get_rank() == pp_last_rank or reshard_model:
broadcast_save_weight(
save_pp_weight(
lm_head,
"lm_head.weight",
pp_last_rank,
transpose_weights=False,
)

tic = time.time()
for starmap_arg in tqdm(starmap_args, desc="saving weights"):
save_weight_torch(**starmap_arg)
Expand All @@ -556,7 +548,6 @@ def remove_vocab_padding(tensor):
return weights_dict



def create_out_dir(args):
out_dir = Path(args.out_dir)
if not out_dir.exists():
Expand Down

0 comments on commit fb108a1

Please sign in to comment.