diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 6b158a33b226..c7ec23ee210b 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -181,7 +181,8 @@ class GPTConfig(TransformerConfig, io.IOMixin): def configure_model(self, tokenizer) -> "MCoreGPTModel": vp_size = self.virtual_pipeline_model_parallel_size - if vp_size: + use_asymmetric_pipeline = getattr(self, 'standalone_embedding_stage', False) or getattr(self, 'standalone_loss_stage', False) + if vp_size and not use_asymmetric_pipeline: p_size = self.pipeline_model_parallel_size assert ( self.num_layers // p_size diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index 31c83713b6e7..5d8de2aa568e 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -64,12 +64,14 @@ def model() -> run.Config[pl.LightningModule]: def trainer( tensor_parallelism: int = 8, - pipeline_parallelism: int = 9, + pipeline_parallelism: int = 8, pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, virtual_pipeline_parallelism: Optional[int] = 2, - context_parallelism: int = 4, + context_parallelism: int = 2, sequence_parallelism: bool = True, - num_nodes: int = 72, + standalone_embedding_stage: bool = True, + standalone_loss_stage: bool = True, + num_nodes: int = 128, num_gpus_per_node: int = 8, max_steps: int = 1168251, callbacks: Optional[list[run.Config[Callback]]] = None, @@ -113,6 +115,8 @@ def trainer( virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, context_parallel_size=context_parallelism, sequence_parallel=sequence_parallelism, + standalone_embedding_stage=standalone_embedding_stage, + standalone_loss_stage=standalone_loss_stage, gradient_as_bucket_view=True, ckpt_async_save=True, ckpt_parallel_load=True, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index d2a21e50e486..b673edabf14d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -181,9 +181,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): if vp_size == 1: vp_size = None else: - assert ( - self.cfg.num_layers // self.cfg.pipeline_model_parallel_size - ) % vp_size == 0, 'Make sure the number of model chunks is the same across all pipeline stages.' + if not(self.cfg.get('standalone_embedding_stage', False) and self.cfg.get('standalone_loss_stage', False)): + assert ( + self.cfg.num_layers // self.cfg.pipeline_model_parallel_size + ) % vp_size == 0, 'Make sure the number of model chunks is the same across all pipeline stages.' initialize_model_parallel_for_nemo( world_size=init_world_size, @@ -536,6 +537,9 @@ def build_transformer_config(self) -> TransformerConfig: tp_only_amax_red = self.cfg.get('tp_only_amax_red', False) + standalone_embedding_stage = self.cfg.get('standalone_embedding_stage', False) + standalone_loss_stage = self.cfg.get('standalone_loss_stage', False) + # any configs that are not in the nemo model config will be added here config_mapping = { 'apply_query_key_layer_scaling': apply_query_key_layer_scaling, @@ -560,6 +564,8 @@ def build_transformer_config(self) -> TransformerConfig: 'rotary_interleaved': rotary_interleaved, 'deallocate_pipeline_outputs': True, 'tp_only_amax_red': tp_only_amax_red, + 'standalone_embedding_stage': standalone_embedding_stage, + 'standalone_loss_stage': standalone_loss_stage, } # populate the transformer config dict @@ -998,9 +1004,10 @@ def _validate_and_override_config(self): if vp_size == 1: self.cfg['virtual_pipeline_model_parallel_size'] = None else: - assert ( - self.cfg.num_layers // self.cfg.pipeline_model_parallel_size - ) % vp_size == 0, 'Make sure the number of model chunks is the same across all pipeline stages.' + if not(self.cfg.get('standalone_embedding_stage', False) and self.cfg.get('standalone_loss_stage', False)): + assert ( + self.cfg.num_layers // self.cfg.pipeline_model_parallel_size + ) % vp_size == 0, 'Make sure the number of model chunks is the same across all pipeline stages.' if self.cfg.get('ub_tp_comm_overlap', False): if not self.cfg.get('sequence_parallel', False): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 8f541e5703e6..41feba672f6d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -2118,11 +2118,12 @@ def build_transformer_config(self) -> TransformerConfig: For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. """ - if self.cfg.num_layers % self.cfg.get('pipeline_model_parallel_size', 1) != 0: - raise ValueError( - f"num_layers ({self.cfg.num_layers}) should be divisible by " - f"pipeline_model_parallel_size ({self.cfg.get('pipeline_model_parallel_size', 1)})" - ) + if not(self.cfg.get('standalone_embedding_stage', False) and self.cfg.get('standalone_loss_stage', False)): + if self.cfg.num_layers % self.cfg.get('pipeline_model_parallel_size', 1) != 0: + raise ValueError( + f"num_layers ({self.cfg.num_layers}) should be divisible by " + f"pipeline_model_parallel_size ({self.cfg.get('pipeline_model_parallel_size', 1)})" + ) normalization = self.cfg.get('normalization', 'layernorm').lower() layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p' or self.cfg.get( diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c62a90313b45..9582e649bf94 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -99,6 +99,8 @@ class ParallelismConfig: pipeline_dtype: torch.dtype encoder_tensor_model_parallel_size: int = 0 encoder_pipeline_model_parallel_size: int = 0 + standalone_embedding_stage: bool = False + standalone_loss_stage: bool = False class MegatronStrategy(DDPStrategy, io.IOMixin): @@ -125,6 +127,8 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): expert_model_parallel_size (int): Distributes MoE Experts across sub data parallel dimension. Defaults to 1. moe_extended_tp (bool): Alternative parallelization strategy for expert parallelism. Defaults to False. + standalone_embedding_stage: TODO + standalone_loss_stage: TODO data_sampler (Optional['DataSampler']): Custom data sampler for distributed training. Defaults to None. parallel_devices (Optional[List[torch.device]]): List of devices to use for parallelism. Defaults to None. cluster_environment: Cluster environment for distributed training. Defaults to None. @@ -185,6 +189,8 @@ def __init__( moe_extended_tp: bool = False, encoder_tensor_model_parallel_size: Optional[int] = 0, encoder_pipeline_model_parallel_size: Optional[int] = 0, + standalone_embedding_stage: bool = False, + standalone_loss_stage: bool = False, data_sampler: Optional["DataSampler"] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment=None, # TODO: Add type-hint @@ -235,6 +241,8 @@ def __init__( self.sequence_parallel = sequence_parallel self.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size self.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size + self.standalone_embedding_stage = standalone_embedding_stage + self.standalone_loss_stage = standalone_loss_stage self.lazy_init = lazy_init self.ckpt_load_optimizer = ckpt_load_optimizer self.ckpt_save_optimizer = ckpt_save_optimizer @@ -862,6 +870,8 @@ def parallelism(self) -> ParallelismConfig: moe_extended_tp=self.moe_extended_tp, encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size, encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size, + standalone_embedding_stage=self.standalone_embedding_stage, + standalone_loss_stage=self.standalone_loss_stage, pipeline_dtype=self.pipeline_dtype, )