From 72f7fd2a53cf631b9a63faf99f835feaf5360296 Mon Sep 17 00:00:00 2001 From: jiemingz Date: Mon, 8 Apr 2024 18:47:03 -0700 Subject: [PATCH] add check if pos embed Signed-off-by: jiemingz --- .../nlp/models/language_modeling/megatron_base_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index baa6e30af81d..854c5ee02e31 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -775,7 +775,8 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any: if parallel_state.is_pipeline_first_stage(ignore_virtual=True): if self.mcore_gpt: fp32_params.append(modules[0].shared_embedding_or_output_weight()) - fp32_params.append(modules[0].embedding.position_embeddings.weight) + if modules[0].embedding.add_position_embedding: + fp32_params.append(modules[0].embedding.position_embeddings.weight) else: fp32_params.append(modules[0].word_embeddings_weight()) fp32_params.append(modules[0].position_embeddings_weight())