Skip to content

Commit

Permalink
Add dictionary method for Softmax variations
Browse files Browse the repository at this point in the history
This helps reduce the size of model.py significantly.
  • Loading branch information
gkielian committed Apr 19, 2024
1 parent c0e6f06 commit 84dbd02
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 44 deletions.
48 changes: 5 additions & 43 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.nn import functional as F

# Variations
from variations.softmax_variations import Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax, ExpPolymax, SaturatingConSmax
from variations.softmax_variations import softmax_dictionary, Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax, ExpPolymax, SaturatingConSmax
from variations.norm_variations import norm_dictionary, LayerNorm, RMSNorm, pRMSNorm, kRMSNorm
from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions, FIRE
from variations.activation_variations import SquaredReLU, activation_dictionary
Expand Down Expand Up @@ -137,30 +137,8 @@ def __init__(self, config, fire_pos_enc=None):
else:
# Remove flash attention (only compatible with 'softmax')
self.flash = False

if self.softmax_variant_attn == "softermax":
self.softmax_layer = Softermax(config)

if self.softmax_variant_attn == "constantmax":
self.softmax_layer = Constantmax(config)

if self.softmax_variant_attn == "constantmax_quan":
self.softmax_layer = Constantmax_quan(config)

if self.softmax_variant_attn == "strongermax":
self.softmax_layer = Strongermax(config)

if self.softmax_variant_attn == "polymax":
self.softmax_layer = Polymax(config)

if self.softmax_variant_attn == "sigsoftmax":
self.softmax_layer = SigSoftmax(config)

if self.softmax_variant_attn == "saturatingconsmax":
self.softmax_layer = SaturatingConSmax(config)

if self.softmax_variant_attn == "exppolymax":
self.softmax_layer = ExpPolymax(config)
# Set softmax_layer_attn to custom softmax alternative
self.softmax_layer_attn = softmax_dictionary[config.softmax_variant_attn](config)

if self.window_size is not None:
# TODO: look into supporting sliding window attn for flash attn
Expand Down Expand Up @@ -245,7 +223,7 @@ def forward(self, x):

# softmax variation
if self.softmax_variant_attn != 'softmax':
att = self.softmax_layer(att)
att = self.softmax_layer_attn(att)
else:
att = F.softmax(att, dim=-1)

Expand Down Expand Up @@ -435,23 +413,7 @@ def __init__(self, config):
# Select softmax variant for output layer
self.softmax_variant_output = config.softmax_variant_output
if self.softmax_variant_output != "softmax":
if self.softmax_variant_output == "softermax":
self.softmax_layer_output = Softermax(config)

if self.softmax_variant_output == "constantmax":
self.softmax_layer_output = Constantmax(config)

if self.softmax_variant_output == "constantmax_quan":
self.softmax_layer_output = Constantmax_quan(config)

if self.softmax_variant_output == "strongermax":
self.softmax_layer_output = Strongermax(config)

if self.softmax_variant_output == "polymax":
self.softmax_layer_output = Polymax(config)

if self.softmax_variant_output == "sigsoftmax":
self.softmax_layer_output = SigSoftmax(config)
self.softmax_layer_output = softmax_dictionary[config.softmax_variant_output](config)

self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# with weight tying when using torch.compile() some warnings get generated:
Expand Down
11 changes: 10 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,16 @@ def parse_args():
"exppolymax",
])
model_group.add_argument("--softmax_variant_output", type=str,
default="softmax", choices=["constantmax_quan", "constantmax", "polymax", "strongermax", "softermax", "sigsoftmax", "softmax"])
default="softmax", choices=["constantmax_quan",
"constantmax",
"polymax",
"strongermax",
"softermax",
"sigsoftmax",
"softmax",
"saturatingconsmax",
"exppolymax",
])

## Custom Softmax Variation Options
model_group.add_argument("--constantmax_initial_beta", type=float, default=2.5)
Expand Down
11 changes: 11 additions & 0 deletions variations/softmax_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,14 @@ def forward(self, inputs):

return numerator / denominator

# Note: we use the built in library for regular softmax
softmax_dictionary = {
"constantmax": Constantmax,
"constantmax_quan": Constantmax_quan,
"exppolymax": ExpPolymax,
"polymax": Polymax,
"saturatingconsmax": SaturatingConSmax,
"sigsoftmax": SigSoftmax,
"softermax": Softermax,
"strongermax": Strongermax,
}

0 comments on commit 84dbd02

Please sign in to comment.