Skip to content

Commit

Permalink
simplify LayerNorm access as a constant (#7)
Browse files Browse the repository at this point in the history
* simplify LayerNorm access as a constant - remone __init__.py in modules
  • Loading branch information
vince62s authored Jun 6, 2024
1 parent 6a0b111 commit 9a299df
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 127 deletions.
9 changes: 3 additions & 6 deletions eole/config/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Dict, Union, Literal, Any, Annotated
from pydantic import Field, field_validator, model_validator # , TypeAdapter

from eole import constants
from eole.modules.transformer_mlp import (
ActivationFunction,
) # might be better defined elsewhere
from eole.constants import PositionEncodingType, ActivationFunction
from eole.config.config import Config


Expand All @@ -30,8 +27,8 @@ class EmbeddingsConfig(Config):
description="Use a sin to mark relative words positions. "
"Necessary for non-RNN style models.",
)
position_encoding_type: constants.PositionEncodingType = Field(
default=constants.PositionEncodingType.SinusoidalInterleaved,
position_encoding_type: PositionEncodingType = Field(
default=PositionEncodingType.SinusoidalInterleaved,
description="Type of positional encoding.",
)

Expand Down
23 changes: 23 additions & 0 deletions eole/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Define constant values used across the project."""
from enum import Enum
import torch
from eole.modules.rmsnorm import RMSNorm
import torch.nn.functional as F


class DefaultTokens(object):
Expand Down Expand Up @@ -46,3 +49,23 @@ class ModelTask(str, Enum):
class PositionEncodingType(str, Enum):
SinusoidalInterleaved = "SinusoidalInterleaved"
SinusoidalConcat = "SinusoidalConcat"


class ActivationFunction(str, Enum):
relu = "relu"
gelu = "gelu"
silu = "silu"
gated_gelu = "gated-gelu"
gated_silu = "gated-silu"


ACTIVATION_FUNCTIONS = {
ActivationFunction.relu: F.relu,
ActivationFunction.gelu: F.gelu,
ActivationFunction.silu: F.silu,
ActivationFunction.gated_gelu: F.gelu,
ActivationFunction.gated_silu: F.silu,
}


LayerNorm = {"standard": torch.nn.LayerNorm, "rms": RMSNorm}
2 changes: 1 addition & 1 deletion eole/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

from eole.modules import ConvMultiStepAttention
from eole.modules.conv_multi_step_attention import ConvMultiStepAttention
from eole.utils.cnn_factory import shape_transform, GatedConv
from eole.decoders.decoder import DecoderBase

Expand Down
3 changes: 2 additions & 1 deletion eole/decoders/rnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from eole.decoders.decoder import DecoderBase
from eole.modules.stacked_rnn import StackedLSTM, StackedGRU
from eole.modules import context_gate_factory, GlobalAttention
from eole.modules.gate import context_gate_factory
from eole.modules.global_attention import GlobalAttention


class RNNDecoderBase(DecoderBase):
Expand Down
20 changes: 7 additions & 13 deletions eole/decoders/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch
import torch.nn as nn
from eole.decoders.decoder import DecoderBase
from eole.modules import MultiHeadedAttention, AverageAttention
from eole.modules.multi_headed_attn import MultiHeadedAttention
from eole.modules.average_attn import AverageAttention
from eole.modules.transformer_mlp import MLP
from eole.modules.moe import MoE
from eole.modules.rmsnorm import RMSNorm
from eole.constants import LayerNorm


class TransformerDecoderLayerBase(nn.Module):
Expand All @@ -23,14 +24,7 @@ def __init__(
model_config (eole.config.TransformerDecoderConfig): full decoder config
"""
super(TransformerDecoderLayerBase, self).__init__()
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.parallel_residual = model_config.parallel_residual
self.shared_layer_norm = model_config.shared_layer_norm
self.dropout_p = getattr(running_config, "dropout", [0.0])[0]
Expand All @@ -39,7 +33,7 @@ def __init__(
self.sliding_window = model_config.sliding_window
self.self_attn_type = model_config.self_attn_type

self.input_layernorm = layernorm(
self.input_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
if self.self_attn_type in ["scaled-dot", "scaled-dot-flash"]:
Expand All @@ -55,11 +49,11 @@ def __init__(
aan_useffn=model_config.aan_useffn,
)
self.dropout = nn.Dropout(self.dropout_p)
self.post_attention_layernorm = layernorm(
self.post_attention_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.parallel_residual and not model_config.shared_layer_norm:
self.residual_layernorm = layernorm(
self.residual_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.num_experts > 0:
Expand Down
28 changes: 8 additions & 20 deletions eole/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
TransformerDecoderLayerBase,
TransformerDecoderBase,
)
from eole.modules import MultiHeadedAttention, AverageAttention
from eole.modules.rmsnorm import RMSNorm
from eole.modules.multi_headed_attn import MultiHeadedAttention
from eole.modules.average_attn import AverageAttention
from eole.constants import LayerNorm


class TransformerDecoderLayer(TransformerDecoderLayerBase):
Expand All @@ -35,15 +36,8 @@ def __init__(
model_config,
running_config=running_config,
)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)
self.precontext_layernorm = layernorm(

self.precontext_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
self.context_attn = MultiHeadedAttention(
Expand Down Expand Up @@ -151,14 +145,6 @@ def __init__(
super(TransformerDecoder, self).__init__(
model_config, running_config=running_config
)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.transformer_layers = nn.ModuleList(
[
Expand All @@ -170,7 +156,9 @@ def __init__(
]
)
# This is the Decoder out layer norm
self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps)
self.layer_norm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)

def forward(self, emb, **kwargs):
"""Decode, possibly stepwise."""
Expand Down
15 changes: 5 additions & 10 deletions eole/decoders/transformer_lm_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TransformerDecoderLayerBase,
TransformerDecoderBase,
)
from eole.modules.rmsnorm import RMSNorm
from eole.constants import LayerNorm


class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
Expand Down Expand Up @@ -86,14 +86,7 @@ def __init__(
running_config=None,
):
super(TransformerLMDecoder, self).__init__(model_config)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.transformer_layers = nn.ModuleList(
[
TransformerLMDecoderLayer(
Expand All @@ -104,7 +97,9 @@ def __init__(
]
)
# This is the Decoder out layer norm
self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps)
self.layer_norm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)

def forward(self, emb, **kwargs):
"""Decode, possibly stepwise."""
Expand Down
34 changes: 8 additions & 26 deletions eole/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import torch.nn as nn

from eole.encoders.encoder import EncoderBase
from eole.modules import MultiHeadedAttention
from eole.modules.multi_headed_attn import MultiHeadedAttention
from eole.modules.transformer_mlp import MLP

from eole.modules.rmsnorm import RMSNorm
from eole.constants import LayerNorm


class TransformerEncoderLayer(nn.Module):
Expand All @@ -26,18 +25,10 @@ def __init__(
running_config=None,
):
super(TransformerEncoderLayer, self).__init__()
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.parallel_residual = model_config.parallel_residual
self.dropout_p = getattr(running_config, "dropout", [0.0])[0]

self.input_layernorm = layernorm(
self.input_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
self.self_attn = MultiHeadedAttention(
Expand All @@ -47,7 +38,7 @@ def __init__(
attn_type="self",
)
self.dropout = nn.Dropout(self.dropout_p)
self.post_attention_layernorm = layernorm(
self.post_attention_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
self.mlp = MLP(
Expand Down Expand Up @@ -120,18 +111,9 @@ def __init__(
]
)
# This is the Encoder out layer norm
if model_config.layer_norm == "standard":
self.layer_norm = nn.LayerNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
elif model_config.layer_norm == "rms":
self.layer_norm = RMSNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)
self.layer_norm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)

@classmethod
def from_config(cls, model_config, running_config=None):
Expand Down
2 changes: 1 addition & 1 deletion eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from eole.encoders import str2enc
from eole.decoders import str2dec
from eole.constants import DefaultTokens
from eole.modules import Embeddings
from eole.modules.embeddings import Embeddings

from eole.models.model_saver import load_checkpoint
from eole.modules.estimator import FeedForward
Expand Down
25 changes: 0 additions & 25 deletions eole/modules/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion eole/modules/average_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
from typing import Optional
from eole.modules.transformer_mlp import MLP
from eole.modules.transformer_mlp import ActivationFunction
from eole.constants import ActivationFunction


def cumulative_average_mask(
Expand Down
22 changes: 2 additions & 20 deletions eole/modules/transformer_mlp.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,11 @@
"""MLP network from "Attention is All You Need"."""

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.checkpoint import checkpoint
from torch.nn.utils import skip_init
from torch.distributed import all_reduce
from enum import Enum


class ActivationFunction(str, Enum):
relu = "relu"
gelu = "gelu"
silu = "silu"
gated_gelu = "gated-gelu"
gated_silu = "gated-silu"


# for silu, see: https://arxiv.org/pdf/2002.05202.pdf
ACTIVATION_FUNCTIONS = {
ActivationFunction.relu: F.relu,
ActivationFunction.gelu: F.gelu,
ActivationFunction.silu: F.silu,
ActivationFunction.gated_gelu: F.gelu,
ActivationFunction.gated_silu: F.silu,
}
from eole.constants import ACTIVATION_FUNCTIONS


class MLP(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion eole/tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_masked_global_attention(self):
enc_out = Variable(torch.randn(batch_size, src_len.max(), dim))
enc_final_hs = Variable(torch.randn(batch_size, dim))

attn = eole.modules.GlobalAttention(dim)
attn = eole.modules.global_attention.GlobalAttention(dim)

_, alignments = attn(enc_final_hs, enc_out, src_len=src_len)
# TODO: fix for pytorch 0.3
Expand Down
4 changes: 2 additions & 2 deletions eole/utils/cnn_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
import torch.nn.init as init

import eole.modules
from eole.modules.weight_norm import WeightNormConv2d

SCALE_WEIGHT = 0.5**0.5

Expand All @@ -20,7 +20,7 @@ class GatedConv(nn.Module):

def __init__(self, input_size, width=3, dropout=0.2, nopad=False):
super(GatedConv, self).__init__()
self.conv = eole.modules.WeightNormConv2d(
self.conv = WeightNormConv2d(
input_size,
2 * input_size,
kernel_size=(width, 1),
Expand Down

0 comments on commit 9a299df

Please sign in to comment.