Skip to content

Commit

Permalink
fix laternorm1p and cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed May 10, 2024
1 parent 9f1a00f commit fc1bbf0
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 69 deletions.
14 changes: 7 additions & 7 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,8 @@ def build(
# So we manipulate TRTLLM to emulate a TP->PP single node setup
tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank)
device_ids = [
(i+torch.cuda.current_device()-mp_rank) % mp_size
(i+torch.cuda.current_device()-mp_rank)
for i in range(mp_size)]
assert device_ids[mp_rank] == torch.cuda.current_device()

mapping = tensorrt_llm.Mapping(
world_size = mp_size,
Expand Down Expand Up @@ -330,20 +329,21 @@ def build(
)
torch.distributed.barrier()
print(f"engine saved to {self.model_dir}")

if torch.cuda.current_device() == 0:
with open(os.path.join(self.model_dir, 'config.json'),
"w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)
cfg_path = Path(os.path.join(self.model_dir, 'config.json'))
if not cfg_path.exists():
with open(cfg_path, "w", encoding="utf-8") as f:
json.dump(engine.config.to_dict(), f, indent=4)

print_mem("post build_and_save_engine")

self.model_runner, self.session_params = load_refit(
engine_dir=self.model_dir,
device_ids=device_ids)

print(f"device: {origdev} {torch.cuda.current_device()}")


def refit(
self,
nemo_model,
Expand Down
101 changes: 47 additions & 54 deletions nemo/export/trt_llm/nemo/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def gpu_map_location(storage, loc):
raise ValueError(f"Not handled {loc}")


def save_val(val, dir, key, tp_num=None):
def tensor(val, dir, key, tp_num=None):
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
# Transpose linear layer weights to the correct shape.
if len(val.shape) >= 2:
Expand All @@ -49,7 +49,7 @@ def save_val(val, dir, key, tp_num=None):

def save_split(split_vals, dir, key, i, split_factor):
for j, val in enumerate(split_vals):
save_val(val, dir, key, i * split_factor + j)
tensor(val, dir, key, i * split_factor + j)


def save_expert_split(split_vals, dir, key, i, split_factor):
Expand Down Expand Up @@ -165,7 +165,7 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o

if tp_rank == 0:
for save_key in saved_keys_once:
save_val(vals[save_key], dir, f"{base_key}.{save_key}")
tensor(vals[save_key], dir, f"{base_key}.{save_key}")


# Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head
Expand Down Expand Up @@ -217,7 +217,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
elif "attention.linear_proj.bias" in key:
key = key.replace("attention.linear_proj.bias", "attention.dense.bias")
if tp_rank == 0:
save_val(vals[0], saved_dir, key)
tensor(vals[0], saved_dir, key)

elif (
"attention.dense.weight" in key
Expand Down Expand Up @@ -394,8 +394,28 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t

#Similar to split_save_weight but done on GPU for performance
@torch.no_grad()
def save_weight_torch(key, val, config):
def save(key, tensor):
def save_weight_torch(key, val, config, weight_type):
num_layers = config["num_layers"]
storage_type = config["storage_type"]
split_gated_activation = config["split_gated_activation"]
num_attention_heads = config["num_attention_heads"]
tp_size = config["tp_size"]
tp_rank = config["tp_rank"]
num_kv_heads = config["num_kv_heads"]
move_to_cpu = config["move_to_cpu"]
save_dict = config["save_dict"]

def save(key, tensor, add_prefix=True):
assert torch.is_tensor(tensor)
if add_prefix:
key = f"transformer.{key}"

if len(tensor.shape) >= 2:
tensor = tensor.reshape(tensor.shape[0], -1)
tensor = torch.transpose(tensor, 0 , 1)
tensor = tensor.detach().contiguous()
tensor = tensor.to(storage_type)

if move_to_cpu:
if key not in save_dict:
cpu_copy = torch.empty(
Expand All @@ -408,15 +428,8 @@ def save(key, tensor):
else:
save_dict[key] = tensor.cuda()

def save_tranpose(val, key, shared=False):
assert torch.is_tensor(val)
key = f"transformer.{key}"

if len(val.shape) >= 2:
val = val.reshape(val.shape[0], -1)
val = torch.transpose(val, 0 , 1)
val = val.detach().contiguous()
save(key, val)
if config.get("transpose_weights", False) and val.ndim == 2:
val = val.T

if "self_attention" in key:
key = key.replace("self_attention", "attention")
Expand All @@ -425,47 +438,26 @@ def save_tranpose(val, key, shared=False):
if "mlp.linear_fc1.layer_norm_weight" in key:
key = key.replace("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight")

num_layers = config["num_layers"]
storage_type = config["storage_type"]
split_gated_activation = config["split_gated_activation"]
num_attention_heads = config["num_attention_heads"]
tp_size = config["tp_size"]
tp_rank = config["tp_rank"]
num_kv_heads = config["num_kv_heads"]
move_to_cpu = config["move_to_cpu"]
save_dict = config["save_dict"]

if config.get("transpose_weights", False) and val.ndim == 2:
val = val.T
if "layernorm.weight" in key and config.get("apply_layernorm_1p", False):
val = val + 1.0
gpu_val = val.to(storage_type)

if (
"input_layernorm.weight" in key
or "input_layernorm.bias" in key
or "pre_mlp_layernorm.weight" in key
if weight_type == 'layernorm_weight':
if config.get("apply_layernorm_1p", False):
val = val.float() + 1.0
save(key, val)
elif (
"input_layernorm.bias" in key
or "pre_mlp_layernorm.bias" in key
or "attention.dense.bias" in key
or "attention.linear_proj.bias" in key
or "post_attention_layernorm.weight" in key
or "post_attention_layernorm.bias" in key
or "post_self_attn_layernorm.weight" in key
or "mlp.dense_4h_to_h.bias" in key
or "mlp.linear_fc2.bias" in key
or "ln_f.weight" in key
or "ln_f.bias" in key
or "vocab_embedding" in key
):
if "post_self_attn_layernorm.weight" in key:
key = key.replace("post_self_attn_layernorm.weight", "post_attention_layernorm.weight")
elif "mlp.linear_fc2.bias" in key:
if "mlp.linear_fc2.bias" in key:
key = key.replace("mlp.linear_fc2.bias", "mlp.dense_4h_to_h.bias")
elif "attention.linear_proj.bias" in key:
key = key.replace("attention.linear_proj.bias", "attention.dense.bias")
elif "post_attention_layernorm.weight" in key:
key = key.replace("post_attention_layernorm.weight", "post_layernorm.weight")

save_tranpose(gpu_val, key, shared=True)
save(key, val)

elif (
"attention.dense.weight" in key
Expand All @@ -479,7 +471,7 @@ def save_tranpose(val, key, shared=False):
key = key.replace("attention.linear_proj.weight", "attention.dense.weight")
elif "mlp.linear_fc2.weight" in key:
key = key.replace("mlp.linear_fc2.weight", "mlp.proj.weight")
save_tranpose(gpu_val, key)
save(key, val)

elif (
"mlp.dense_h_to_4h.weight" in key
Expand All @@ -488,15 +480,15 @@ def save_tranpose(val, key, shared=False):
or "mlp.linear_fc1.bias" in key
):
if split_gated_activation:
val, gate = torch.chunk(gpu_val, 2, axis=-1)
val, gate = torch.chunk(val, 2, axis=-1)

if "mlp.linear_fc1" in key:
key = key.replace("mlp.linear_fc1", "mlp.fc")
save_tranpose(val, key)
save(key, val)

if split_gated_activation:
key = key.replace("mlp.fc", "mlp.gate")
save_tranpose(gate, key)
save(key, gate)

elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key:
if "attention.linear_qkv.weight" in key:
Expand All @@ -506,19 +498,20 @@ def save_tranpose(val, key, shared=False):
size_per_head = hidden_dim // num_attention_heads
q_num = num_attention_heads // num_kv_heads

gpu_val = gpu_val.reshape(hidden_dim, num_kv_heads // tp_size, q_num + 2, size_per_head)
val = val.reshape(hidden_dim, num_kv_heads // tp_size, q_num + 2, size_per_head)

# Split the QKV to separate variables.
#[qqqqkkvv] - > [qqqq,kk,vv]
qkv = torch.split(gpu_val, [q_num, 1, 1], dim=2)
qkv = torch.split(val, [q_num, 1, 1], dim=2)
split_vals = torch.concatenate([
qkv[0].reshape(hidden_dim, -1),
qkv[1].reshape(hidden_dim, -1),
qkv[2].reshape(hidden_dim, -1)
], dim=1)
save_tranpose(split_vals, key)
elif "vocab_embedding" in key or "lm_head.weight" in key:
save(key, gpu_val)
save(key, split_vals)

elif "lm_head.weight" in key:
save(key, val, add_prefix=False)
else:
raise RuntimeError(f"{key} not handled by NeMo->TRTLLM converter!")

20 changes: 13 additions & 7 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ def convert_nemo_model(
"storage_type": storage_type,
"move_to_cpu": cpu,
"save_dict": weights_dict,
"tp_rank": tp_rank
"tp_rank": tp_rank,
"weight_type" : None,
}

tl_params = {}
Expand Down Expand Up @@ -447,6 +448,7 @@ def convert_nemo_model(
"key": key,
"val": val,
"config": export_config,
"weight_type" : 'layernorm_weight' if 'layernorm.weight' in key else None
})

def broadcast_item(item, group, src_rank):
Expand All @@ -456,7 +458,7 @@ def broadcast_item(item, group, src_rank):

#broadcast a tensor across PP group and save it
def save_pp_weight(
src_key_or_tensor, dst_key, pp_src_idx, transpose_weights=False):
src_key_or_tensor, dst_key, pp_src_idx, transpose_weights=False, weight_type=None):

have_tensor = False
if torch.distributed.get_rank() == pp_src_idx:
Expand Down Expand Up @@ -486,24 +488,27 @@ def save_pp_weight(

temp_config = dict(export_config)
temp_config['transpose_weights'] = transpose_weights
temp_config['weight_type'] = weight_type
starmap_args.append({
"key": dst_key,
"val": tensor,
"config": temp_config,
"weight_type": weight_type
})
# ----------------Convert Final Layernorm----------------
if pp_is_last or reshard_model:
save_pp_weight(
get_layer_name("final_layernorm.weight", transformer_layer_prefix),
"ln_f.weight",
pp_last_rank,
transpose_weights=True
transpose_weights=True,
weight_type='layernorm_weight'
)
save_pp_weight(
get_layer_name("final_layernorm.bias", transformer_layer_prefix),
"ln_f.bias",
pp_last_rank,
transpose_weights=True
transpose_weights=True,
)

# ----------------Convert Embeddings----------------
Expand All @@ -525,9 +530,9 @@ def remove_vocab_padding(tensor):
world_embed = remove_vocab_padding(world_embed)
save_pp_weight(
world_embed,
"transformer.vocab_embedding.weight",
"vocab_embedding.weight",
pp_first_rank,
transpose_weights=False,
transpose_weights=True,
)

if pp_is_last or reshard_model:
Expand All @@ -542,7 +547,7 @@ def remove_vocab_padding(tensor):
lm_head,
"lm_head.weight",
pp_last_rank,
transpose_weights=False,
transpose_weights=True,
)

tic = time.time()
Expand All @@ -559,6 +564,7 @@ def remove_vocab_padding(tensor):
else:
new_key = key
renamed_weight_dict[new_key] = val

return renamed_weight_dict


Expand Down
3 changes: 2 additions & 1 deletion nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,12 @@ def load_refit(engine_dir, device_ids):
tp_size = json_config.tensor_parallelism
pp_size = json_config.pipeline_parallelism
mp_size = tp_size*pp_size
world_config = WorldConfig.mpi(gpus_per_node=mp_size,
world_config = WorldConfig.mpi(gpus_per_node=999, #Unused so just choose a big number to avoid asserts
tensor_parallelism=tp_size,
pipeline_parallelism=pp_size,
device_ids=device_ids)


assert torch.cuda.current_device() == world_config.device
engine_filename = json_config.engine_filename(world_config)
serialize_path = Path(engine_dir) / engine_filename
Expand Down

0 comments on commit fc1bbf0

Please sign in to comment.