diff --git a/nemo/export/trt_llm/nemo/convert.py b/nemo/export/trt_llm/nemo/convert.py index 9b40ff59e29d..5558150bb899 100644 --- a/nemo/export/trt_llm/nemo/convert.py +++ b/nemo/export/trt_llm/nemo/convert.py @@ -35,12 +35,18 @@ def gpu_map_location(storage, loc): raise ValueError(f"Not handled {loc}") -def save_val(val, save_dict, key, tp_num=None): +def save_val(val, save_dict, key, tp_num=None, use_gpu=False, storage_type=None): suffix = "bin" if tp_num is None else f"{tp_num}.bin" # AMMO modification, save to in-memory dict instead of dir. # Transpose linear layer weights to the correct shape. if len(val.shape) >= 2: - val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) + if use_gpu: + # val = np.ascontiguousarray(torch_to_numpy(torch.transpose(val.reshape(val.shape[0], -1), 0, 1).cpu().to(storage_type))) + val = torch_to_numpy(torch.transpose(val.reshape(val.shape[0], -1), 0, 1).contiguous().cpu().to(storage_type)) + print(f'gpu try {val.data.contiguous}') + else: + val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) + print(val.data.contiguous) if type(save_dict) is dict: save_dict[f"model.{key}.{suffix}"] = val @@ -49,9 +55,9 @@ def save_val(val, save_dict, key, tp_num=None): weights_dict[f"model.{key}.{suffix}"] = val -def save_split(split_vals, dir, key, i, split_factor): +def save_split(split_vals, dir, key, i, split_factor, use_gpu=False, storage_type=None): for j, val in enumerate(split_vals): - save_val(val, dir, key, i * split_factor + j) + save_val(val, dir, key, i * split_factor + j, use_gpu=use_gpu, storage_type=storage_type) def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): @@ -188,8 +194,10 @@ def split_and_save_weight( vals = [val + 1.0 for val in vals] if torch.is_tensor(vals[0]): - vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals] - + if not ("attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key): + vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals] + else: + vals = [val.to(torch.cuda.current_device()) for val in vals] if ( "input_layernorm.weight" in key or "input_layernorm.bias" in key @@ -317,23 +325,23 @@ def split_and_save_weight( # We first concat all sub weights per tp rank together. len_vals = len(vals) - val = np.concatenate(vals, axis=1) + val = torch.cat(vals, dim=1) - val = val.reshape(hidden_dim, num_kv_heads * len_vals // tp_size, q_num + 2, size_per_head) + val = val.reshape((hidden_dim, num_kv_heads * len_vals // tp_size, q_num + 2, size_per_head)) # Split the QKV to separate variables. - qkv = np.split(val, [q_num, q_num + 1], axis=2) + qkv = torch.split(val, [q_num, 1, 1], dim=2) - q_split = np.split(qkv[0], split_factor, axis=1) - k_split = np.split(qkv[1], split_factor, axis=1) - v_split = np.split(qkv[2], split_factor, axis=1) + q_split = torch.split(qkv[0], qkv[0].size(1) // split_factor, dim=1) + k_split = torch.split(qkv[1], qkv[1].size(1) // split_factor, dim=1) + v_split = torch.split(qkv[2], qkv[2].size(1) // split_factor, dim=1) # Concatenate Q, K, and V together - split_vals = [np.concatenate([q_split[i].reshape(hidden_dim, -1), k_split[i].reshape(hidden_dim, -1), v_split[i].reshape(hidden_dim, -1)], axis=1) for i in range(split_factor)] - + split_vals = [torch.cat([q_split[i].reshape((hidden_dim, -1)), k_split[i].reshape((hidden_dim, -1)), v_split[i].reshape((hidden_dim, -1))], dim=1) for i in range(split_factor)] + if "attention.linear_qkv.weight" in key: key = key.replace("attention.linear_qkv.weight", "attention.query_key_value.weight") - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + save_split(split_vals, saved_dir, key, tp_rank, split_factor, use_gpu=True, storage_type=storage_type) if save_int8: base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, is_qkv=True, multi_query_mode=multi_query_mode) @@ -354,7 +362,8 @@ def split_and_save_weight( ): pass else: - print(f"[WARNING] {key} not handled by converter") + pass + # print(f"[WARNING] {key} not handled by converter") # Ammo modification global weights_dict diff --git a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py index 743f014bc3bc..eecf5697e8b2 100644 --- a/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py +++ b/nemo/export/trt_llm/nemo/nemo_ckpt_convert.py @@ -608,6 +608,9 @@ def _handle_weights(src_key: str, dst_key: str, pp_src_idx: int, tensor_dim: int "share_weights" : False, } ) + for starmap_arg in starmap_args: + starmap_arg["vals"] = [v.to("cpu", non_blocking=True) for v in starmap_arg["vals"]] + starmap_args = tqdm(starmap_args, desc="saving weights") for starmap_arg in starmap_args: split_and_save_weight(**starmap_arg)