diff --git a/eole/config/models.py b/eole/config/models.py index 359c6796..667c7916 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -1,10 +1,7 @@ from typing import Dict, Union, Literal, Any, Annotated from pydantic import Field, field_validator, model_validator # , TypeAdapter -from eole import constants -from eole.modules.transformer_mlp import ( - ActivationFunction, -) # might be better defined elsewhere +from eole.constants import PositionEncodingType, ActivationFunction from eole.config.config import Config @@ -30,8 +27,8 @@ class EmbeddingsConfig(Config): description="Use a sin to mark relative words positions. " "Necessary for non-RNN style models.", ) - position_encoding_type: constants.PositionEncodingType = Field( - default=constants.PositionEncodingType.SinusoidalInterleaved, + position_encoding_type: PositionEncodingType = Field( + default=PositionEncodingType.SinusoidalInterleaved, description="Type of positional encoding.", ) diff --git a/eole/constants.py b/eole/constants.py index a91b7432..95e1f2ed 100644 --- a/eole/constants.py +++ b/eole/constants.py @@ -1,5 +1,8 @@ """Define constant values used across the project.""" from enum import Enum +import torch +from eole.modules.rmsnorm import RMSNorm +import torch.nn.functional as F class DefaultTokens(object): @@ -46,3 +49,23 @@ class ModelTask(str, Enum): class PositionEncodingType(str, Enum): SinusoidalInterleaved = "SinusoidalInterleaved" SinusoidalConcat = "SinusoidalConcat" + + +class ActivationFunction(str, Enum): + relu = "relu" + gelu = "gelu" + silu = "silu" + gated_gelu = "gated-gelu" + gated_silu = "gated-silu" + + +ACTIVATION_FUNCTIONS = { + ActivationFunction.relu: F.relu, + ActivationFunction.gelu: F.gelu, + ActivationFunction.silu: F.silu, + ActivationFunction.gated_gelu: F.gelu, + ActivationFunction.gated_silu: F.silu, +} + + +LayerNorm = {"standard": torch.nn.LayerNorm, "rms": RMSNorm} diff --git a/eole/decoders/cnn_decoder.py b/eole/decoders/cnn_decoder.py index 39cb09f3..aa186237 100644 --- a/eole/decoders/cnn_decoder.py +++ b/eole/decoders/cnn_decoder.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from eole.modules import ConvMultiStepAttention +from eole.modules.conv_multi_step_attention import ConvMultiStepAttention from eole.utils.cnn_factory import shape_transform, GatedConv from eole.decoders.decoder import DecoderBase diff --git a/eole/decoders/rnn_decoder.py b/eole/decoders/rnn_decoder.py index de062b25..8564ac4e 100644 --- a/eole/decoders/rnn_decoder.py +++ b/eole/decoders/rnn_decoder.py @@ -3,7 +3,8 @@ from eole.decoders.decoder import DecoderBase from eole.modules.stacked_rnn import StackedLSTM, StackedGRU -from eole.modules import context_gate_factory, GlobalAttention +from eole.modules.gate import context_gate_factory +from eole.modules.global_attention import GlobalAttention class RNNDecoderBase(DecoderBase): diff --git a/eole/decoders/transformer_base.py b/eole/decoders/transformer_base.py index a09a4aa5..679135b5 100644 --- a/eole/decoders/transformer_base.py +++ b/eole/decoders/transformer_base.py @@ -6,10 +6,11 @@ import torch import torch.nn as nn from eole.decoders.decoder import DecoderBase -from eole.modules import MultiHeadedAttention, AverageAttention +from eole.modules.multi_headed_attn import MultiHeadedAttention +from eole.modules.average_attn import AverageAttention from eole.modules.transformer_mlp import MLP from eole.modules.moe import MoE -from eole.modules.rmsnorm import RMSNorm +from eole.constants import LayerNorm class TransformerDecoderLayerBase(nn.Module): @@ -23,14 +24,7 @@ def __init__( model_config (eole.config.TransformerDecoderConfig): full decoder config """ super(TransformerDecoderLayerBase, self).__init__() - if model_config.layer_norm == "standard": - layernorm = nn.LayerNorm - elif model_config.layer_norm == "rms": - layernorm = RMSNorm - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) + self.parallel_residual = model_config.parallel_residual self.shared_layer_norm = model_config.shared_layer_norm self.dropout_p = getattr(running_config, "dropout", [0.0])[0] @@ -39,7 +33,7 @@ def __init__( self.sliding_window = model_config.sliding_window self.self_attn_type = model_config.self_attn_type - self.input_layernorm = layernorm( + self.input_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) if self.self_attn_type in ["scaled-dot", "scaled-dot-flash"]: @@ -55,11 +49,11 @@ def __init__( aan_useffn=model_config.aan_useffn, ) self.dropout = nn.Dropout(self.dropout_p) - self.post_attention_layernorm = layernorm( + self.post_attention_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) if model_config.parallel_residual and not model_config.shared_layer_norm: - self.residual_layernorm = layernorm( + self.residual_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) if model_config.num_experts > 0: diff --git a/eole/decoders/transformer_decoder.py b/eole/decoders/transformer_decoder.py index 536237a2..02a8a8f0 100644 --- a/eole/decoders/transformer_decoder.py +++ b/eole/decoders/transformer_decoder.py @@ -9,8 +9,9 @@ TransformerDecoderLayerBase, TransformerDecoderBase, ) -from eole.modules import MultiHeadedAttention, AverageAttention -from eole.modules.rmsnorm import RMSNorm +from eole.modules.multi_headed_attn import MultiHeadedAttention +from eole.modules.average_attn import AverageAttention +from eole.constants import LayerNorm class TransformerDecoderLayer(TransformerDecoderLayerBase): @@ -35,15 +36,8 @@ def __init__( model_config, running_config=running_config, ) - if model_config.layer_norm == "standard": - layernorm = nn.LayerNorm - elif model_config.layer_norm == "rms": - layernorm = RMSNorm - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) - self.precontext_layernorm = layernorm( + + self.precontext_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) self.context_attn = MultiHeadedAttention( @@ -151,14 +145,6 @@ def __init__( super(TransformerDecoder, self).__init__( model_config, running_config=running_config ) - if model_config.layer_norm == "standard": - layernorm = nn.LayerNorm - elif model_config.layer_norm == "rms": - layernorm = RMSNorm - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) self.transformer_layers = nn.ModuleList( [ @@ -170,7 +156,9 @@ def __init__( ] ) # This is the Decoder out layer norm - self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps) + self.layer_norm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) def forward(self, emb, **kwargs): """Decode, possibly stepwise.""" diff --git a/eole/decoders/transformer_lm_decoder.py b/eole/decoders/transformer_lm_decoder.py index 7f12689f..63787f61 100644 --- a/eole/decoders/transformer_lm_decoder.py +++ b/eole/decoders/transformer_lm_decoder.py @@ -10,7 +10,7 @@ TransformerDecoderLayerBase, TransformerDecoderBase, ) -from eole.modules.rmsnorm import RMSNorm +from eole.constants import LayerNorm class TransformerLMDecoderLayer(TransformerDecoderLayerBase): @@ -86,14 +86,7 @@ def __init__( running_config=None, ): super(TransformerLMDecoder, self).__init__(model_config) - if model_config.layer_norm == "standard": - layernorm = nn.LayerNorm - elif model_config.layer_norm == "rms": - layernorm = RMSNorm - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) + self.transformer_layers = nn.ModuleList( [ TransformerLMDecoderLayer( @@ -104,7 +97,9 @@ def __init__( ] ) # This is the Decoder out layer norm - self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps) + self.layer_norm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) def forward(self, emb, **kwargs): """Decode, possibly stepwise.""" diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 0c592e8c..e51a141b 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -5,10 +5,9 @@ import torch.nn as nn from eole.encoders.encoder import EncoderBase -from eole.modules import MultiHeadedAttention +from eole.modules.multi_headed_attn import MultiHeadedAttention from eole.modules.transformer_mlp import MLP - -from eole.modules.rmsnorm import RMSNorm +from eole.constants import LayerNorm class TransformerEncoderLayer(nn.Module): @@ -26,18 +25,10 @@ def __init__( running_config=None, ): super(TransformerEncoderLayer, self).__init__() - if model_config.layer_norm == "standard": - layernorm = nn.LayerNorm - elif model_config.layer_norm == "rms": - layernorm = RMSNorm - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) + self.parallel_residual = model_config.parallel_residual self.dropout_p = getattr(running_config, "dropout", [0.0])[0] - - self.input_layernorm = layernorm( + self.input_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) self.self_attn = MultiHeadedAttention( @@ -47,7 +38,7 @@ def __init__( attn_type="self", ) self.dropout = nn.Dropout(self.dropout_p) - self.post_attention_layernorm = layernorm( + self.post_attention_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) self.mlp = MLP( @@ -120,18 +111,9 @@ def __init__( ] ) # This is the Encoder out layer norm - if model_config.layer_norm == "standard": - self.layer_norm = nn.LayerNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - elif model_config.layer_norm == "rms": - self.layer_norm = RMSNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) + self.layer_norm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) @classmethod def from_config(cls, model_config, running_config=None): diff --git a/eole/models/model.py b/eole/models/model.py index 0838c5eb..8c347b21 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -20,7 +20,7 @@ from eole.encoders import str2enc from eole.decoders import str2dec from eole.constants import DefaultTokens -from eole.modules import Embeddings +from eole.modules.embeddings import Embeddings from eole.models.model_saver import load_checkpoint from eole.modules.estimator import FeedForward diff --git a/eole/modules/__init__.py b/eole/modules/__init__.py deleted file mode 100644 index e19d6603..00000000 --- a/eole/modules/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -""" Attention and normalization modules """ -from eole.modules.gate import context_gate_factory, ContextGate -from eole.modules.global_attention import GlobalAttention -from eole.modules.conv_multi_step_attention import ConvMultiStepAttention -from eole.modules.multi_headed_attn import MultiHeadedAttention -from eole.modules.embeddings import Embeddings, PositionalEncoding -from eole.modules.weight_norm import WeightNormConv2d -from eole.modules.average_attn import AverageAttention -from eole.modules.alibi_position_bias import AlibiPositionalBias -from eole.modules.rmsnorm import RMSNorm - - -__all__ = [ - "context_gate_factory", - "ContextGate", - "GlobalAttention", - "ConvMultiStepAttention", - "MultiHeadedAttention", - "Embeddings", - "PositionalEncoding", - "AlibiPositionalBias", - "WeightNormConv2d", - "AverageAttention", - "RMSNorm", -] diff --git a/eole/modules/average_attn.py b/eole/modules/average_attn.py index d20dc76a..fcdb39ba 100644 --- a/eole/modules/average_attn.py +++ b/eole/modules/average_attn.py @@ -6,7 +6,7 @@ from torch import Tensor from typing import Optional from eole.modules.transformer_mlp import MLP -from eole.modules.transformer_mlp import ActivationFunction +from eole.constants import ActivationFunction def cumulative_average_mask( diff --git a/eole/modules/transformer_mlp.py b/eole/modules/transformer_mlp.py index 27239645..ecb6bfee 100644 --- a/eole/modules/transformer_mlp.py +++ b/eole/modules/transformer_mlp.py @@ -1,29 +1,11 @@ """MLP network from "Attention is All You Need".""" import torch.nn as nn -import torch.nn.functional as F + from torch.utils.checkpoint import checkpoint from torch.nn.utils import skip_init from torch.distributed import all_reduce -from enum import Enum - - -class ActivationFunction(str, Enum): - relu = "relu" - gelu = "gelu" - silu = "silu" - gated_gelu = "gated-gelu" - gated_silu = "gated-silu" - - -# for silu, see: https://arxiv.org/pdf/2002.05202.pdf -ACTIVATION_FUNCTIONS = { - ActivationFunction.relu: F.relu, - ActivationFunction.gelu: F.gelu, - ActivationFunction.silu: F.silu, - ActivationFunction.gated_gelu: F.gelu, - ActivationFunction.gated_silu: F.silu, -} +from eole.constants import ACTIVATION_FUNCTIONS class MLP(nn.Module): diff --git a/eole/tests/test_attention.py b/eole/tests/test_attention.py index 33bc80ff..83b5843e 100644 --- a/eole/tests/test_attention.py +++ b/eole/tests/test_attention.py @@ -23,7 +23,7 @@ def test_masked_global_attention(self): enc_out = Variable(torch.randn(batch_size, src_len.max(), dim)) enc_final_hs = Variable(torch.randn(batch_size, dim)) - attn = eole.modules.GlobalAttention(dim) + attn = eole.modules.global_attention.GlobalAttention(dim) _, alignments = attn(enc_final_hs, enc_out, src_len=src_len) # TODO: fix for pytorch 0.3 diff --git a/eole/utils/cnn_factory.py b/eole/utils/cnn_factory.py index f3352e4f..3275a907 100644 --- a/eole/utils/cnn_factory.py +++ b/eole/utils/cnn_factory.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.init as init -import eole.modules +from eole.modules.weight_norm import WeightNormConv2d SCALE_WEIGHT = 0.5**0.5 @@ -20,7 +20,7 @@ class GatedConv(nn.Module): def __init__(self, input_size, width=3, dropout=0.2, nopad=False): super(GatedConv, self).__init__() - self.conv = eole.modules.WeightNormConv2d( + self.conv = WeightNormConv2d( input_size, 2 * input_size, kernel_size=(width, 1),