diff --git a/model.py b/model.py index 648f7d50e..942070909 100644 --- a/model.py +++ b/model.py @@ -17,7 +17,7 @@ from torch.nn import functional as F # Variations -from variations.softmax_variations import Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax, ExpPolymax, SaturatingConSmax +from variations.softmax_variations import softmax_dictionary, Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax, ExpPolymax, SaturatingConSmax from variations.norm_variations import norm_dictionary, LayerNorm, RMSNorm, pRMSNorm, kRMSNorm from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions, FIRE from variations.activation_variations import SquaredReLU, activation_dictionary @@ -137,30 +137,8 @@ def __init__(self, config, fire_pos_enc=None): else: # Remove flash attention (only compatible with 'softmax') self.flash = False - - if self.softmax_variant_attn == "softermax": - self.softmax_layer = Softermax(config) - - if self.softmax_variant_attn == "constantmax": - self.softmax_layer = Constantmax(config) - - if self.softmax_variant_attn == "constantmax_quan": - self.softmax_layer = Constantmax_quan(config) - - if self.softmax_variant_attn == "strongermax": - self.softmax_layer = Strongermax(config) - - if self.softmax_variant_attn == "polymax": - self.softmax_layer = Polymax(config) - - if self.softmax_variant_attn == "sigsoftmax": - self.softmax_layer = SigSoftmax(config) - - if self.softmax_variant_attn == "saturatingconsmax": - self.softmax_layer = SaturatingConSmax(config) - - if self.softmax_variant_attn == "exppolymax": - self.softmax_layer = ExpPolymax(config) + # Set softmax_layer_attn to custom softmax alternative + self.softmax_layer_attn = softmax_dictionary[config.softmax_variant_attn](config) if self.window_size is not None: # TODO: look into supporting sliding window attn for flash attn @@ -245,7 +223,7 @@ def forward(self, x): # softmax variation if self.softmax_variant_attn != 'softmax': - att = self.softmax_layer(att) + att = self.softmax_layer_attn(att) else: att = F.softmax(att, dim=-1) @@ -435,23 +413,7 @@ def __init__(self, config): # Select softmax variant for output layer self.softmax_variant_output = config.softmax_variant_output if self.softmax_variant_output != "softmax": - if self.softmax_variant_output == "softermax": - self.softmax_layer_output = Softermax(config) - - if self.softmax_variant_output == "constantmax": - self.softmax_layer_output = Constantmax(config) - - if self.softmax_variant_output == "constantmax_quan": - self.softmax_layer_output = Constantmax_quan(config) - - if self.softmax_variant_output == "strongermax": - self.softmax_layer_output = Strongermax(config) - - if self.softmax_variant_output == "polymax": - self.softmax_layer_output = Polymax(config) - - if self.softmax_variant_output == "sigsoftmax": - self.softmax_layer_output = SigSoftmax(config) + self.softmax_layer_output = softmax_dictionary[config.softmax_variant_output](config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: diff --git a/train.py b/train.py index cc083f581..f267f4bf0 100644 --- a/train.py +++ b/train.py @@ -133,7 +133,16 @@ def parse_args(): "exppolymax", ]) model_group.add_argument("--softmax_variant_output", type=str, - default="softmax", choices=["constantmax_quan", "constantmax", "polymax", "strongermax", "softermax", "sigsoftmax", "softmax"]) + default="softmax", choices=["constantmax_quan", + "constantmax", + "polymax", + "strongermax", + "softermax", + "sigsoftmax", + "softmax", + "saturatingconsmax", + "exppolymax", + ]) ## Custom Softmax Variation Options model_group.add_argument("--constantmax_initial_beta", type=float, default=2.5) diff --git a/variations/softmax_variations.py b/variations/softmax_variations.py index efcddc25c..1759c53c1 100644 --- a/variations/softmax_variations.py +++ b/variations/softmax_variations.py @@ -238,3 +238,14 @@ def forward(self, inputs): return numerator / denominator +# Note: we use the built in library for regular softmax +softmax_dictionary = { + "constantmax": Constantmax, + "constantmax_quan": Constantmax_quan, + "exppolymax": ExpPolymax, + "polymax": Polymax, + "saturatingconsmax": SaturatingConSmax, + "sigsoftmax": SigSoftmax, + "softermax": Softermax, + "strongermax": Strongermax, +}