Skip to content

Commit

Permalink
add check if pos embed
Browse files Browse the repository at this point in the history
Signed-off-by: jiemingz <jiemingz@nvidia.com>
  • Loading branch information
jiemingz committed Apr 9, 2024
1 parent 35e400f commit 72f7fd2
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,8 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any:
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if self.mcore_gpt:
fp32_params.append(modules[0].shared_embedding_or_output_weight())
fp32_params.append(modules[0].embedding.position_embeddings.weight)
if modules[0].embedding.add_position_embedding:
fp32_params.append(modules[0].embedding.position_embeddings.weight)
else:
fp32_params.append(modules[0].word_embeddings_weight())
fp32_params.append(modules[0].position_embeddings_weight())
Expand Down

0 comments on commit 72f7fd2

Please sign in to comment.