diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index c3046e4f01f4..fab1cc199bc3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -349,6 +349,7 @@ def build_transformer_config(self) -> TransformerConfig: activation = self.cfg.get('activation', 'gelu') # TODO: need to check which activation functions are supported in mcore activation_func = activation_to_func(activation) + gated_linear_unit = activation.endswith('glu') normalization = self.cfg.get('normalization', 'LayerNorm') @@ -396,7 +397,7 @@ def build_transformer_config(self) -> TransformerConfig: 'apply_residual_connection_post_layernorm': False, # we don't use this in NeMo 'layernorm_zero_centered_gamma': False, 'add_bias_linear': add_bias_linear, - 'gated_linear_unit': False, + 'gated_linear_unit': gated_linear_unit, 'activation_func': activation_func, 'normalization': normalization, 'init_method': init_method,