diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 8b027168a21d..ed5c5db0ab5f 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -130,8 +130,8 @@ def _wrap_for_auto(self, model, train_dataloader): self.optimizer = dist.shard_optimizer( self.optimizer, dist.ShardingStage3(), self.args.gradient_accumulation_steps ) - else: - self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps) + # else: + # self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps) if self.args.to_static: unified_strategy = dist.Strategy() diff --git a/paddlenlp/trainer/utils/ckpt_converter.py b/paddlenlp/trainer/utils/ckpt_converter.py index dc1481f1f471..5ad4a4e24022 100644 --- a/paddlenlp/trainer/utils/ckpt_converter.py +++ b/paddlenlp/trainer/utils/ckpt_converter.py @@ -492,6 +492,35 @@ def gen_metadata_and_prepare_source_state_dict(self): else: return self.gen_metadata_for_tp_sharded_tensor() + def rename_state_dict_expert(self, state_dict, file_name): + (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) + # print(file_name) + # print(f"tp_rank: {tp_rank} sharding_rank: {sharding_rank}") + pattern = r"(experts\.)(\d+)" + expert = set() + for state_name in state_dict.keys(): + res = re.search(pattern, state_name) + if res: + expert.add(int(res.group(2))) + expert_num = len(expert) + expert_name_old2new = {} + for state_name in state_dict.keys(): + res = re.search(pattern, state_name) + if res: + new_expert_id = int(res.group(2)) + tp_rank * expert_num + new_expert_str = f"{res.group(1)}{new_expert_id}" + new_param_name = re.sub(pattern, new_expert_str, state_name) + expert_name_old2new[state_name] = new_param_name + + renamed_state_dict = {} + for state_name in state_dict.keys(): + if state_name in expert_name_old2new: + renamed_state_dict[expert_name_old2new[state_name]] = state_dict[state_name] + else: + renamed_state_dict[state_name] = state_dict[state_name] + + return renamed_state_dict + def load_state_dict_and_rename(self): """ Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state @@ -736,10 +765,13 @@ def load_state_dict_and_rename(self): assert model_state_file_name is not None model_state_keys = global_file_to_state_dict_keys_mapping[model_state_file_name] renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict) - self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos) + renamed_state_dict = self.rename_state_dict_expert(renamed_state_dict, file_name) + self.get_sharded_tensor_infos(file_name, renamed_state_dict, cur_rank_sharded_tensor_infos) self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict else: + state_dict = self.rename_state_dict_expert(state_dict, file_name) self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos) + self.cur_rank_loaded_state_dict[file_name] = state_dict else: for file, state_dict in self.cur_rank_loaded_state_dict.items(): # The rule for renaming is to change the master_weights name in the optimizer state to the model weight name, @@ -873,24 +905,40 @@ def rename_using_parameter_to_structured_name_mapping(self, state_dict, paramete renamed_state_dict = {} def rename(old_name, parameter_to_structured_name): - for i in range(1, len(old_name) + 1): - param_name = old_name[:i] # param_name - suffix = old_name[i:] # suffix - if param_name in parameter_to_structured_name: - structure_name = parameter_to_structured_name[param_name] - if "moment1" in suffix: - return structure_name + ".moment1" - elif "moment2" in suffix: - return structure_name + ".moment2" - elif "beta1_pow_acc" in suffix: - return structure_name + ".beta1_pow_acc" - elif "beta2_pow_acc" in suffix: - return structure_name + ".beta2_pow_acc" + # for i in range(1, len(old_name) + 1): + # param_name = old_name[:i] # param_name + # suffix = old_name[i:] # suffix + # if param_name in parameter_to_structured_name: + # structure_name = parameter_to_structured_name[param_name] + # if "moment1" in suffix: + # return structure_name + ".moment1" + # elif "moment2" in suffix: + # return structure_name + ".moment2" + # elif "beta1_pow_acc" in suffix: + # return structure_name + ".beta1_pow_acc" + # elif "beta2_pow_acc" in suffix: + # return structure_name + ".beta2_pow_acc" + # else: + # return structure_name + ".master_weight" + # return None + for k, v in parameter_to_structured_name.items(): + if k in old_name: + if "moment1" in old_name: + return v + ".moment1" + elif "moment2" in old_name: + return v + ".moment2" + elif "beta1_pow_acc" in old_name: + return v + ".beta1_pow_acc" + elif "beta2_pow_acc" in old_name: + return v + ".beta2_pow_acc" else: - return structure_name + ".master_weight" + return v + ".master_weight" return None for key, value in state_dict.items(): + if not value._is_initialized(): + # print(f"param is skipped and not add into renamed_state_dict: {key}") + continue if key in parameter_to_structured_name.values(): new_name = key else: @@ -903,7 +951,7 @@ def rename(old_name, parameter_to_structured_name): def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict): name_mapping = {} suffix_bucket = {} - assert len(optimizer_state_dict) % len(model_state_keys) == 0 + # assert len(optimizer_state_dict) % len(model_state_keys) == 0 for suffix in OPTIMIZER_STATE_NAME_SUFFIX: suffix_bucket[suffix] = [] for opt_name, opt_value in optimizer_state_dict.items(): @@ -921,9 +969,23 @@ def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_d for suffix, old_names in suffix_bucket.items(): if len(old_names) == 0: continue - assert len(old_names) == len(model_state_keys) - for i in range(len(old_names)): - name_mapping[old_names[i]] = model_state_keys[i] + suffix + # assert len(old_names) == len(model_state_keys) + if suffix != ".master_weight": + for i in range(len(old_names)): + name_mapping[old_names[i]] = model_state_keys[i] + suffix + else: + for i in range(len(old_names)): + param = old_names[i][:-14] + index = -1 + for idx, opt_name in enumerate(suffix_bucket[".moment1"]): + if param == opt_name[:-24]: + index = idx + break + if index >= 0: + name_mapping[old_names[i]] = model_state_keys[index] + suffix + else: + print(suffix_bucket[".moment1"]) + raise RuntimeError(f"Can't find {param} for suffix_bucket in optimizer state dict.") renamed_state_dict = {} for k, v in optimizer_state_dict.items():