Skip to content

Commit

Permalink
Merge pull request #155 from gkielian/add_softmax_variation_dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
klei22 authored Apr 19, 2024
2 parents 42a4a05 + dc25597 commit d2e1d7a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 49 deletions.
22 changes: 22 additions & 0 deletions explorations/strongermax_sweep.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[
{
"max_iters": ["3500"],
"n_layer": ["6"],
"n_kv_group": ["6"],
"n_head": ["6"],
"n_embd": ["384"],
"block_size":["256"],
"use_post_ln": [false],
"device": ["cuda"],
"dtype": ["float16"],
"dataset": ["shakespeare_char"],
"use_rotary_embeddings": [false],
"use_abs_pos_embeddings": [true],
"compile": [true],
"softmax_variant_attn": ["strongermax"],
"strongermax_strength": ["1.5", "2", "2.719", "3", "4", "5"],
"strongermax_divisor": ["1.0", "10.0", "100.0", "1000.0"],
"strongermax_sum_to_1": [true, false]
}
]

7 changes: 6 additions & 1 deletion inspect_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def main():
elif args.sort == 'iter':
ckpt_data.sort(key=lambda x: x[2], reverse=args.reverse)

console = Console()
console = None
# Check if the TERM environment variable is set to a value that supports ANSI escape codes
if 'TERM' in os.environ and os.environ['TERM'] in ['xterm', 'xterm-color', 'xterm-256color', 'screen', 'screen-256color', 'tmux', 'tmux-256color']:
console = Console(color_system="standard")
else:
console = Console()

# Determine the maximum length of the checkpoint file paths
max_path_length = max(len(ckpt_file) for ckpt_file, _, _ in ckpt_data)
Expand Down
48 changes: 5 additions & 43 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.nn import functional as F

# Variations
from variations.softmax_variations import Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax, ExpPolymax, SaturatingConSmax
from variations.softmax_variations import softmax_dictionary, Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax, ExpPolymax, SaturatingConSmax
from variations.norm_variations import norm_dictionary, LayerNorm, RMSNorm, pRMSNorm, kRMSNorm
from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions, FIRE
from variations.activation_variations import SquaredReLU, activation_dictionary
Expand Down Expand Up @@ -137,30 +137,8 @@ def __init__(self, config, fire_pos_enc=None):
else:
# Remove flash attention (only compatible with 'softmax')
self.flash = False

if self.softmax_variant_attn == "softermax":
self.softmax_layer = Softermax(config)

if self.softmax_variant_attn == "constantmax":
self.softmax_layer = Constantmax(config)

if self.softmax_variant_attn == "constantmax_quan":
self.softmax_layer = Constantmax_quan(config)

if self.softmax_variant_attn == "strongermax":
self.softmax_layer = Strongermax(config)

if self.softmax_variant_attn == "polymax":
self.softmax_layer = Polymax(config)

if self.softmax_variant_attn == "sigsoftmax":
self.softmax_layer = SigSoftmax(config)

if self.softmax_variant_attn == "saturatingconsmax":
self.softmax_layer = SaturatingConSmax(config)

if self.softmax_variant_attn == "exppolymax":
self.softmax_layer = ExpPolymax(config)
# Set softmax_layer_attn to custom softmax alternative
self.softmax_layer_attn = softmax_dictionary[config.softmax_variant_attn](config)

if self.window_size is not None:
# TODO: look into supporting sliding window attn for flash attn
Expand Down Expand Up @@ -245,7 +223,7 @@ def forward(self, x):

# softmax variation
if self.softmax_variant_attn != 'softmax':
att = self.softmax_layer(att)
att = self.softmax_layer_attn(att)
else:
att = F.softmax(att, dim=-1)

Expand Down Expand Up @@ -435,23 +413,7 @@ def __init__(self, config):
# Select softmax variant for output layer
self.softmax_variant_output = config.softmax_variant_output
if self.softmax_variant_output != "softmax":
if self.softmax_variant_output == "softermax":
self.softmax_layer_output = Softermax(config)

if self.softmax_variant_output == "constantmax":
self.softmax_layer_output = Constantmax(config)

if self.softmax_variant_output == "constantmax_quan":
self.softmax_layer_output = Constantmax_quan(config)

if self.softmax_variant_output == "strongermax":
self.softmax_layer_output = Strongermax(config)

if self.softmax_variant_output == "polymax":
self.softmax_layer_output = Polymax(config)

if self.softmax_variant_output == "sigsoftmax":
self.softmax_layer_output = SigSoftmax(config)
self.softmax_layer_output = softmax_dictionary[config.softmax_variant_output](config)

self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# with weight tying when using torch.compile() some warnings get generated:
Expand Down
17 changes: 13 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,16 @@ def parse_args():
"exppolymax",
])
model_group.add_argument("--softmax_variant_output", type=str,
default="softmax", choices=["constantmax_quan", "constantmax", "polymax", "strongermax", "softermax", "sigsoftmax", "softmax"])
default="softmax", choices=["constantmax_quan",
"constantmax",
"polymax",
"strongermax",
"softermax",
"sigsoftmax",
"softmax",
"saturatingconsmax",
"exppolymax",
])

## Custom Softmax Variation Options
model_group.add_argument("--constantmax_initial_beta", type=float, default=2.5)
Expand All @@ -150,15 +159,15 @@ def parse_args():
model_group.add_argument('--sigsoftmax_use_euler_base', default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument("--sigsoftmax_base", type=float, default=2.0)

model_group.add_argument("--strongermax_strength", type=float, default=2.0)
model_group.add_argument("--strongermax_strength", type=float, default=4.0)
model_group.add_argument('--strongermax_sum_to_1', default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument("--strongermax_divisor", type=float, default=1.0)
model_group.add_argument('--strongermax_use_xmax', default=True, action=argparse.BooleanOptionalAction)

model_group.add_argument("--exppolymax_base", type=float, default="2.719")
model_group.add_argument("--exppolymax_base", type=float, default="4")
model_group.add_argument("--exppolymax_y_intercept", type=float, default=1.0)
model_group.add_argument("--exppolymax_power", type=float, default=2.0)
model_group.add_argument("--exppolymax_divisor", type=float, default=1.0)
model_group.add_argument("--exppolymax_divisor", type=float, default=1000.0)

# Softermax Specific Options
model_group.add_argument('--softermax_use_xmax', default=True, action=argparse.BooleanOptionalAction)
Expand Down
13 changes: 12 additions & 1 deletion variations/softmax_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def forward(self, x):
# Polynomial section: 0 < x < inf

# Exponential section
exponential_piece = torch.where((x < 0), torch.pow(self.constantmax_base, x), torch.tensor(0.0, device=x.device))
exponential_piece = torch.where((x < 0), torch.pow(self.exppolymax_base, x), torch.tensor(0.0, device=x.device))

# Polynomial section
poly_piece = torch.where(x > 0, x**self.power + self.y_intercept, torch.tensor(0.0, device=x.device))
Expand Down Expand Up @@ -238,3 +238,14 @@ def forward(self, inputs):

return numerator / denominator

# Note: we use the built in library for regular softmax
softmax_dictionary = {
"constantmax": Constantmax,
"constantmax_quan": Constantmax_quan,
"exppolymax": ExpPolymax,
"polymax": Polymax,
"saturatingconsmax": SaturatingConSmax,
"sigsoftmax": SigSoftmax,
"softermax": Softermax,
"strongermax": Strongermax,
}

0 comments on commit d2e1d7a

Please sign in to comment.