Skip to content

Commit

Permalink
Merge branch 'main' into auto_cudagraph
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumpa authored Sep 11, 2024
2 parents f8d9bd8 + 2089c53 commit 6ad766d
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 21 deletions.
1 change: 0 additions & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_T
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.15.0

ARG MCORE_TAG=8307fcda5fff57ab0e77131b09bf37da997ec1f2

ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ def build_transformer_config(self) -> TransformerConfig:
recompute_method = self.cfg.get('activations_checkpoint_method', None)
recompute_num_layers = self.cfg.get('activations_checkpoint_num_layers', None)

tp_only_amax_red = self.cfg.get('tp_only_amax_red', False)

# any configs that are not in the nemo model config will be added here
config_mapping = {
'apply_query_key_layer_scaling': apply_query_key_layer_scaling,
Expand All @@ -557,6 +559,7 @@ def build_transformer_config(self) -> TransformerConfig:
'fp8': None,
'rotary_interleaved': rotary_interleaved,
'deallocate_pipeline_outputs': True,
'tp_only_amax_red': tp_only_amax_red,
}

# populate the transformer config dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,9 @@ def setup_mcore_distributed_parallel(self):
# using bucket_cap_mb to configure bucket_size here
bucket_size=self.cfg.optim.get('ddp_bucket_size', None),
average_in_collective=self.cfg.optim.get('average_in_collective', True),
overlap_param_gather=self.cfg.optim.get('overlap_param_gather', False),
align_param_gather=self.cfg.optim.get('align_param_gather', False),
fp8_param_gather=self.cfg.get('fp8_params', False),
)
self.model = [
McoreDDP(
Expand All @@ -566,7 +569,8 @@ def setup_mcore_distributed_parallel(self):
model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
disable_bucketing=(model_chunk_idx > 0)
or self.cfg.optim.get('overlap_param_gather_with_optimizer_step', False),
)
for (model_chunk_idx, model_chunk) in enumerate(self.model)
]
Expand Down Expand Up @@ -685,14 +689,11 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):
no_sync_func = [model_chunk.no_sync for model_chunk in self.model]
no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func

if self.cfg.optim.get("delay_grad_reduce", True):
if self.cfg.optim.get("align_grad_reduce", True):
grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model]
grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func
if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False):
param_sync_func = [
lambda x, model_index=model_index: self._optimizer.finish_param_sync(model_index, x)
for model_index in range(len(self.model))
]
if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("align_param_gather", False):
param_sync_func = [model_chunk.start_param_sync for model_chunk in self.model]
param_sync_func = param_sync_func[0] if len(self.model) == 1 else param_sync_func

# pipeline schedules will get these from self.model.config
Expand Down
5 changes: 3 additions & 2 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,9 @@ def setup_megatron_optimization(self, optim_config: Union[Dict[str, Any], DictCo
adam_beta2=optim_config['betas'][1],
clip_grad=self.trainer.gradient_clip_val,
use_distributed_optimizer=self.use_mcore_dist_optim,
overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False),
overlap_param_gather=self.cfg.optim.get('overlap_param_sync', False),
overlap_param_gather_with_optimizer_step=self.cfg.optim.get(
'overlap_param_gather_with_optimizer_step', False
),
)
return megatron_optim_config

Expand Down
3 changes: 0 additions & 3 deletions nemo/core/optim/mcore_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def _set_param_groups(self, value):

param_groups = property(_get_param_groups, _set_param_groups)

def finish_param_sync(self, model_index):
self.mcore_optimizer.finish_param_sync(model_index)

def disable_pre_hook(self):
self.mcore_optimizer.disable_pre_hook()

Expand Down
3 changes: 1 addition & 2 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def extract_ddp_funcs(ddp_config, pipeline):
if getattr(ddp_config, "overlap_grad_reduce", False):
no_sync_func = [model_chunk.no_sync for model_chunk in pipeline]
no_sync_func = no_sync_func[0] if len(pipeline) == 1 else no_sync_func
# TODO(@akoumparouli): why is True default here?
if getattr(ddp_config, "delay_grad_reduce", True):
if getattr(ddp_config, "align_grad_reduce", False):
grad_sync_func = [model_chunk.start_grad_sync for model_chunk in pipeline]
grad_sync_func = grad_sync_func[0] if len(pipeline) == 1 else grad_sync_func

Expand Down
9 changes: 3 additions & 6 deletions nemo/lightning/pytorch/optim/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,10 @@ def sharded_state_dict(
)

if getattr(model.ddp_config, "overlap_param_sync", False) and getattr(
model.ddp_config, "delay_param_gather", False
model.ddp_config, "align_param_gather", False
):
param_sync_func = [
lambda x, model_index=model_index: mcore_opt.finish_param_sync(model_index, x)
for model_index in range(len(pipeline))
]
param_sync_func = param_sync_func[0] if len(pipeline) == 1 else param_sync_func
param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
param_sync_func = param_sync_func[0] if len(model) == 1 else param_sync_func
for module in model:
module.config.param_sync_func = param_sync_func

Expand Down

0 comments on commit 6ad766d

Please sign in to comment.