Skip to content

Commit

Permalink
any combinatino of number of experts and world size should not break
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 10, 2023
1 parent 52b5c8a commit 00be346
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,

- [ ] distributed
- [x] handle any world size less than number of experts
- [ ] handle any world size greater than number of experts
- [x] handle any world size greater than number of experts - for now, just have remainder machines do nothing
- [ ] optimize
- [ ] figure out what is faster, all gather, or broadcast with async followed by barrier
- [ ] make all distributed code pluggable, for different strategies
Expand Down
4 changes: 2 additions & 2 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def start(
cleanup()

if __name__ == '__main__':
world_size = 8
num_experts = 8
world_size = 9
num_experts = 4
batch_size = 2
batch_size_var_len = False

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.28',
version = '0.0.29',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
43 changes: 30 additions & 13 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def forward(
assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same'

x, batch_sizes = self.all_gather(x)
total_batch_size = batch_sizes.sum().item()

world_size = dist.get_world_size()
rank = dist.get_rank()
Expand All @@ -220,14 +221,28 @@ def forward(
if world_size <= num_experts:
num_experts_across_ranks = chunk_num(num_experts, world_size)
start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim = -1)

num_experts_per_rank = num_experts_across_ranks[rank]
num_experts_batches_across_ranks = tuple(i * total_batch_size for i in num_experts_across_ranks)

expert_start_index = start_indices[rank].item()
else:
# for now, make sure number of machines is right multiple
num_batch_chunks = world_size // num_experts
total_ranks_in_use = num_batch_chunks * num_experts

expert_start_index = rank // num_batch_chunks

batch_splits = chunk_num(total_batch_size, num_batch_chunks)
num_experts_batches_across_ranks = batch_splits * num_experts

# for now, remaining machines just process nothing

remain_ranks = world_size % num_experts
num_experts_batches_across_ranks += (0,) * remain_ranks

num_experts_per_rank = int(rank < total_ranks_in_use)

assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed'
num_experts_per_rank = 1
expert_start_index = rank // (world_size // num_experts)
assert len(num_experts_batches_across_ranks) == world_size

expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
else:
Expand All @@ -241,15 +256,13 @@ def forward(
if is_distributed:
x, expert_batch_packed_shape = pack_one(x, '* n d')

if world_size <= num_experts:
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts)
x = x.split(num_experts_across_ranks, dim = 0)
x = tuple(rearrange(t, 'e b n d -> (e b) n d') for t in x)
else:
x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size)

x = x.split(num_experts_batches_across_ranks, dim = 0)
x, experts_per_rank_sizes = split_by_rank(x)
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)

if num_experts_per_rank > 0:
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)
else:
x = x.reshape(num_experts, *x.shape)

# get the experts in use

Expand All @@ -260,11 +273,15 @@ def forward(
# route tokens to appropriate experts

outs = []

for expert, expert_input in zip(experts, x):
out = expert(expert_input)
outs.append(out)

outs = torch.stack(outs)
if len(outs) > 0:
outs = torch.stack(outs)
else:
outs = torch.empty_like(x).requires_grad_()

# all gather across merged expert batches dimensions
# then split the batch dimension back
Expand Down

0 comments on commit 00be346

Please sign in to comment.