Skip to content

Commit

Permalink
Merge branch 'main' into fix_gated_linear_unit
Browse files Browse the repository at this point in the history
  • Loading branch information
athitten authored Jan 4, 2024
2 parents be00d21 + 86f1b7d commit a6d262f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,41 @@ def setup_transformer_engine_tp_groups(self):
tp_group = parallel_state.get_tensor_model_parallel_group()
child.set_tensor_parallel_group(tp_group)

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
model: The model to wrap. Can be a list of modules or a single module.
Returns:
The wrapped model. Returns a list of wrapped modules or a single wrapped module.
"""
is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False)

Float16Wrapper = MCoreFloat16Module if is_mcore_model else Float16Module

nemo_args = {'config': self.model_parallel_config, 'precision': self.cfg.precision}

if type(self).__name__ == 'MegatronGPTModel':
nemo_args['share_token_embeddings'] = self.cfg.get('share_embeddings_and_output_weights', True)

mcore_args = {
'config': self.transformer_config,
}

args = mcore_args if is_mcore_model else nemo_args

# Model wrapper to convert both model and inputs to half precision
if isinstance(self.model, list):
converted_model = []
for module in self.model:
args['module'] = module
converted_model.append(Float16Wrapper(**args))
self.model = converted_model
else:
args['module'] = self.model
self.model = Float16Wrapper(**args)

args.pop('module')

def get_model_module_list(self):
if isinstance(self.model, list):
return [
Expand Down Expand Up @@ -827,6 +862,7 @@ def is_data_parallel_rank_zero(self):

def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
"""Returns the total number of parameters across all model parallel groups."""
is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False)
# log number of parameters
if isinstance(model, list):
num_parameters_on_device = sum(
Expand All @@ -839,7 +875,7 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
):
word_embeddings_weight = (
model[-1].module.shared_embedding_or_output_weight()
if getattr(self, 'mcore_gpt', False)
if is_mcore_model
else model[-1].word_embeddings_weight()
)
# substract the embedding weights on the last virtual stage
Expand All @@ -854,7 +890,7 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
):
word_embeddings_weight = (
model.module.shared_embedding_or_output_weight()
if getattr(self, 'mcore_gpt', False)
if is_mcore_model
else model.word_embeddings_weight()
)
# substract the embedding weights on the last stage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,40 +136,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
model: The model to wrap. Can be a list of modules or a single module.
Returns:
The wrapped model. Returns a list of wrapped modules or a single wrapped module.
"""
Float16Wrapper = MCoreFloat16Module if self.mcore_bert else Float16Module

nemo_args = {
'config': self.model_parallel_config,
'precision': self.cfg.precision,
}
mcore_args = {
'config': self.transformer_config,
}

args = mcore_args if self.mcore_bert else nemo_args

# Model wrapper to convert both model and inputs to half precision
if isinstance(self.model, list):
converted_model = []
for module in self.model:
if not self.mcore_bert:
args['module'] = module
converted_model.append(Float16Wrapper(**args))
self.model = converted_model
else:
if not self.mcore_bert:
args['module'] = self.model
self.model = Float16Wrapper(**args)

args.pop('module')

def model_provider_func(self, pre_process, post_process):
cfg = self.cfg
num_tokentypes = 2 if cfg.bert_binary_head else 0
Expand Down Expand Up @@ -990,7 +956,7 @@ def configure_optimizers(self):
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
stage_bucket = []
layers = module.transformer.layers if self.mcore_bert else module.language_model.encoder.layers
layers = module.encoder.layers if self.mcore_bert else module.language_model.encoder.layers
for layer in layers:
stage_bucket.extend(
p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
Expand All @@ -1002,7 +968,7 @@ def configure_optimizers(self):
for module in modules:
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
layers = module.transformer.layers if self.mcore_bert else module.language_model.encoder.layers
layers = module.encoder.layers if self.mcore_bert else module.language_model.encoder.layers
for layer in layers:
buckets.append(
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1554,36 +1554,3 @@ def build_transformer_config(self) -> TransformerConfig:
setattr(transformer_config, key, value)

return transformer_config

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
model: The model to wrap. Can be a list of modules or a single module.
Returns:
The wrapped model. Returns a list of wrapped modules or a single wrapped module.
"""
Float16Wrapper = MCoreFloat16Module if self.mcore_gpt else Float16Module

nemo_args = {
'config': self.model_parallel_config,
'precision': self.cfg.precision,
'share_token_embeddings': self.cfg.get('share_embeddings_and_output_weights', True),
}
mcore_args = {
'config': self.transformer_config,
}

args = mcore_args if self.mcore_gpt else nemo_args

# Model wrapper to convert both model and inputs to half precision
if isinstance(self.model, list):
converted_model = []
for module in self.model:
args['module'] = module
converted_model.append(Float16Wrapper(**args))
self.model = converted_model
else:
args['module'] = self.model
self.model = Float16Wrapper(**args)

args.pop('module')

0 comments on commit a6d262f

Please sign in to comment.