From f132876c41915483d07e6424addc9b5e4857ec4c Mon Sep 17 00:00:00 2001 From: billishyahao Date: Sat, 10 Aug 2024 23:45:27 +0800 Subject: [PATCH] [LLaMa] Adding support converting checkpoint from mds to hf (#432) * add support converting checkpoint from hf to mds * Fix PP issue * update --- .../finetune_hf_llama/ds_config.json | 8 +- .../finetune_hf_llama/finetune_llama.sh | 23 +- megatron/global_vars.py | 1 + tools/hf2megads_weight_converter.py | 306 +++++++++++++++--- 4 files changed, 280 insertions(+), 58 deletions(-) diff --git a/examples_deepspeed/finetune_hf_llama/ds_config.json b/examples_deepspeed/finetune_hf_llama/ds_config.json index 9c0b332473..85f439ce47 100755 --- a/examples_deepspeed/finetune_hf_llama/ds_config.json +++ b/examples_deepspeed/finetune_hf_llama/ds_config.json @@ -1,11 +1,5 @@ { "train_batch_size" : 256, "train_micro_batch_size_per_gpu": 16, - "steps_per_print": 100, - "zero_optimization": { - "stage": 0 - }, - "bf16": { - "enabled": true - } + "steps_per_print": 1 } diff --git a/examples_deepspeed/finetune_hf_llama/finetune_llama.sh b/examples_deepspeed/finetune_hf_llama/finetune_llama.sh index c48ea11b93..ab8bfdf419 100644 --- a/examples_deepspeed/finetune_hf_llama/finetune_llama.sh +++ b/examples_deepspeed/finetune_hf_llama/finetune_llama.sh @@ -1,8 +1,8 @@ DS_CONFIG=./examples_deepspeed/finetune_hf_llama/ds_config.json -DATASET_PATH=./alpaca_data.json +DATASET_PATH=./examples_deepspeed/finetune_hf_llama/alpaca_data.json # dataset link: https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json -HF_LLAMA_PATH=/data/llama-7b/ +HF_LLAMA_PATH=/data/llama-2-7b-hf/ # weights link: https://huggingface.co/huggyllama/llama-7b MICRO_BATCH_SIZE=16 @@ -44,11 +44,20 @@ cat < $DS_CONFIG EOT -covert_args="deepspeed tools/hf2megads_weight_converter.py \ +covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \ --hf-ckpt-num-shards 2 \ ---origin-hf-ckpt-dir $HF_LLAMA_PATH \ +--hf-ckpt-dir $HF_LLAMA_PATH \ +--load-mode auto \ --save $MEGA_DS_LLAMA_PATH" +covert_mds2hf_args="deepspeed tools/hf2megads_weight_converter.py \ +--hf-ckpt-num-shards 2 \ +--hf-ckpt-dir $HF_LLAMA_PATH \ +--load-mode auto \ +--to-hf-ckpt \ +--load $MEGA_DS_LLAMA_PATH \ +--save $HF_LLAMA_PATH'-hf-out' " + finetune_args="deepspeed finetune_llama.py \ --load $MEGA_DS_LLAMA_PATH" @@ -98,8 +107,10 @@ comm_args="--tensor-model-parallel-size $TP \ --no-gradient-accumulation-fusion \ --repeated-dataloader" -if [ "$1" = "convert" ]; then - task_args="$covert_args" +if [ "$1" = "convert_hf2mds" ]; then + task_args="$covert_hf2mds_args" +elif [ "$1" = "convert_mds2hf" ]; then + task_args="$covert_mds2hf_args" else task_args="$finetune_args" fi diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 3f9d6fd66b..ccd0a4e21e 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -175,6 +175,7 @@ def _set_wandb_writer(args): 'project or experiment name provided, ' 'therefore WANDB logs will be written ' 'according to random generated project or experiment name.', flush=True) + return try: import wandb diff --git a/tools/hf2megads_weight_converter.py b/tools/hf2megads_weight_converter.py index bfbde1fd05..12468963c5 100755 --- a/tools/hf2megads_weight_converter.py +++ b/tools/hf2megads_weight_converter.py @@ -3,9 +3,11 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import print_rank_0, get_tokenizer, get_args from megatron.core import mpu +from megatron.core import tensor_parallel from megatron.core.utils import divide from megatron.model import GPTModelPipe, Float16Module from megatron.utils import unwrap_model @@ -13,20 +15,30 @@ from megatron.arguments import core_transformer_config_from_args from megatron.initialize import initialize_megatron from megatron.optimizer import get_megatron_optimizer -from megatron.checkpointing import save_checkpoint +from megatron.checkpointing import save_checkpoint, load_checkpoint from megatron.training import get_optimizer_param_scheduler from deepspeed.runtime.utils import see_memory_usage import deepspeed +import copy +from pathlib import Path + def add_extra_args(parser): """Text generation arguments.""" group = parser.add_argument_group(title='hf2mega') - group.add_argument("--hf-ckpt-num-shards", type=int, help='num of llama ckpt.') - group.add_argument("--origin-hf-ckpt-dir", + group.add_argument("--hf-ckpt-dir", type=str, default="", - help="the original path of the llama-hf ckpt") + help="the llama-hf ckpt") + group.add_argument("--hf-ckpt-num-shards", type=int, default=-1, help='num of llama ckpt.') + group.add_argument("--load-mode", type=str, + default=None, + choices=['torchbin', 'safetensor', 'auto'], + help="load ckpt format: pytorch.bin or model.safetensor or auto.") + group.add_argument("--to-hf-ckpt", action="store_true", + help="by default convert from hf to megads" + "if set, convert reversely from megads to hf ckpt.") return parser @@ -55,6 +67,49 @@ def load_and_print_hf_weight(hf_ckpt_dir, hf_ckpt_num_of_shards): return loaded +def load_and_print_hf_weight_from_safetensor(hf_ckpt_dir, hf_ckpt_num_of_shards): + from safetensors import safe_open + # Optimization point: We can selectively load specific 'shared' data to reduce CPU memory usage. + hf_model = {} + print_rank_0( + f"----------------------------hf weight list----------------------------") + + for wid in range(1, hf_ckpt_num_of_shards + 1): + if hf_ckpt_num_of_shards == 1: + ckpt_path = f"{hf_ckpt_dir}/model.safetensors" + else: + ckpt_path = f"{hf_ckpt_dir}/model-{wid:05d}-of-{hf_ckpt_num_of_shards:05d}.safetensors" + + with safe_open(ckpt_path, framework="pt", device="cpu") as f: + for k in f.keys(): + print_rank_0(f"name: {k}, shape: {f.get_tensor(k).shape}") + assert k not in hf_model + hf_model[k] = f.get_tensor(k).clone() + + return hf_model + + +def load_and_print_hf_weight_auto(hf_ckpt_dir, no_init=True): + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.modeling_utils import no_init_weights + + if no_init: + hf_config = AutoConfig.from_pretrained(hf_ckpt_dir, trust_remote_code=True) + with no_init_weights(): + hf_model = AutoModelForCausalLM.from_config(hf_config, trust_remote_code=True, torch_dtype=torch.bfloat16) + else: + hf_model = {} + hf_auto_model = AutoModelForCausalLM.from_pretrained(hf_ckpt_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) + print_rank_0( + f"----------------------------hf weight list----------------------------") + + for name, param in hf_auto_model.named_parameters(): + hf_model[name] = param.clone() + print_rank_0(name) + + return hf_model + + def print_distinct_weights(model): print_rank_0( f"----------------------------mega-ds weight list----------------------------") @@ -70,16 +125,19 @@ def print_distinct_weights(model): class refactor: - def __init__(self, model, loaded, args, config): + def __init__(self, ds_model, hf_model, args, config): tokenizer = get_tokenizer() # align layer number - self.model = model - self.loaded = loaded + self.ds_model = ds_model + self.hf_model = hf_model + self.hf_dict = {} # for handling pp case when converting mds => hf self.config = config self.offset_num = 2 self.mega_emb_wnum = 1 self.mega_norm_wnum = args.num_layers + 2 + self.num_attention_heads = args.num_attention_heads + self.num_key_value_heads = args.num_key_value_heads self.mega_lm_head_wnum = self.mega_norm_wnum + 1 self.token_vocab = tokenizer.vocab_size self.padded_vocab_size = args.padded_vocab_size @@ -95,7 +153,7 @@ def _embedding_refactor(self, pname, p): hf_name = "lm_head.weight" elif pname == f"{self.mega_emb_wnum}.word_embeddings.weight": hf_name = "model.embed_tokens.weight" - hf_w = self.loaded[hf_name] + hf_w = self.hf_model[hf_name] assert hf_w.shape[0] == self.token_vocab per_partition_vocab_size, start_index, end_index = compute_partition_range( self.padded_vocab_size, self.tp_rank, self.tp_size) @@ -112,24 +170,28 @@ def _embedding_refactor(self, pname, p): ) return new_w + + + def _direct_refactor(self, pname, p, hf_layer=None, subname=None): if pname == f"{self.mega_norm_wnum}.weight": hf_name = "model.norm.weight" elif subname in ["input_layernorm.weight", "post_attention_layernorm.weight"]: hf_name = f"model.layers.{hf_layer}.{subname}" - new_w = hf_w = self.loaded[hf_name] + new_w = hf_w = self.hf_model[hf_name] self.record_mapping_info( f"mega-ds:{pname,p.data.shape}<--hf{hf_name,} {hf_w.shape}") return new_w + def _qkv_refactor(self, pname, p, hf_layer): hf_wq_name = f"model.layers.{hf_layer}.self_attn.q_proj.weight" hf_wk_name = f"model.layers.{hf_layer}.self_attn.k_proj.weight" hf_wv_name = f"model.layers.{hf_layer}.self_attn.v_proj.weight" - wq = self.loaded[hf_wq_name] - wk = self.loaded[hf_wk_name] - wv = self.loaded[hf_wv_name] + wq = self.hf_model[hf_wq_name] + wk = self.hf_model[hf_wk_name] + wv = self.hf_model[hf_wv_name] hidden_size = wq.shape[0] per_partition_size, start_index, end_index = compute_partition_range( @@ -159,8 +221,8 @@ def _qkv_refactor(self, pname, p, hf_layer): def _mlphto4h_dense_refactor(self, pname, p, hf_layer): hf_w_gate_name = f"model.layers.{hf_layer}.mlp.gate_proj.weight" hf_w_up_name = f"model.layers.{hf_layer}.mlp.up_proj.weight" - w_gate = self.loaded[hf_w_gate_name] - w_up = self.loaded[hf_w_up_name] + w_gate = self.hf_model[hf_w_gate_name] + w_up = self.hf_model[hf_w_up_name] hidden_size = w_gate.shape[0] per_partition_size, start_index, end_index = compute_partition_range( @@ -184,7 +246,7 @@ def _attn_dense_refactor(self, pname, p, hf_layer, subname): else: hf_name = f"model.layers.{hf_layer}.mlp.down_proj.weight" - hf_w = self.loaded[hf_name] + hf_w = self.hf_model[hf_name] hidden_size = hf_w.shape[1] per_partition_size, start_index, end_index = compute_partition_range( hidden_size, self.tp_rank, self.tp_size) @@ -200,7 +262,7 @@ def _mlphto4h1_refactor(self, pname, p, hf_layer, subname): hf_name = f"model.layers.{hf_layer}.mlp.gate_proj.weight" else: hf_name = f"model.layers.{hf_layer}.mlp.up_proj.weight" - hf_w = self.loaded[hf_name] + hf_w = self.hf_model[hf_name] hidden_size = hf_w.shape[0] per_partition_size, start_index, end_index = compute_partition_range( hidden_size, self.tp_rank, self.tp_size) @@ -212,10 +274,11 @@ def _mlphto4h1_refactor(self, pname, p, hf_layer, subname): ) return new_w - def refactor(self): + def transform_from_hf_to_megds(self): assert self.is_refactored == False new_w = None - for pname, p in self.model.named_parameters(): + for pname, p in self.ds_model.named_parameters(): + if pname in [ f"{self.mega_emb_wnum}.word_embeddings.weight", f"{self.mega_lm_head_wnum}.lm_head.weight" @@ -253,6 +316,123 @@ def refactor(self): new_w = None self.is_refactored = True + + def _embedding_refactor_to_hf(self, pname, ds_w): + if pname == f"{self.mega_lm_head_wnum}.lm_head.weight": + hf_w = self.hf_model.lm_head.weight + hf_w_name = "lm_head.weight" + elif pname == f"{self.mega_emb_wnum}.word_embeddings.weight": + hf_w = self.hf_model.model.embed_tokens.weight + hf_w_name = "model.embed_tokens.weight" + + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w) + + self.hf_dict[hf_w_name] = copy.deepcopy(ds_w_all_rank[:hf_w.shape[0], :]) + + def _direct_refactor_to_hf(self, pname, ds_w, hf_layer=None, subname=None): + if pname in [f"{self.mega_norm_wnum}.weight"]: + hf_w = self.hf_model.model.norm.weight + hf_w_name = "model.norm.weight" + elif subname in ["input_layernorm.weight"]: + hf_w = self.hf_model.model.layers[hf_layer].input_layernorm.weight + hf_w_name = f"model.layers.{hf_layer}.input_layernorm.weight" + elif subname in ["post_attention_layernorm.weight"]: + hf_w = self.hf_model.model.layers[hf_layer].post_attention_layernorm.weight + hf_w_name = f"model.layers.{hf_layer}.post_attention_layernorm.weight" + + self.hf_dict[hf_w_name] = copy.deepcopy(ds_w) + + def _attn_dense_refactor_to_hf(self, pname, ds_w, hf_layer, subname): + if subname == "self_attention.dense.weight": + hf_w = self.hf_model.model.layers[hf_layer].self_attn.o_proj.weight + hf_w_name = f"model.layers.{hf_layer}.self_attn.o_proj.weight" + elif subname == "mlp.dense_4h_to_h.weight": + hf_w = self.hf_model.model.layers[hf_layer].mlp.down_proj.weight + hf_w_name = f"model.layers.{hf_layer}.mlp.down_proj.weight" + + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_last_dim(ds_w) + + self.hf_dict[hf_w_name] = copy.deepcopy(ds_w_all_rank) + + def _mlphto4h_dense_refactor_to_hf(self, pname, ds_w, hf_layer): + hf_g_name = f"model.layers.{hf_layer}.mlp.gate_proj.weight" + hf_u_name = f"model.layers.{hf_layer}.mlp.up_proj.weight" + + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w) + + ds_w_shape = ds_w_all_rank.shape + ds_w_all_rank = ds_w_all_rank.reshape(self.tp_size, 2, -1, ds_w_shape[-1]) + self.hf_dict[hf_g_name] = copy.deepcopy(ds_w_all_rank[:, 0, :, :].reshape(-1, ds_w_shape[-1])) + self.hf_dict[hf_u_name] = copy.deepcopy(ds_w_all_rank[:, 1, :, :].reshape(-1, ds_w_shape[-1])) + + + def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer): + with torch.no_grad(): + ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w) + + hf_q = self.hf_model.model.layers[hf_layer].self_attn.q_proj.weight + hf_k = self.hf_model.model.layers[hf_layer].self_attn.k_proj.weight + hf_v = self.hf_model.model.layers[hf_layer].self_attn.v_proj.weight + hf_q_name = f"model.layers.{hf_layer}.self_attn.q_proj.weight" + hf_k_name = f"model.layers.{hf_layer}.self_attn.k_proj.weight" + hf_v_name = f"model.layers.{hf_layer}.self_attn.v_proj.weight" + oldshape = hf_q.shape + hidden_size = oldshape[-1] + hidden_size_per_attention_head = divide(hidden_size, + self.config.num_attention_heads) + num_attention_heads_per_partition = divide(self.config.num_attention_heads, + self.tp_size) + newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size) + ds_w_out = ds_w_all_rank.reshape(*newshape) + self.hf_dict[hf_q_name] = copy.deepcopy(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1])) + self.hf_dict[hf_k_name] = copy.deepcopy(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1])) + self.hf_dict[hf_v_name] = copy.deepcopy(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1])) + + + def transform_from_megads_to_hf(self): + use_gqa = True if self.num_attention_heads != self.num_key_value_heads else False + + for pname, p in self.ds_model.named_parameters(): + if pname in [ + f"{self.mega_emb_wnum}.word_embeddings.weight", + f"{self.mega_lm_head_wnum}.lm_head.weight", + ]: + self._embedding_refactor_to_hf(pname, p) + elif pname in [ + f"{self.mega_norm_wnum}.weight", + ]: + self._direct_refactor_to_hf(pname, p) + else: + mobj = self.decoder_pat.match(pname) + layer_num = int(mobj.group(1)) + subname = mobj.group(2) + hf_layer = layer_num - self.offset_num + if subname in ["self_attention.query_key_value.weight"]: + if not use_gqa: + self._qkv_refactor_to_hf(pname, p, hf_layer) + else: + #TODO(billishyahao): Not impl yet ... + assert False + elif subname in ["mlp.dense_h_to_4h.weight"]: + self._mlphto4h_dense_refactor_to_hf(pname, p, hf_layer) + elif subname in [ + "self_attention.dense.weight", + "mlp.dense_4h_to_h.weight" + ]: + self._attn_dense_refactor_to_hf(pname, p, hf_layer, subname) + elif subname in [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + ]: + self._direct_refactor_to_hf(pname, p, hf_layer, subname) + else: + print(f"Unrecognized weight type: {pname}") + raise ValueError(f"Unrecognized weight type: {pname}") + self.is_refactored = True + def record_mapping_info(self, record_msg): self.refactor_weight_list.append(record_msg) @@ -272,7 +452,18 @@ def inorder_show_record(self): torch.distributed.barrier() -def convert_hf_to_mega_ds(): +def load_hf_weights(args, no_init): + if args.load_mode == 'torchbin': + assert no_init == False, "only work with init" + return load_and_print_hf_weight(args.hf_ckpt_dir, args.hf_ckpt_num_shards) + elif args.load_mode == 'safetensor': + assert no_init == False, "only work with init" + return load_and_print_hf_weight_from_safetensor(args.hf_ckpt_dir, args.hf_ckpt_num_shards) + elif args.load_mode == 'auto': + return load_and_print_hf_weight_auto(args.hf_ckpt_dir, no_init) + + +def convert_ckpt(): """Build the model.""" args = get_args() print_rank_0(f'building model ...') @@ -286,49 +477,74 @@ def convert_hf_to_mega_ds(): enabled=args.zero_stage == 3, mpu=mpu): if args.deepspeed and not args.no_pipeline_parallel: - model = GPTModelPipe(config, num_tokentypes=0, parallel_output=True) + ds_model = GPTModelPipe(config, num_tokentypes=0, parallel_output=True) else: raise NotImplementedError("Not implemented") see_memory_usage(f"After Building Model", force=True) if torch.distributed.get_rank() < 2: - print(f"{torch.distributed.get_rank()} {model}") - - # load and initialize HF weight dict - # print hf weights list & mega-ds weights list - hf_ckpt_dir = args.origin_hf_ckpt_dir - hf_ckpt_num_of_shards = args.hf_ckpt_num_shards - loaded = load_and_print_hf_weight(hf_ckpt_dir, hf_ckpt_num_of_shards) - print_distinct_weights(model) - - # refactor weight from hf to mega-ds - - cur_refactor = refactor(model, loaded, args, config) - cur_refactor.refactor() - cur_refactor.inorder_show_record() + print(f"{torch.distributed.get_rank()} {ds_model}") - del loaded + # 'torchbin', 'safetensor', 'auto' + hf_model = load_hf_weights(args, no_init=args.to_hf_ckpt) - unwrapped_model = unwrap_model([model], (torchDDP, LocalDDP, Float16Module)) - optimizer = get_megatron_optimizer(unwrapped_model) - opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + # print_distinct_weights(hf_model) #init model and save print_rank_0(f"before deepspeed init") ds_engine, _, _, _ = deepspeed.initialize( - model=model, - optimizer=optimizer, + model=ds_model, + optimizer=None, args=args, - lr_scheduler=opt_param_scheduler, + lr_scheduler=None, mpu=mpu if args.no_pipeline_parallel else None) print_rank_0(f"after deepspeed init") - print_rank_0(f"mega-ds checkpoint will be saved in {args.save}") - save_checkpoint(0, [ds_engine], optimizer, opt_param_scheduler) - print_rank_0(f"save checkpoint completed") + if args.to_hf_ckpt: + load_checkpoint([ds_engine], None, None, load_only_weights=True) + print_rank_0(f"completed to load deepspeed actual checkpoint") + + # refactor weight from hf to mega-ds and vice versa + + cur_refactor = refactor(ds_model, hf_model, args, config) + if args.to_hf_ckpt: + cur_refactor.transform_from_megads_to_hf() + else: + cur_refactor.transform_from_hf_to_megds() + # cur_refactor.inorder_show_record() + + if args.to_hf_ckpt: + save_path = args.save + if not os.path.exists(save_path): + Path(save_path).mkdir(parents=True, exist_ok=True) + ckpt_per_pp_path = os.path.join(save_path, f"model_pp{mpu.get_pipeline_model_parallel_rank()}.pt") + torch.save(cur_refactor.hf_dict, ckpt_per_pp_path) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + print_rank_0(f"hf checkpoint will be saved in {save_path}/release ") + if mpu.is_pipeline_last_stage(): + ## doing checkpoint merging and saving... + # hf_model.tie_weights() + + all_wei = {} + for pprank in range(mpu.get_pipeline_model_parallel_world_size()): + ckpt_per_pp_path = os.path.join(save_path, f"model_pp{pprank}.pt") + partial_wei = torch.load(ckpt_per_pp_path) + all_wei = all_wei | partial_wei + + hf_model.load_state_dict(all_wei) + + # mega-ds checkpoint will be saved in args.save + hf_model.save_pretrained(os.path.join(save_path, "release"), safe_serialization=True) + else: + print_rank_0(f"mega-ds checkpoint will be saved in {args.save}") + save_checkpoint(0, [ds_engine], None, None) + + print_rank_0(f"save checkpoint completed") if __name__ == "__main__": initialize_megatron(extra_args_provider=add_extra_args) - convert_hf_to_mega_ds() + convert_ckpt()