Skip to content

Commit

Permalink
Added CPU async and GPU QKV manipulation
Browse files Browse the repository at this point in the history
  • Loading branch information
Sahil Jain committed Jan 23, 2024
1 parent 4bcf1cf commit f7aea2e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
41 changes: 25 additions & 16 deletions nemo/export/trt_llm/nemo/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,18 @@ def gpu_map_location(storage, loc):
raise ValueError(f"Not handled {loc}")


def save_val(val, save_dict, key, tp_num=None):
def save_val(val, save_dict, key, tp_num=None, use_gpu=False, storage_type=None):
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
# AMMO modification, save to in-memory dict instead of dir.
# Transpose linear layer weights to the correct shape.
if len(val.shape) >= 2:
val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0]))
if use_gpu:
# val = np.ascontiguousarray(torch_to_numpy(torch.transpose(val.reshape(val.shape[0], -1), 0, 1).cpu().to(storage_type)))
val = torch_to_numpy(torch.transpose(val.reshape(val.shape[0], -1), 0, 1).contiguous().cpu().to(storage_type))
print(f'gpu try {val.data.contiguous}')
else:
val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0]))
print(val.data.contiguous)

if type(save_dict) is dict:
save_dict[f"model.{key}.{suffix}"] = val
Expand All @@ -49,9 +55,9 @@ def save_val(val, save_dict, key, tp_num=None):
weights_dict[f"model.{key}.{suffix}"] = val


def save_split(split_vals, dir, key, i, split_factor):
def save_split(split_vals, dir, key, i, split_factor, use_gpu=False, storage_type=None):
for j, val in enumerate(split_vals):
save_val(val, dir, key, i * split_factor + j)
save_val(val, dir, key, i * split_factor + j, use_gpu=use_gpu, storage_type=storage_type)


def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
Expand Down Expand Up @@ -188,8 +194,10 @@ def split_and_save_weight(
vals = [val + 1.0 for val in vals]

if torch.is_tensor(vals[0]):
vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals]

if not ("attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key):
vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals]
else:
vals = [val.to(torch.cuda.current_device()) for val in vals]
if (
"input_layernorm.weight" in key
or "input_layernorm.bias" in key
Expand Down Expand Up @@ -317,23 +325,23 @@ def split_and_save_weight(

# We first concat all sub weights per tp rank together.
len_vals = len(vals)
val = np.concatenate(vals, axis=1)
val = torch.cat(vals, dim=1)

val = val.reshape(hidden_dim, num_kv_heads * len_vals // tp_size, q_num + 2, size_per_head)
val = val.reshape((hidden_dim, num_kv_heads * len_vals // tp_size, q_num + 2, size_per_head))

# Split the QKV to separate variables.
qkv = np.split(val, [q_num, q_num + 1], axis=2)
qkv = torch.split(val, [q_num, 1, 1], dim=2)

q_split = np.split(qkv[0], split_factor, axis=1)
k_split = np.split(qkv[1], split_factor, axis=1)
v_split = np.split(qkv[2], split_factor, axis=1)
q_split = torch.split(qkv[0], qkv[0].size(1) // split_factor, dim=1)
k_split = torch.split(qkv[1], qkv[1].size(1) // split_factor, dim=1)
v_split = torch.split(qkv[2], qkv[2].size(1) // split_factor, dim=1)

# Concatenate Q, K, and V together
split_vals = [np.concatenate([q_split[i].reshape(hidden_dim, -1), k_split[i].reshape(hidden_dim, -1), v_split[i].reshape(hidden_dim, -1)], axis=1) for i in range(split_factor)]

split_vals = [torch.cat([q_split[i].reshape((hidden_dim, -1)), k_split[i].reshape((hidden_dim, -1)), v_split[i].reshape((hidden_dim, -1))], dim=1) for i in range(split_factor)]
if "attention.linear_qkv.weight" in key:
key = key.replace("attention.linear_qkv.weight", "attention.query_key_value.weight")
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
save_split(split_vals, saved_dir, key, tp_rank, split_factor, use_gpu=True, storage_type=storage_type)
if save_int8:
base_key = key.replace(".weight", "")
vals_i8 = generate_int8(val, act_range, is_qkv=True, multi_query_mode=multi_query_mode)
Expand All @@ -354,7 +362,8 @@ def split_and_save_weight(
):
pass
else:
print(f"[WARNING] {key} not handled by converter")
pass
# print(f"[WARNING] {key} not handled by converter")

# Ammo modification
global weights_dict
Expand Down
3 changes: 3 additions & 0 deletions nemo/export/trt_llm/nemo/nemo_ckpt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,9 @@ def _handle_weights(src_key: str, dst_key: str, pp_src_idx: int, tensor_dim: int
"share_weights" : False,
}
)
for starmap_arg in starmap_args:
starmap_arg["vals"] = [v.to("cpu", non_blocking=True) for v in starmap_arg["vals"]]

starmap_args = tqdm(starmap_args, desc="saving weights")
for starmap_arg in starmap_args:
split_and_save_weight(**starmap_arg)
Expand Down

0 comments on commit f7aea2e

Please sign in to comment.