Skip to content

Commit

Permalink
Merge pull request #156 from gkielian/tidy_implementation
Browse files Browse the repository at this point in the history
Tidy implementation
  • Loading branch information
klei22 authored Apr 20, 2024
2 parents d2e1d7a + 7cfdcc2 commit 80996c6
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 89 deletions.
82 changes: 82 additions & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from dataclasses import dataclass

@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 12
n_head: int = 12
n_kv_group: int = 12
n_embd: int = 768
dropout: float = 0.0
window_size: int = 128
gate: bool = False

use_parallel_mlp: bool = False

# Shared parameters
# MLP
shared_mlp_size: int = 1
shared_mlp_sym: bool = False
# ATTN
shared_attn_size: int = 1
shared_attn_sym: bool = False

# Softmax Alternatives and Options
softmax_variant_attn: str = "softmax" # Choices: "softmax" "softermax" "sigsoftmax" "polymax" "strongermax" "constantmax"
softmax_variant_output: str = "softmax" # Choices: "softmax" "softermax" "sigsoftmax" "polymax" "strongermax" "constantmax"

## Constantmax Options
constantmax_initial_beta: float = 0.0 # denominator to utilize for Constantmax
constantmax_initial_gamma: float = 1.0 # denominator to utilize for Constantmax
constantmax_use_euler_base: bool = True # use 'e' as base for Constantmax
constantmax_base: float = 2.0 # denominator to utilize for Constantmax

## Softermax options
softermax_use_xmax: bool = True # Softermax Option active is softermax selected - True: uses (x - x_max) normalization; False: removes normalization (potential overflow)

## Polymax options
polymax_x_intercept: float = -100.0
polymax_y_intercept: float = 1.0
polymax_power: float = 2.0
polymax_divisor: float = 1000.0

## SigSoftmaxBase
sigsoftmax_use_euler_base: bool = True # use 'e' as base for Constantmax
sigsoftmax_base: float = 2.0 # denominator to utilize for Constantmax

## Strongermax options
strongermax_strength: float = 2.0 # Softermax with option of 'stronger' (larger integer) bases
strongermax_sum_to_1: bool = False # Softermax with option of 'stronger' (larger integer) bases
strongermax_divisor: float = 1.0 # Softermax with option of 'stronger' (larger integer) bases
strongermax_use_xmax: bool = True # Softermax with option of 'stronger' (larger integer) bases

## ExpPolymax options
exppolymax_base: float = 2.719
exppolymax_y_intercept: float = 1.0
exppolymax_power: float = 2.0
exppolymax_divisor: float = 1.0

# Positional Embeddings Variations
use_abs_pos_embeddings: bool = True # Note: one can use this AND rotary embeddings
use_fire_embeddings: bool = False
shared_fire_embeddings: bool = False
use_rotary_embeddings: bool = False
rope_variant: str = "rope" # options: "shortrope", "rope"
shortrope_length: int = 8 # number of embeddings to use in shortrope

# Structuring Options, remember to compile the model
use_post_ln: bool = True

# Layernorm Alternatives and Options
norm_variant_attn: str = "rmsnorm"
norm_variant_output: str = "rmsnorm"
bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
prmsnorm_pct: float = 0.0625
krmsnorm_num: float = 10

# Activation Alternatives
activation_variant: str = "gelu"

# Linear Alternatives
linear_variant: str = "linear"
84 changes: 3 additions & 81 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import math
import inspect
import sys
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

# Config
from gpt_conf import GPTConfig

# Variations
from variations.softmax_variations import softmax_dictionary, Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax, ExpPolymax, SaturatingConSmax
from variations.norm_variations import norm_dictionary, LayerNorm, RMSNorm, pRMSNorm, kRMSNorm
Expand Down Expand Up @@ -304,86 +306,6 @@ def forward(self, x):
x = x + self.mlp(self.ln_2(x))
return x

@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 12
n_head: int = 12
n_kv_group: int = 12
n_embd: int = 768
dropout: float = 0.0
window_size: int = 128
gate: bool = False

use_parallel_mlp: bool = False

# Shared parameters
# MLP
shared_mlp_size: int = 1
shared_mlp_sym: bool = False
# ATTN
shared_attn_size: int = 1
shared_attn_sym: bool = False

# Softmax Alternatives and Options
softmax_variant_attn: str = "softmax" # Choices: "softmax" "softermax" "sigsoftmax" "polymax" "strongermax" "constantmax"
softmax_variant_output: str = "softmax" # Choices: "softmax" "softermax" "sigsoftmax" "polymax" "strongermax" "constantmax"

## Constantmax Options
constantmax_initial_beta: float = 0.0 # denominator to utilize for Constantmax
constantmax_initial_gamma: float = 1.0 # denominator to utilize for Constantmax
constantmax_use_euler_base: bool = True # use 'e' as base for Constantmax
constantmax_base: float = 2.0 # denominator to utilize for Constantmax

## Softermax options
softermax_use_xmax: bool = True # Softermax Option active is softermax selected - True: uses (x - x_max) normalization; False: removes normalization (potential overflow)

## Polymax options
polymax_x_intercept: float = -100.0
polymax_y_intercept: float = 1.0
polymax_power: float = 2.0
polymax_divisor: float = 1000.0

## SigSoftmaxBase
sigsoftmax_use_euler_base: bool = True # use 'e' as base for Constantmax
sigsoftmax_base: float = 2.0 # denominator to utilize for Constantmax

## Strongermax options
strongermax_strength: float = 2.0 # Softermax with option of 'stronger' (larger integer) bases
strongermax_sum_to_1: bool = False # Softermax with option of 'stronger' (larger integer) bases
strongermax_divisor: float = 1.0 # Softermax with option of 'stronger' (larger integer) bases
strongermax_use_xmax: bool = True # Softermax with option of 'stronger' (larger integer) bases

## ExpPolymax options
exppolymax_base: float = 2.719
exppolymax_y_intercept: float = 1.0
exppolymax_power: float = 2.0
exppolymax_divisor: float = 1.0

# Positional Embeddings Variations
use_abs_pos_embeddings: bool = True # Note: one can use this AND rotary embeddings
use_fire_embeddings: bool = False
shared_fire_embeddings: bool = False
use_rotary_embeddings: bool = False
rope_variant: str = "rope" # options: "shortrope", "rope"
shortrope_length: int = 8 # number of embeddings to use in shortrope

# Structuring Options, remember to compile the model
use_post_ln: bool = True

# Layernorm Alternatives and Options
norm_variant_attn: str = "rmsnorm"
norm_variant_output: str = "rmsnorm"
bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
prmsnorm_pct: float = 0.0625
krmsnorm_num: float = 10

# Activation Alternatives
activation_variant: str = "gelu"

# Linear Alternatives
linear_variant: str = "linear"

class GPT(nn.Module):

Expand Down
20 changes: 17 additions & 3 deletions run_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import argparse
from datetime import datetime
from itertools import product
from rich import print
from rich.console import Console
from rich.table import Table

def parse_args():
parser = argparse.ArgumentParser(description="Run experiments based on a json configuration file.")
Expand Down Expand Up @@ -97,7 +100,7 @@ def format_config_name(config, config_basename, prefix, add_names):

return f"{prefix}{config_basename}-{'-'.join(config_items)}"

def run_command(config, config_basename, output_dir, csv_ckpt_dir, prefix, add_names,
def run_command(config, config_basename, output_dir, csv_ckpt_dir, prefix, add_names,
best_val_loss_from, override_max_iters, override_dataset, override_block_size):
formatted_name = format_config_name(config, config_basename, prefix, add_names)
base_command = ["python3", "train.py"]
Expand All @@ -113,6 +116,17 @@ def run_command(config, config_basename, output_dir, csv_ckpt_dir, prefix, add_n
if override_block_size:
config['block_size'] = str(override_block_size)

# Print the entered arguments before each run
console = Console()
table = Table(title="Entered Arguments", show_header=True, header_style="bold magenta")
table.add_column("Argument", style="cyan")
table.add_column("Value", style="green")

for key, value in config.items():
table.add_row(key, str(value))

console.print(table)

for key, value in config.items():
if isinstance(value, bool):
base_command.extend([f"--{'' if value else 'no-'}{key}"])
Expand Down Expand Up @@ -142,8 +156,8 @@ def main():

for config in original_configurations:
for combination in generate_combinations(config):
run_command(combination, config_basename, args.output_dir, args.csv_ckpt_dir,
args.prefix, args.add_names, args.use_best_val_loss_from,
run_command(combination, config_basename, args.output_dir, args.csv_ckpt_dir,
args.prefix, args.add_names, args.use_best_val_loss_from,
args.override_max_iters, args.override_dataset, args.override_block_size)

if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
from rich import print
import os
import sys
import time
import csv
from datetime import datetime
Expand Down Expand Up @@ -201,8 +200,8 @@ def parse_args():

# CSV logging
logging_group.add_argument('--csv_log', default=True, action=argparse.BooleanOptionalAction)
training_group.add_argument('--csv_dir', default='csv_logs', type=str)
training_group.add_argument('--csv_name', default='output', type=str, help="Output csv basename. Note, the .csv will be automatically appended.")
logging_group.add_argument('--csv_dir', default='csv_logs', type=str)
logging_group.add_argument('--csv_name', default='output', type=str, help="Output csv basename. Note, the .csv will be automatically appended.")

# Tensorboard args
logging_group.add_argument('--tensorboard_log', default=True, action=argparse.BooleanOptionalAction)
Expand All @@ -229,7 +228,6 @@ def setup(self):
self.ddp = int(os.environ.get('RANK', -1)) != -1
if self.ddp:
init_process_group(backend=self.args.backend)
print(self.args)
self.ddp_rank = int(os.environ['RANK'])
self.ddp_local_rank = int(os.environ['LOCAL_RANK'])
self.ddp_world_size = int(os.environ['WORLD_SIZE'])
Expand Down Expand Up @@ -263,7 +261,6 @@ def setup(self):
# Model
# TODO only add if they are defined from the argparse
self.model_args = {action.dest: getattr(self.args, action.dest) for action in self.model_group._group_actions}
print(self.model_args)
self.model_args['vocab_size'] = None

if self.args.init_from == 'scratch':
Expand Down

0 comments on commit 80996c6

Please sign in to comment.