Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce all-to-all communication volume when both expert and non-expert are tensor-parallel #5626

Merged
merged 9 commits into from
Jul 23, 2024
26 changes: 18 additions & 8 deletions deepspeed/moe/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,23 @@ def _gather_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu

input_ = input_.contiguous()
# Size and dimension.
rank = bwc_tensor_model_parallel_rank(mpu)

tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))
world_size = bwc_tensor_model_parallel_world_size(mpu)
if world_size == 1:
return input_

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
deepspeed.comm.all_gather_into_tensor(gather_buffer, input_, group=bwc_tensor_model_parallel_group(mpu))
if dim == 0:
shape = list(input_.size())
shape[0] = shape[0] * world_size
output = gather_buffer.view(shape)
else:
tensor_list = [
gather_buffer.narrow(0,
input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
]
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()

return output

Expand All @@ -50,6 +58,8 @@ def _drop_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu

total_chunks = bwc_tensor_model_parallel_world_size(mpu)
if total_chunks == 1:
return input_
this_chunk = bwc_tensor_model_parallel_rank(mpu)
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
Expand Down
36 changes: 25 additions & 11 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,18 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).start()

if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, it will create
tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu)
if tensor_model_world_size > 1:
# If the non-expert is tensor-parallel,
# Whether expert is tensor-parallel or not , it will create
# duplicate tokens on the tensor-parallel ranks.
# Since our experts are not tensor-parallel, these duplicates
# need to be dropped to ensure correctness.
# this also doubles up as a communication optimization as we are
# reducing the all-to-all communication volume.
# drop duplicate tokens also doubles up as a communication
# optimization as we are reducing the all-to-all communication volume.
# 1: for not tensor-parallel expert,drop duplicate tokens to ensure
# both correctness and reduce all-to-all communication.
# 2: for tensor-parallel expert,drop duplicate tokens to reduce all-to-all
# communication volume,before expert execution, it is necessary to perform
# an allgather to ensure correctness,
dispatched_input = drop_tokens(dispatched_input, dim=1)

dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
Expand All @@ -548,10 +553,22 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
self.timers(FIRST_ALLTOALL_TIMER).stop()
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)

if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
# if both expert and non-expert are tensor-parallel
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again to ensure correctness
dispatched_input = gather_tokens(dispatched_input, dim=1)

# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)

expert_output = self.experts(dispatched_input)
# Re-shape before drop_tokens: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
# if both expert and non-expert are tensor-parallel
# drop duplicate tokens to ensure both correctness
# and reduce all-to-all communication.
expert_output = drop_tokens(expert_output, dim=1)

if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).start()
Expand All @@ -562,10 +579,7 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
self.timers(SECOND_ALLTOALL_TIMER).stop()
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)

# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)

if groups._get_expert_model_parallel_world_size() == 1:
if tensor_model_world_size > 1:
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again for the tensor-parallel
# non-expert of the next layer.
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/transformer/inference/moe_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def forward(self,

if self.expert_mp_group is not None:
world_size = dist.get_world_size(group=self.expert_mp_group)
gather_buffer = torch.zeros(world_size * attention_output.numel(),
gather_buffer = torch.empty(world_size * attention_output.numel(),
dtype=attention_output.dtype,
device=attention_output.device)
dist.all_gather_into_tensor(gather_buffer, attention_output, group=self.expert_mp_group)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,7 +2237,7 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]:
return grad_dict

def _fp32_state_allgather(self, param, fp32_state_partition):
reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(),
reduce_buffer = torch.empty(self.partition_count * fp32_state_partition.numel(),
dtype=torch.float32,
device=param.device)
my_rank = dist.get_rank(group=self.dp_process_group)
Expand Down
Loading