Skip to content

Commit

Permalink
rename loss_coef to balance_loss_coef, sum the balance and router z-l…
Browse files Browse the repository at this point in the history
…oss and return the total auxiliary loss and add some comments in readme on what to do with it
  • Loading branch information
lucidrains committed Sep 11, 2023
1 parent 240414a commit 5d5f071
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ moe = MoE(
threshold_eval = 0.2,
capacity_factor_train = 1.25, # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
capacity_factor_eval = 2., # capacity_factor_* should be set to a value >=1
loss_coef = 1e-2, # multiplier on the auxiliary expert balancing auxiliary loss
balance_loss_coef = 1e-2, # multiplier on the auxiliary expert balancing auxiliary loss
router_z_loss_coef = 1e-3, # loss weight for router z-loss
)

inputs = torch.randn(4, 1024, 512)
out, balance_loss, router_z_loss = moe(inputs) # (4, 1024, 512), (1,), (1,)
out, total_aux_loss, balance_loss, router_z_loss = moe(inputs) # (4, 1024, 512), (1,), (1,), (1,)

# for the entire mixture of experts block, in context of transformer

Expand All @@ -51,7 +51,11 @@ moe_block = SparseMoEBlock(
add_ff_after = True
)

out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,)
out, total_aux_loss, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,) (1,), (1,)

# the total auxiliary loss will need to be summed and then added to the main loss

# the other two losses are the breakdown, weighed by the coefficients
```

## Todo
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.30',
version = '0.1.0',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
17 changes: 11 additions & 6 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

MixtureOfExpertsReturn = namedtuple('MixtureOfExpertsReturn', [
'outputs',
'total_aux_loss',
'balance_loss',
'router_z_loss'
])
Expand Down Expand Up @@ -545,7 +546,7 @@ def __init__(self,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
gating_top_n = 2,
loss_coef = 1e-2,
balance_loss_coef = 1e-2,
router_z_loss_coef = 1e-3,
experts: Optional[Module] = None,
straight_through_dispatch_tensor = True,
Expand Down Expand Up @@ -578,7 +579,7 @@ def __init__(self,
allow_var_seq_len = allow_var_seq_len
)

self.loss_coef = loss_coef
self.balance_loss_coef = balance_loss_coef
self.router_z_loss_coef = router_z_loss_coef

def forward(self, x):
Expand All @@ -598,10 +599,14 @@ def forward(self, x):

# losses

balance_loss = loss * self.loss_coef
balance_loss = loss * self.balance_loss_coef
router_z_loss = router_z_loss * self.router_z_loss_coef

return MixtureOfExpertsReturn(output, balance_loss, router_z_loss)
# combine the losses

total_aux_loss = balance_loss + router_z_loss

return MixtureOfExpertsReturn(output, total_aux_loss, balance_loss, router_z_loss)

# sparse moe block
# in particular, they found that adding a feedforward before or after greatly stabilized the training and improved results
Expand Down Expand Up @@ -636,7 +641,7 @@ def forward(self, x):

residual = x

moe_out, balance_loss, router_z_loss = self.moe(self.moe_prenorm(x))
moe_out, total_aux_loss, balance_loss, router_z_loss = self.moe(self.moe_prenorm(x))

x = moe_out + residual

Expand All @@ -645,4 +650,4 @@ def forward(self, x):
if exists(self.ff_after):
x = self.ff_after(x) + x

return MixtureOfExpertsReturn(x, balance_loss, router_z_loss)
return MixtureOfExpertsReturn(x, total_aux_loss, balance_loss, router_z_loss)

0 comments on commit 5d5f071

Please sign in to comment.