Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for Mixture of Experts #147

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,43 @@
from variations.activation_variations import SquaredReLU, activation_dictionary
from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary

class MoELayer(nn.Module):
def __init__(self, d_model, num_experts, dropout_rate=0.1):
super(MoELayer, self).__init__()
self.num_experts = num_experts
self.d_model = d_model
self.dropout = nn.Dropout(dropout_rate)

# Expert weights
self.experts = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_experts)])

# Gating network
self.gate = nn.Linear(d_model, num_experts)

def forward(self, x):
batch_size, seq_len, _ = x.size()

# Flatten the input tensor
x = x.view(-1, self.d_model)

# Get the expert outputs
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) # [batch_size * seq_len, num_experts, d_model]

# Apply the gating network
gate_outputs = self.gate(x) # [batch_size * seq_len, num_experts]
gate_outputs = F.softmax(gate_outputs, dim=-1)

# Apply dropout to the gate outputs
gate_outputs = self.dropout(gate_outputs)

# Combine the expert outputs based on the gate outputs
combined_output = torch.sum(expert_outputs * gate_outputs.unsqueeze(-1), dim=1) # [batch_size * seq_len, d_model]

# Reshape the output to the original shape
combined_output = combined_output.view(batch_size, seq_len, self.d_model)

return combined_output

def create_shared_param_group(layer_type, config):
shared_size = None
shared_sym = None # if true, output array is symmetrical
Expand Down Expand Up @@ -285,7 +322,6 @@ def forward(self, x):
return x

class Block(nn.Module):

def __init__(self, config, mlp=None, attn=None):
super().__init__()

Expand All @@ -303,31 +339,29 @@ def __init__(self, config, mlp=None, attn=None):
self.use_parallel_mlp = config.use_parallel_mlp

# Allow for sharing attn between blocks
if attn == None:
if attn is None:
self.attn = CausalSelfAttention(config)
else:
self.attn = attn

# Allow for sharing mlp between blocks
if mlp == None:
self.mlp = MLP(config)
else:
self.mlp = mlp
# Add the MoE layer
self.moe = MoELayer(config.n_embd, config.num_experts)

def forward(self, x):
if self.use_post_ln:
if self.use_parallel_mlp:
x = self.ln_1(x + self.attn(x) + self.mlp(x))
x = self.ln_1(x + self.attn(x) + self.moe(x))
else:
x = self.ln_1(x + self.attn(x))
x = self.ln_2(x + self.mlp(x))
x = self.ln_2(x + self.moe(x))
else:
if self.use_parallel_mlp:
ln_1 = self.ln_1(x)
x = x + self.attn(ln_1) + self.mlp(ln_1)
x = x + self.attn(ln_1) + self.moe(ln_1)
else:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
x = x + self.moe(self.ln_2(x))

return x

@dataclass
Expand All @@ -341,8 +375,10 @@ class GPTConfig:
dropout: float = 0.0
window_size: int = 128
gate: bool = False
num_experts: int = 16 # Default value for the number of experts

use_parallel_mlp: bool = False
use_mixture_of_experts: bool = False

# Shared parameters
# MLP
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def parse_args():
],
)

# MOE
model_group.add_argument('--use_mixture_of_experts', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument('--num_experts', type=int, default=16, help="number of experts if mixture of experts flag is set")

# POSITIONAL EMBEDDING VARIATIONS
model_group.add_argument('--use_rotary_embeddings', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument("--rope_variant", type=str, default="rope", choices=["shortrope", "rope"])
Expand Down
Loading