Skip to content

Commit

Permalink
remove erroneous backwards for split_by_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 29, 2024
1 parent 09c14b2 commit 6b7f7fb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.6',
version = '0.1.7',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
29 changes: 9 additions & 20 deletions st_moe_pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,14 @@ def __init__(self, *, dim = 0):
def forward(self, x, sizes = None):
return AllGatherFunction.apply(x, self.dim, sizes)

class SplitByRankFunction(Function):
@staticmethod
def forward(ctx, x):
rank = dist.get_rank()
out = x[rank]

if isinstance(x, tuple):
sizes = tuple(map(lambda t: t.shape[0], x))
else:
sizes = (x.shape[1],) * x.shape[0]

sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
ctx.sizes = sizes
return out, sizes
def split_by_rank(x):
rank = dist.get_rank()
out = x[rank]

@staticmethod
def backward(ctx, grads, _):
grads = rearrange(grads, '... -> 1 ...')
grads = all_gather_variable_dim(grads, sizes = ctx.sizes)
return grads
if isinstance(x, tuple):
sizes = tuple(map(lambda t: t.shape[0], x))
else:
sizes = (x.shape[1],) * x.shape[0]

split_by_rank = SplitByRankFunction.apply
sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
return out, sizes

0 comments on commit 6b7f7fb

Please sign in to comment.