diff --git a/deepspeed/moe/mappings.py b/deepspeed/moe/mappings.py index b8a06274343a..e57f66b85193 100644 --- a/deepspeed/moe/mappings.py +++ b/deepspeed/moe/mappings.py @@ -32,15 +32,23 @@ def _gather_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu input_ = input_.contiguous() - # Size and dimension. - rank = bwc_tensor_model_parallel_rank(mpu) - - tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))] - tensor_list[rank] = input_ - deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu)) + world_size = bwc_tensor_model_parallel_world_size(mpu) + if world_size == 1: + return input_ - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim).contiguous() + gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device) + deepspeed.comm.all_gather_into_tensor(gather_buffer, input_, group=bwc_tensor_model_parallel_group(mpu)) + if dim == 0: + shape = list(input_.size()) + shape[0] = shape[0] * world_size + output = gather_buffer.view(shape) + else: + tensor_list = [ + gather_buffer.narrow(0, + input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) + ] + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() return output @@ -50,6 +58,8 @@ def _drop_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu total_chunks = bwc_tensor_model_parallel_world_size(mpu) + if total_chunks == 1: + return input_ this_chunk = bwc_tensor_model_parallel_rank(mpu) assert input_.shape[ dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 96eab5e2ab17..416f01b82e3d 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -533,13 +533,18 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: if self.wall_clock_breakdown: self.timers(FIRST_ALLTOALL_TIMER).start() - if groups._get_expert_model_parallel_world_size() == 1: - # If the non-expert is tensor-parallel, it will create + tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu) + if tensor_model_world_size > 1: + # If the non-expert is tensor-parallel, + # Whether expert is tensor-parallel or not , it will create # duplicate tokens on the tensor-parallel ranks. - # Since our experts are not tensor-parallel, these duplicates - # need to be dropped to ensure correctness. - # this also doubles up as a communication optimization as we are - # reducing the all-to-all communication volume. + # drop duplicate tokens also doubles up as a communication + # optimization as we are reducing the all-to-all communication volume. + # 1: for not tensor-parallel expert,drop duplicate tokens to ensure + # both correctness and reduce all-to-all communication. + # 2: for tensor-parallel expert,drop duplicate tokens to reduce all-to-all + # communication volume,before expert execution, it is necessary to perform + # an allgather to ensure correctness, dispatched_input = drop_tokens(dispatched_input, dim=1) dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) @@ -548,10 +553,22 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: self.timers(FIRST_ALLTOALL_TIMER).stop() self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False) + if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1: + # if both expert and non-expert are tensor-parallel + # the dropped duplicate tokens need to be gathered on each + # tensor parallel rank again to ensure correctness + dispatched_input = gather_tokens(dispatched_input, dim=1) + # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model) - expert_output = self.experts(dispatched_input) + # Re-shape before drop_tokens: gecm -> ecm + expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) + if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1: + # if both expert and non-expert are tensor-parallel + # drop duplicate tokens to ensure both correctness + # and reduce all-to-all communication. + expert_output = drop_tokens(expert_output, dim=1) if self.wall_clock_breakdown: self.timers(SECOND_ALLTOALL_TIMER).start() @@ -562,10 +579,7 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: self.timers(SECOND_ALLTOALL_TIMER).stop() self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False) - # Re-shape back: gecm -> ecm - expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) - - if groups._get_expert_model_parallel_world_size() == 1: + if tensor_model_world_size > 1: # the dropped duplicate tokens need to be gathered on each # tensor parallel rank again for the tensor-parallel # non-expert of the next layer. diff --git a/deepspeed/ops/transformer/inference/moe_inference.py b/deepspeed/ops/transformer/inference/moe_inference.py index 8766b65e866d..fc001a86d42e 100644 --- a/deepspeed/ops/transformer/inference/moe_inference.py +++ b/deepspeed/ops/transformer/inference/moe_inference.py @@ -327,7 +327,7 @@ def forward(self, if self.expert_mp_group is not None: world_size = dist.get_world_size(group=self.expert_mp_group) - gather_buffer = torch.zeros(world_size * attention_output.numel(), + gather_buffer = torch.empty(world_size * attention_output.numel(), dtype=attention_output.dtype, device=attention_output.device) dist.all_gather_into_tensor(gather_buffer, attention_output, group=self.expert_mp_group) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 37b81d42c0d6..3ac6987e9c22 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2237,7 +2237,7 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: return grad_dict def _fp32_state_allgather(self, param, fp32_state_partition): - reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(), + reduce_buffer = torch.empty(self.partition_count * fp32_state_partition.numel(), dtype=torch.float32, device=param.device) my_rank = dist.get_rank(group=self.dp_process_group)