Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moe auto_parallel load ckpt from dyhand hack fix #9457

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def _wrap_for_auto(self, model, train_dataloader):
self.optimizer = dist.shard_optimizer(
self.optimizer, dist.ShardingStage3(), self.args.gradient_accumulation_steps
)
else:
self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps)
# else:
# self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps)

if self.args.to_static:
unified_strategy = dist.Strategy()
Expand Down
100 changes: 81 additions & 19 deletions paddlenlp/trainer/utils/ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,35 @@
else:
return self.gen_metadata_for_tp_sharded_tensor()

def rename_state_dict_expert(self, state_dict, file_name):
(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)

Check warning on line 496 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L495-L496

Added lines #L495 - L496 were not covered by tests
# print(file_name)
# print(f"tp_rank: {tp_rank} sharding_rank: {sharding_rank}")
pattern = r"(experts\.)(\d+)"
expert = set()
for state_name in state_dict.keys():
res = re.search(pattern, state_name)
if res:
expert.add(int(res.group(2)))
expert_num = len(expert)
expert_name_old2new = {}
for state_name in state_dict.keys():
res = re.search(pattern, state_name)
if res:
new_expert_id = int(res.group(2)) + tp_rank * expert_num
new_expert_str = f"{res.group(1)}{new_expert_id}"
new_param_name = re.sub(pattern, new_expert_str, state_name)
expert_name_old2new[state_name] = new_param_name

Check warning on line 513 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L499-L513

Added lines #L499 - L513 were not covered by tests

renamed_state_dict = {}
for state_name in state_dict.keys():
if state_name in expert_name_old2new:
renamed_state_dict[expert_name_old2new[state_name]] = state_dict[state_name]

Check warning on line 518 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L515-L518

Added lines #L515 - L518 were not covered by tests
else:
renamed_state_dict[state_name] = state_dict[state_name]

Check warning on line 520 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L520

Added line #L520 was not covered by tests

return renamed_state_dict

Check warning on line 522 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L522

Added line #L522 was not covered by tests

def load_state_dict_and_rename(self):
"""
Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state
Expand Down Expand Up @@ -736,10 +765,13 @@
assert model_state_file_name is not None
model_state_keys = global_file_to_state_dict_keys_mapping[model_state_file_name]
renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict)
self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos)
renamed_state_dict = self.rename_state_dict_expert(renamed_state_dict, file_name)
self.get_sharded_tensor_infos(file_name, renamed_state_dict, cur_rank_sharded_tensor_infos)

Check warning on line 769 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L768-L769

Added lines #L768 - L769 were not covered by tests
self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict
else:
state_dict = self.rename_state_dict_expert(state_dict, file_name)

Check warning on line 772 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L772

Added line #L772 was not covered by tests
self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos)
self.cur_rank_loaded_state_dict[file_name] = state_dict

Check warning on line 774 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L774

Added line #L774 was not covered by tests
else:
for file, state_dict in self.cur_rank_loaded_state_dict.items():
# The rule for renaming is to change the master_weights name in the optimizer state to the model weight name,
Expand Down Expand Up @@ -873,24 +905,40 @@
renamed_state_dict = {}

def rename(old_name, parameter_to_structured_name):
for i in range(1, len(old_name) + 1):
param_name = old_name[:i] # param_name
suffix = old_name[i:] # suffix
if param_name in parameter_to_structured_name:
structure_name = parameter_to_structured_name[param_name]
if "moment1" in suffix:
return structure_name + ".moment1"
elif "moment2" in suffix:
return structure_name + ".moment2"
elif "beta1_pow_acc" in suffix:
return structure_name + ".beta1_pow_acc"
elif "beta2_pow_acc" in suffix:
return structure_name + ".beta2_pow_acc"
# for i in range(1, len(old_name) + 1):
# param_name = old_name[:i] # param_name
# suffix = old_name[i:] # suffix
# if param_name in parameter_to_structured_name:
# structure_name = parameter_to_structured_name[param_name]
# if "moment1" in suffix:
# return structure_name + ".moment1"
# elif "moment2" in suffix:
# return structure_name + ".moment2"
# elif "beta1_pow_acc" in suffix:
# return structure_name + ".beta1_pow_acc"
# elif "beta2_pow_acc" in suffix:
# return structure_name + ".beta2_pow_acc"
# else:
# return structure_name + ".master_weight"
# return None
for k, v in parameter_to_structured_name.items():
if k in old_name:
if "moment1" in old_name:
return v + ".moment1"
elif "moment2" in old_name:
return v + ".moment2"
elif "beta1_pow_acc" in old_name:
return v + ".beta1_pow_acc"
elif "beta2_pow_acc" in old_name:
return v + ".beta2_pow_acc"

Check warning on line 933 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L924-L933

Added lines #L924 - L933 were not covered by tests
else:
return structure_name + ".master_weight"
return v + ".master_weight"

Check warning on line 935 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L935

Added line #L935 was not covered by tests
return None

for key, value in state_dict.items():
if not value._is_initialized():

Check warning on line 939 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L939

Added line #L939 was not covered by tests
# print(f"param is skipped and not add into renamed_state_dict: {key}")
continue

Check warning on line 941 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L941

Added line #L941 was not covered by tests
if key in parameter_to_structured_name.values():
new_name = key
else:
Expand All @@ -903,7 +951,7 @@
def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict):
name_mapping = {}
suffix_bucket = {}
assert len(optimizer_state_dict) % len(model_state_keys) == 0
# assert len(optimizer_state_dict) % len(model_state_keys) == 0
for suffix in OPTIMIZER_STATE_NAME_SUFFIX:
suffix_bucket[suffix] = []
for opt_name, opt_value in optimizer_state_dict.items():
Expand All @@ -921,9 +969,23 @@
for suffix, old_names in suffix_bucket.items():
if len(old_names) == 0:
continue
assert len(old_names) == len(model_state_keys)
for i in range(len(old_names)):
name_mapping[old_names[i]] = model_state_keys[i] + suffix
# assert len(old_names) == len(model_state_keys)
if suffix != ".master_weight":
for i in range(len(old_names)):
name_mapping[old_names[i]] = model_state_keys[i] + suffix

Check warning on line 975 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L973-L975

Added lines #L973 - L975 were not covered by tests
else:
for i in range(len(old_names)):
param = old_names[i][:-14]
index = -1
for idx, opt_name in enumerate(suffix_bucket[".moment1"]):
if param == opt_name[:-24]:
index = idx
break
if index >= 0:
name_mapping[old_names[i]] = model_state_keys[index] + suffix

Check warning on line 985 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L977-L985

Added lines #L977 - L985 were not covered by tests
else:
print(suffix_bucket[".moment1"])
raise RuntimeError(f"Can't find {param} for suffix_bucket in optimizer state dict.")

Check warning on line 988 in paddlenlp/trainer/utils/ckpt_converter.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/ckpt_converter.py#L987-L988

Added lines #L987 - L988 were not covered by tests

renamed_state_dict = {}
for k, v in optimizer_state_dict.items():
Expand Down
Loading