diff --git a/Dockerfile.ci b/Dockerfile.ci index 414cd5473672..05cbb851acbf 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -38,7 +38,6 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_T # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.15.0 - ARG MCORE_TAG=8307fcda5fff57ab0e77131b09bf37da997ec1f2 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 788e9bd059f6..d2a21e50e486 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -534,6 +534,8 @@ def build_transformer_config(self) -> TransformerConfig: recompute_method = self.cfg.get('activations_checkpoint_method', None) recompute_num_layers = self.cfg.get('activations_checkpoint_num_layers', None) + tp_only_amax_red = self.cfg.get('tp_only_amax_red', False) + # any configs that are not in the nemo model config will be added here config_mapping = { 'apply_query_key_layer_scaling': apply_query_key_layer_scaling, @@ -557,6 +559,7 @@ def build_transformer_config(self) -> TransformerConfig: 'fp8': None, 'rotary_interleaved': rotary_interleaved, 'deallocate_pipeline_outputs': True, + 'tp_only_amax_red': tp_only_amax_red, } # populate the transformer config dict diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 5872f9d21014..e3d426ec9275 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -558,6 +558,9 @@ def setup_mcore_distributed_parallel(self): # using bucket_cap_mb to configure bucket_size here bucket_size=self.cfg.optim.get('ddp_bucket_size', None), average_in_collective=self.cfg.optim.get('average_in_collective', True), + overlap_param_gather=self.cfg.optim.get('overlap_param_gather', False), + align_param_gather=self.cfg.optim.get('align_param_gather', False), + fp8_param_gather=self.cfg.get('fp8_params', False), ) self.model = [ McoreDDP( @@ -566,7 +569,8 @@ def setup_mcore_distributed_parallel(self): model_chunk, # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0), + disable_bucketing=(model_chunk_idx > 0) + or self.cfg.optim.get('overlap_param_gather_with_optimizer_step', False), ) for (model_chunk_idx, model_chunk) in enumerate(self.model) ] @@ -685,14 +689,11 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): no_sync_func = [model_chunk.no_sync for model_chunk in self.model] no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func - if self.cfg.optim.get("delay_grad_reduce", True): + if self.cfg.optim.get("align_grad_reduce", True): grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model] grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func - if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False): - param_sync_func = [ - lambda x, model_index=model_index: self._optimizer.finish_param_sync(model_index, x) - for model_index in range(len(self.model)) - ] + if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("align_param_gather", False): + param_sync_func = [model_chunk.start_param_sync for model_chunk in self.model] param_sync_func = param_sync_func[0] if len(self.model) == 1 else param_sync_func # pipeline schedules will get these from self.model.config diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 2bfd4e5cd695..6393bb5581d6 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -615,8 +615,9 @@ def setup_megatron_optimization(self, optim_config: Union[Dict[str, Any], DictCo adam_beta2=optim_config['betas'][1], clip_grad=self.trainer.gradient_clip_val, use_distributed_optimizer=self.use_mcore_dist_optim, - overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False), - overlap_param_gather=self.cfg.optim.get('overlap_param_sync', False), + overlap_param_gather_with_optimizer_step=self.cfg.optim.get( + 'overlap_param_gather_with_optimizer_step', False + ), ) return megatron_optim_config diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 09b6a290d558..c0f3464a1e7c 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -107,9 +107,6 @@ def _set_param_groups(self, value): param_groups = property(_get_param_groups, _set_param_groups) - def finish_param_sync(self, model_index): - self.mcore_optimizer.finish_param_sync(model_index) - def disable_pre_hook(self): self.mcore_optimizer.disable_pre_hook() diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index f8476a440b0c..60f090d6318f 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -99,8 +99,7 @@ def extract_ddp_funcs(ddp_config, pipeline): if getattr(ddp_config, "overlap_grad_reduce", False): no_sync_func = [model_chunk.no_sync for model_chunk in pipeline] no_sync_func = no_sync_func[0] if len(pipeline) == 1 else no_sync_func - # TODO(@akoumparouli): why is True default here? - if getattr(ddp_config, "delay_grad_reduce", True): + if getattr(ddp_config, "align_grad_reduce", False): grad_sync_func = [model_chunk.start_grad_sync for model_chunk in pipeline] grad_sync_func = grad_sync_func[0] if len(pipeline) == 1 else grad_sync_func diff --git a/nemo/lightning/pytorch/optim/megatron.py b/nemo/lightning/pytorch/optim/megatron.py index 5252f7621859..9450c24002c0 100644 --- a/nemo/lightning/pytorch/optim/megatron.py +++ b/nemo/lightning/pytorch/optim/megatron.py @@ -125,13 +125,10 @@ def sharded_state_dict( ) if getattr(model.ddp_config, "overlap_param_sync", False) and getattr( - model.ddp_config, "delay_param_gather", False + model.ddp_config, "align_param_gather", False ): - param_sync_func = [ - lambda x, model_index=model_index: mcore_opt.finish_param_sync(model_index, x) - for model_index in range(len(pipeline)) - ] - param_sync_func = param_sync_func[0] if len(pipeline) == 1 else param_sync_func + param_sync_func = [model_chunk.start_param_sync for model_chunk in model] + param_sync_func = param_sync_func[0] if len(model) == 1 else param_sync_func for module in model: module.config.param_sync_func = param_sync_func