From 2ee1389e9af598f11418f66abc754270e0324a95 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Thu, 19 Sep 2024 14:44:03 -0700 Subject: [PATCH 01/19] change API for hierarchical CP Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 95 +++++++++++++++++------ transformer_engine/pytorch/transformer.py | 13 ++-- 2 files changed, 78 insertions(+), 30 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f8ba46b2ea..eb86b861f2 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1439,10 +1439,19 @@ def forward( if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + if isinstance(cp_group, list): + cp_group_a2a = cp_group[0] + cp_size_a2a = get_distributed_world_size(cp_group_a2a) + rank_a2a = get_distributed_rank(cp_group_a2a) + cp_group = cp_group[1] + else: + cp_size_a2a = 1 + rank_a2a = 0 + cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) - send_dst = cp_global_ranks[(rank + 1) % cp_size] - recv_src = cp_global_ranks[(rank - 1) % cp_size] + send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) causal = "causal" in attn_mask_type @@ -3969,6 +3978,18 @@ def attn_forward_func_with_cp( Attention implementation with context parallelism. """ + if isinstance(cp_comm_type, list): + assert isinstance(cp_group, list), "Hierarchical CP implementation needs multi-level of CP groups!" + assert len(cp_comm_type) == 2 and len(cp_group) == 2, "Current implementation only supports two-level of CP groups!" + if get_distributed_world_size(cp_group[0]) == 1: + cp_group = cp_group[1] + cp_comm_type = cp_comm_type[1] + elif get_distributed_world_size(cp_group[1]) == 1: + cp_group = cp_group[0] + cp_comm_type = cp_comm_type[0] + else: + assert isinstance(cp_group, dist_group_type), "Unsupported process group for non-hierarchical CP implementation!" + assert qkv_format in [ "bshd", "sbhd", @@ -4023,7 +4044,7 @@ def attn_forward_func_with_cp( use_fused_attention, ] - if cp_comm_type == "p2p": + if cp_comm_type in ["p2p", ["a2a", "p2p"]]: args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": @@ -4841,10 +4862,10 @@ def forward( attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, alibi_slopes: Optional[torch.Tensor] = None, - cp_group: Optional[dist_group_type] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, - cp_comm_type: str = "p2p", + cp_comm_type: Union[str, List[str]] = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: @@ -4861,7 +4882,12 @@ def forward( qkv_layout in QKVLayouts ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" - cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + cp_size = 1 + if isinstance(cp_group, dist_group_type): + cp_size = get_distributed_world_size(cp_group) + elif isinstance(cp_group, list): + for group in cp_group: + cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -6655,10 +6681,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, - cp_group: Optional[dist_group_type] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, - cp_comm_type: str = "p2p", + cp_comm_type: Union[str, List[str]] = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: @@ -6677,7 +6703,12 @@ def forward( qkv_layout in QKVLayouts ), f"FusedAttention does not support qkv_layout = {qkv_layout}!" - cp_size = 1 if cp_group is None else get_distributed_world_size(cp_group) + cp_size = 1 + if isinstance(cp_group, dist_group_type): + cp_size = get_distributed_world_size(cp_group) + elif isinstance(cp_group, list): + for group in cp_group: + cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -6923,7 +6954,7 @@ class DotProductAttention(TransformerEngineBaseModule): tensor parallel world size. tp_group : ProcessGroup, default = `None` tensor parallel process group. - cp_group : ProcessGroup, default = `None` + cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None` context parallel process group. cp_global_ranks : list of global rank IDs, default = `None` global rank IDs of GPUs that are in cp_group. @@ -6932,15 +6963,18 @@ class DotProductAttention(TransformerEngineBaseModule): compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels. - cp_comm_type : str + cp_comm_type : Union[str, List[str]], default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a" or ["a2a", "p2p"]. "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a + to QKV across each CP sub-group (e.g., via NVLink), then exchanging + KV with p2p between sub-groups (e.g., via IBLink). """ def __init__( @@ -6958,10 +6992,10 @@ def __init__( tp_group: Optional[dist_group_type] = None, layer_number: Optional[int] = None, attention_type: str = "self", - cp_group: Optional[dist_group_type] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, - cp_comm_type: str = "p2p", + cp_comm_type: Union[str, List[str]] = "p2p", softmax_scale: Optional[float] = None, ) -> None: super().__init__() @@ -7113,10 +7147,10 @@ def custom_forward(*input_args, **input_kwargs): def set_context_parallel_group( self, - cp_group: Union[dist_group_type, None], + cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, - cp_comm_type: str = "p2p", + cp_comm_type: Union[str, List[str]] = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -7124,21 +7158,24 @@ def set_context_parallel_group( Parameters ---------- - cp_group : ProcessGroup + cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str + cp_comm_type : Union[str, List[str]], default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a" or ["a2a", "p2p"]. "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a + to QKV across each CP sub-group (e.g., via NVLink), then exchanging + KV with p2p between sub-groups (e.g., via IBLink). """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks @@ -7448,7 +7485,12 @@ def forward( max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) batch_size = len(cu_seqlens_q) - 1 - cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group) + cp_size = 1 + if isinstance(self.cp_group, dist_group_type): + cp_size = get_distributed_world_size(self.cp_group) + elif isinstance(self.cp_group, list): + for group in self.cp_group: + cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 if qkv_format in ["sbhd", "bshd"]: @@ -8144,10 +8186,10 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N def set_context_parallel_group( self, - cp_group: Union[dist_group_type, None], + cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, - cp_comm_type: str = "p2p", + cp_comm_type: Union[str, List[str]] = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -8155,21 +8197,24 @@ def set_context_parallel_group( Parameters ---------- - cp_group : ProcessGroup + cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str + cp_comm_type : Union[str, List[str]], default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a", ["a2a", "p2p"]. "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a + to QKV across each CP sub-group (e.g., via NVLink), then exchanging + KV with p2p between sub-groups (e.g., via IBLink). """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 958c7019ba..e4bff05476 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -484,10 +484,10 @@ def reset_fp8_meta_tensors(self) -> None: def set_context_parallel_group( self, - cp_group: Union[dist_group_type, None], + cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, - cp_comm_type: str = "p2p", + cp_comm_type: Union[str, List[str]] = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -495,21 +495,24 @@ def set_context_parallel_group( Parameters ---------- - cp_group : ProcessGroup + cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str + cp_comm_type : Union[str, List[str]], default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a". + Can be "p2p" or "all_gather" or "a2a", or ["a2a", "p2p"]. "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a + to QKV across each CP sub-group (e.g., via NVLink), then exchanging + KV with p2p between sub-groups (e.g., via IBLink). """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): From 0e826e771faa52b6cfb802aa39d968e83c1f8c46 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Thu, 19 Sep 2024 19:30:43 -0700 Subject: [PATCH 02/19] move fp8 code before qkv reshape Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 84 +++++++++++++------------ 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6f79341f22..b7238d01de 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1472,6 +1472,47 @@ def forward( cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + else: + q_f16, k_f16, v_f16 = q, k, v + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" @@ -1529,47 +1570,6 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() - if fp8: - if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_backend = FusedAttnBackend["FP8"] - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA!" - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_f16, k_f16, v_f16 = q, k, v - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [k_f16, v_f16] - ] - fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - else: - assert False, "FP8 is only supported with Fused Attention!" - else: - q_f16 = q - if use_fused_attention: - fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - p2p_comm_buffers = [None for _ in range(cp_size)] if use_fused_attention and qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) @@ -2174,6 +2174,7 @@ def forward( fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() elif fp8 and fp8_meta["recipe"].fp8_mha: + q_fp8 = q_fp8.view(q.shape) kv_fp8 = Float8Tensor( data=kv, fp8_meta=fp8_meta, @@ -2185,6 +2186,7 @@ def forward( q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 fp8_fwd_scales, fp8_fwd_scale_invs = None, None else: + q_f16 = q_f16.view(q.shape) q_save, kv_save, out_save = q_f16, kv, out_f16 fp8_fwd_scales, fp8_fwd_scale_invs = None, None From 44db43b623fb891b01608b0b9a5157a3c732b8c0 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 23 Sep 2024 13:40:09 -0700 Subject: [PATCH 03/19] try to insert A2A for hierarchical CP Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 26 +++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b7238d01de..d5fc5b5533 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1488,7 +1488,8 @@ def forward( q, k, v = q_fp8._data, k_fp8._data, v_fp8._data else: q_f16, k_f16, v_f16 = q, k, v - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): k, v = [ cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) @@ -1513,6 +1514,19 @@ def forward( fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) + print(f"before A2A rank_{rank*cp_size_a2a+rank_a2a}, {q.shape}, {k.shape}, {v.shape}") + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + print(f"after A2A rank_{rank*cp_size_a2a+rank_a2a}, {q.shape}, {k.shape}, {v.shape}") + if not fp8 or not fp8_meta["recipe"].fp8_mha: + q_f16 = q + if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + torch.cuda.synchronize() + assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" @@ -1585,6 +1599,7 @@ def forward( req.wait() if i < (cp_size - 1): + print(f"P2P global_rank_{rank*cp_size_a2a+rank_a2a}, rank_{rank}, step_{i}, send_dst_{send_dst}, recv_src_{recv_src}, {p2p_comm_buffers[i].shape}") p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i]) send_recv_reqs[i % 2] = flash_attn_p2p_communicate( rank, @@ -2174,7 +2189,14 @@ def forward( fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() elif fp8 and fp8_meta["recipe"].fp8_mha: - q_fp8 = q_fp8.view(q.shape) + q_fp8 = Float8Tensor( + data=q, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) kv_fp8 = Float8Tensor( data=kv, fp8_meta=fp8_meta, From 705da7ec81490425062afa989dd991ba7571dc5c Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 23 Sep 2024 17:03:12 -0700 Subject: [PATCH 04/19] make fwd work Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 29 +++++++++++++++++-------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d5fc5b5533..fa23c7141b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1516,11 +1516,9 @@ def forward( if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) - print(f"before A2A rank_{rank*cp_size_a2a+rank_a2a}, {q.shape}, {k.shape}, {v.shape}") q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) - print(f"after A2A rank_{rank*cp_size_a2a+rank_a2a}, {q.shape}, {k.shape}, {v.shape}") if not fp8 or not fp8_meta["recipe"].fp8_mha: q_f16 = q if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): @@ -1599,7 +1597,6 @@ def forward( req.wait() if i < (cp_size - 1): - print(f"P2P global_rank_{rank*cp_size_a2a+rank_a2a}, rank_{rank}, step_{i}, send_dst_{send_dst}, recv_src_{recv_src}, {p2p_comm_buffers[i].shape}") p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i]) send_recv_reqs[i % 2] = flash_attn_p2p_communicate( rank, @@ -2155,12 +2152,26 @@ def forward( ) kv = p2p_comm_buffers[-1] - if use_fused_attention: - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - else: + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + batch_size = out.shape[0] + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + batch_size = out.shape[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + ) + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, batch_size, *out.shape[-2:]) + elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) if fp8 and use_fused_attention: From 99ace4673b1e7ad0d33e49ea4e8687bb8a4d0649 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 23 Sep 2024 17:23:35 -0700 Subject: [PATCH 05/19] remove a redundant sync Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fa23c7141b..e8d59f8931 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1523,7 +1523,6 @@ def forward( q_f16 = q if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - torch.cuda.synchronize() assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 From f73b2fb5c85214a38e5e1d537a76122d0984a423 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Mon, 23 Sep 2024 20:30:15 -0700 Subject: [PATCH 06/19] make bwd of hierarchical CP work Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 40 +++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e8d59f8931..c4c886eda3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1445,6 +1445,7 @@ def forward( rank_a2a = get_distributed_rank(cp_group_a2a) cp_group = cp_group[1] else: + cp_group_a2a = None cp_size_a2a = 1 rank_a2a = 0 @@ -2236,8 +2237,10 @@ def forward( *rng_states, *attn_biases, ) + ctx.cp_group_a2a = cp_group_a2a ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks + ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.total_tokens_kv = total_tokens_kv ctx.max_seqlen_q = max_seqlen_q @@ -2255,10 +2258,17 @@ def forward( @staticmethod def backward(ctx, dout): + if ctx.cp_group_a2a is not None: + cp_size_a2a = get_distributed_world_size(ctx.cp_group_a2a) + rank_a2a = get_distributed_rank(ctx.cp_group_a2a) + else: + cp_size_a2a = 1 + rank_a2a = 0 + cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] - recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] + send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] @@ -2271,6 +2281,7 @@ def backward(ctx, dout): causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type if ctx.qkv_format in ["bshd", "sbhd"]: + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] else: qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format @@ -2354,6 +2365,12 @@ def backward(ctx, dout): fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True) + out, dout = flash_attn_a2a_communicate( + [out, dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True + ) + out = out.view(*q.shape) dout = dout.view(*q.shape) send_recv_reqs = [] @@ -2949,6 +2966,21 @@ def backward(ctx, dout): cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) for x in [dq, dkv] ] + dk, dv = dkv[0], dkv[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, False + ) + if ctx.qkv_format == "bshd": + batch_size = dout.shape[0] + dq, dk, dv = [x.view(batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + elif ctx.qkv_format == "sbhd": + batch_size = dout.shape[2] + dq, dk, dv = [x.view(-1, batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: dq, dk, dv = [ Float8Tensor( data=x, @@ -2958,10 +2990,8 @@ def backward(ctx, dout): fp8_dtype=fp8_dtype_backward, dtype=dout_dtype, ) - for x in [dq, dkv[0], dkv[1]] + for x in [dq, dk, dv] ] - else: - dk, dv = dkv[0], dkv[1] if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] From 0be5dd7d3a1b3f4c67a070b2ac4eb9fa77a4dd9a Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Sep 2024 11:10:04 -0700 Subject: [PATCH 07/19] fix dout a2a in bwd Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c4c886eda3..07bf877a1a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2237,6 +2237,7 @@ def forward( *rng_states, *attn_biases, ) + ctx.batch_size = batch_size ctx.cp_group_a2a = cp_group_a2a ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks @@ -2316,6 +2317,7 @@ def backward(ctx, dout): # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) + dout_dtype = dout.dtype if ctx.fp8: if ctx.use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) @@ -2326,7 +2328,6 @@ def backward(ctx, dout): dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) dkv_fp8_ = torch.empty_like(dkv_fp8) - dout_dtype = dout.dtype if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv @@ -2350,7 +2351,13 @@ def backward(ctx, dout): assert False, "FP8 is only supported with Fused Attention!" else: if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: - q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]] + q, kv = [x.from_float8(x.dtype) for x in [q, kv]] + if cp_size_a2a == 1: + dout = dout.from_float8(dout_dtype) + else: + dout_fp8_dtype = dout._fp8_dtype + dout_scale_inv = dout._scale_inv + dout = dout._data dq = torch.empty_like(q) if ctx.qkv_format == "thd" and causal: dq[cu_seqlens_q_padded[-1] :].fill_(0) @@ -2362,7 +2369,7 @@ def backward(ctx, dout): if ctx.use_fused_attention: fp8_meta_kwargs = {} fused_attn_qkv_dtype = TE_DType[q.dtype] - fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if cp_size_a2a > 1: @@ -2370,6 +2377,10 @@ def backward(ctx, dout): out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + dout = cast_from_fp8( + dout, None, None, dout_fp8_dtype, TE_DType[dout_dtype], scale_inv=dout_scale_inv + ) out = out.view(*q.shape) dout = dout.view(*q.shape) @@ -2974,11 +2985,9 @@ def backward(ctx, dout): [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, False ) if ctx.qkv_format == "bshd": - batch_size = dout.shape[0] - dq, dk, dv = [x.view(batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif ctx.qkv_format == "sbhd": - batch_size = dout.shape[2] - dq, dk, dv = [x.view(-1, batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: dq, dk, dv = [ From a6833c7bba5a2c7f73b3d17c2907cb5474d229db Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Sep 2024 11:41:10 -0700 Subject: [PATCH 08/19] fix q_f16 with fp8 Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 07bf877a1a..a7ce34f9c0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1520,9 +1520,10 @@ def forward( q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) - if not fp8 or not fp8_meta["recipe"].fp8_mha: + if not fp8: + q_f16 = q + elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16 = q - if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) assert qkv_format == "thd" or ( From caf3746d721785ec0a22a02cb08a5b9c1ab38e8b Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Sep 2024 13:10:59 -0700 Subject: [PATCH 09/19] assert hierarchical CP implementation does not support THD format Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index a7ce34f9c0..f90ca149d5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1440,6 +1440,7 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) if isinstance(cp_group, list): + assert qkv_format != "thd", f"{qkv_format} format is not supported with hierarchical CP implementation yet!" cp_group_a2a = cp_group[0] cp_size_a2a = get_distributed_world_size(cp_group_a2a) rank_a2a = get_distributed_rank(cp_group_a2a) From 03806c4f8a38590593e0ce7d0ffc5a320fecdc7c Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Sep 2024 18:19:36 -0700 Subject: [PATCH 10/19] bug fix Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f90ca149d5..45d5be2ecd 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2156,10 +2156,10 @@ def forward( kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) - batch_size = out.shape[0] + ctx.batch_size = out.shape[0] elif qkv_format == "sbhd": out = out.view(-1, *out.shape[-3:]) - batch_size = out.shape[1] + ctx.batch_size = out.shape[1] if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) @@ -2169,10 +2169,10 @@ def forward( if use_fused_attention: if qkv_format == "bshd": # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) elif qkv_format == "sbhd": # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, batch_size, *out.shape[-2:]) + out = out.view(-1, ctx.batch_size, *out.shape[-2:]) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) @@ -2239,7 +2239,6 @@ def forward( *rng_states, *attn_biases, ) - ctx.batch_size = batch_size ctx.cp_group_a2a = cp_group_a2a ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks From ea1e4a3e4ac1c4323dcbf0caf5b2ddbf89d463a0 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Sep 2024 19:15:03 -0700 Subject: [PATCH 11/19] assert hierarchical CP does not support attn bias Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 45d5be2ecd..ffd3fcf727 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1441,6 +1441,7 @@ def forward( if isinstance(cp_group, list): assert qkv_format != "thd", f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported with hierarchical CP implementation yet!" cp_group_a2a = cp_group[0] cp_size_a2a = get_distributed_world_size(cp_group_a2a) rank_a2a = get_distributed_rank(cp_group_a2a) From 3cd21ab6d5f73b4cd9e6477e6517e5a1b0e329ea Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Sep 2024 20:02:31 -0700 Subject: [PATCH 12/19] add unit test for hierarchical CP Signed-off-by: Xiaowei Ren --- .../fused_attn/run_fused_attn_with_cp.py | 24 ++++++++++------- .../fused_attn/test_fused_attn_with_cp.py | 26 +++++++++++-------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 6c775fb127..41e0284aad 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -17,7 +17,7 @@ def run_dpa_with_cp( - dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p" + dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_size=2, cp_comm_type="p2p" ): """Test DotProductAttention module with context parallelism""" @@ -59,6 +59,19 @@ def run_dpa_with_cp( cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if isinstance(cp_size, list): + assert len(cp_size) == 2 + assert len(cp_size) == len(cp_comm_type) + assert cp_size[0] * cp_size[1] == world_size + cp_comm_sub_ranks = [range(i*cp_size[0], (i+1)*cp_size[0]) for i in range(cp_size[1])] + cp_comm_sub_ranks += [range(i, world_size, cp_size[0]) for i in range(cp_size[0])] + cp_comm_sub_groups = [] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) + else: + assert cp_size == world_size if dtype == "fp8": fp8_recipe = DelayedScaling(fp8_dpa=True) @@ -167,13 +180,6 @@ def run_dpa_with_cp( else: bias = None - # make sure all GPU ranks have same inputs - for x in [q, k, v, dout] + ([] if bias is None else [bias]): - dist.broadcast(x, 0, group=cp_comm_group) - if qkv_format == "thd": - for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]: - dist.broadcast(x, 0, group=cp_comm_group) - # run core_attn without CP for x in [q, k, v]: x.requires_grad = True @@ -239,7 +245,7 @@ def run_dpa_with_cp( bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group( - cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type + cp_comm_sub_groups if isinstance(cp_size, list) else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type ) if dtype == "fp8": diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index d6358d1062..4463f77844 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -36,8 +36,8 @@ } -def get_bash_arguments(**kwargs): - args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"] +def get_bash_arguments(num_gpus, **kwargs): + args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node="+str(num_gpus)] te_path = os.getenv("TE_PATH", "/opt/transformerengine") script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") args.append(script_path) @@ -51,20 +51,20 @@ def get_bash_arguments(**kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", ["a2a", "p2p"]]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] - if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if cp_comm_type == "a2a" and qkv_format == "thd": + if "a2a" in cp_comm_type and qkv_format == "thd": pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") - if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" @@ -72,10 +72,12 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( + num_gpus=4 if isinstance(cp_comm_type, list) else 2, dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention", + cp_size=[2, 2] if isinstance(cp_comm_type, list) else 2, cp_comm_type=cp_comm_type, ), check=True, @@ -106,7 +108,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", ["a2a", "p2p"]]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") @@ -120,7 +122,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if qkv_format == "thd" and cp_comm_type == "a2a": + if qkv_format == "thd" and "a2a" in cp_comm_type: pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": pytest.skip( @@ -138,9 +140,9 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("FP8 attention cannot work with sliding window yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" @@ -148,10 +150,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( + num_gpus=4 if isinstance(cp_comm_type, list) else 2, dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention", + cp_size=[2, 2] if isinstance(cp_comm_type, list) else 2, cp_comm_type=cp_comm_type, ), check=True, From 90c0bb8bbe2a575a35961836f244e3a1262129fb Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 24 Sep 2024 20:28:08 -0700 Subject: [PATCH 13/19] fix cp_comm_type in unit test Signed-off-by: Xiaowei Ren --- tests/pytorch/fused_attn/run_fused_attn_with_cp.py | 14 ++++++-------- .../pytorch/fused_attn/test_fused_attn_with_cp.py | 10 ++++------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 41e0284aad..5042dfa85e 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -17,7 +17,7 @@ def run_dpa_with_cp( - dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_size=2, cp_comm_type="p2p" + dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p" ): """Test DotProductAttention module with context parallelism""" @@ -59,10 +59,10 @@ def run_dpa_with_cp( cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if isinstance(cp_size, list): - assert len(cp_size) == 2 - assert len(cp_size) == len(cp_comm_type) - assert cp_size[0] * cp_size[1] == world_size + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0 + cp_comm_type = ["a2a", "p2p"] + cp_size = [2, world_size // 2] cp_comm_sub_ranks = [range(i*cp_size[0], (i+1)*cp_size[0]) for i in range(cp_size[1])] cp_comm_sub_ranks += [range(i, world_size, cp_size[0]) for i in range(cp_size[0])] cp_comm_sub_groups = [] @@ -70,8 +70,6 @@ def run_dpa_with_cp( sub_group = dist.new_group(sub_ranks, backend="nccl") if rank in sub_ranks: cp_comm_sub_groups.append(sub_group) - else: - assert cp_size == world_size if dtype == "fp8": fp8_recipe = DelayedScaling(fp8_dpa=True) @@ -245,7 +243,7 @@ def run_dpa_with_cp( bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group( - cp_comm_sub_groups if isinstance(cp_size, list) else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type + cp_comm_sub_groups if isinstance(cp_comm_type, list) else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type ) if dtype == "fp8": diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 4463f77844..dfb710f8ce 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -51,7 +51,7 @@ def get_bash_arguments(num_gpus, **kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", ["a2a", "p2p"]]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): @@ -72,12 +72,11 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( - num_gpus=4 if isinstance(cp_comm_type, list) else 2, + num_gpus=4 if cp_comm_type == "a2a+p2p" else 2, dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention", - cp_size=[2, 2] if isinstance(cp_comm_type, list) else 2, cp_comm_type=cp_comm_type, ), check=True, @@ -108,7 +107,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", ["a2a", "p2p"]]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") @@ -150,12 +149,11 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( - num_gpus=4 if isinstance(cp_comm_type, list) else 2, + num_gpus=4 if cp_comm_type == "a2a+p2p" else 2, dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention", - cp_size=[2, 2] if isinstance(cp_comm_type, list) else 2, cp_comm_type=cp_comm_type, ), check=True, From 8c6139d308029d049740574874b2405ff4b025e2 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Wed, 25 Sep 2024 00:28:44 -0700 Subject: [PATCH 14/19] bug fix and code cleaning Signed-off-by: Xiaowei Ren --- .../fused_attn/run_fused_attn_with_cp.py | 8 +- transformer_engine/pytorch/attention.py | 100 ++++++++---------- transformer_engine/pytorch/transformer.py | 12 +-- 3 files changed, 56 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 5042dfa85e..f028902c62 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -61,10 +61,8 @@ def run_dpa_with_cp( cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") if cp_comm_type == "a2a+p2p": assert world_size % 2 == 0 - cp_comm_type = ["a2a", "p2p"] - cp_size = [2, world_size // 2] - cp_comm_sub_ranks = [range(i*cp_size[0], (i+1)*cp_size[0]) for i in range(cp_size[1])] - cp_comm_sub_ranks += [range(i, world_size, cp_size[0]) for i in range(cp_size[0])] + cp_comm_sub_ranks = [range(i*2, (i+1)*2) for i in range(world_size//2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] cp_comm_sub_groups = [] for sub_ranks in cp_comm_sub_ranks: sub_group = dist.new_group(sub_ranks, backend="nccl") @@ -243,7 +241,7 @@ def run_dpa_with_cp( bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group( - cp_comm_sub_groups if isinstance(cp_comm_type, list) else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type ) if dtype == "fp8": diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ffd3fcf727..cf8736796a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1407,6 +1407,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): Attention implementation with context parallelism. Exchange KV between CP ranks with P2P in ring topology. Split attention compute into multiple steps, and overlap current-step compute with next-step communication. + + This implementation also supports hierarchical CP, which parallelizes attention + heads in low-level CP grpups and parallizes seqeunce dimension in high-levle CP + groups. Refer details in `LongVILA `_ and + `USP `_. """ @staticmethod @@ -2167,15 +2172,12 @@ def forward( out = flash_attn_a2a_communicate( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, ctx.batch_size, *out.shape[-2:]) - elif not use_fused_attention: - out = out.view(-1, *out.shape[-2:]) + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, ctx.batch_size, *out.shape[-2:]) if fp8 and use_fused_attention: amax_cp_fwd = amax_per_step.amax(dim=1) @@ -3255,13 +3257,10 @@ def forward( torch.cuda.current_stream().wait_stream(cp_stream) - if use_fused_attention: - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - else: - out = out.view(-1, *out.shape[-2:]) + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) ctx.save_for_backward( q, @@ -3744,13 +3743,12 @@ def forward( out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, batch_size, *out.shape[-2:]) + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, batch_size, *out.shape[-2:]) if fp8: if fp8_meta["recipe"].fp8_mha: @@ -3890,10 +3888,6 @@ def backward(ctx, dout): fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True @@ -4053,15 +4047,15 @@ def attn_forward_func_with_cp( Attention implementation with context parallelism. """ - if isinstance(cp_comm_type, list): - assert isinstance(cp_group, list), "Hierarchical CP implementation needs multi-level of CP groups!" - assert len(cp_comm_type) == 2 and len(cp_group) == 2, "Current implementation only supports two-level of CP groups!" + if cp_comm_type == "a2a+p2p": + assert isinstance(cp_group, list), "Hierarchical CP implementation needs multi-level CP groups!" + assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" if get_distributed_world_size(cp_group[0]) == 1: cp_group = cp_group[1] - cp_comm_type = cp_comm_type[1] + cp_comm_type = "p2p" elif get_distributed_world_size(cp_group[1]) == 1: cp_group = cp_group[0] - cp_comm_type = cp_comm_type[0] + cp_comm_type = "a2a" else: assert isinstance(cp_group, dist_group_type), "Unsupported process group for non-hierarchical CP implementation!" @@ -4119,7 +4113,7 @@ def attn_forward_func_with_cp( use_fused_attention, ] - if cp_comm_type in ["p2p", ["a2a", "p2p"]]: + if "p2p" in cp_comm_type: args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": @@ -4942,7 +4936,7 @@ def forward( cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, - cp_comm_type: Union[str, List[str]] = "p2p", + cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: @@ -6756,7 +6750,7 @@ def forward( cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, - cp_comm_type: Union[str, List[str]] = "p2p", + cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: @@ -7035,18 +7029,18 @@ class DotProductAttention(TransformerEngineBaseModule): compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels. - cp_comm_type : Union[str, List[str]], default = `p2p` + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a" or ["a2a", "p2p"]. + Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. - ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a - to QKV across each CP sub-group (e.g., via NVLink), then exchanging - KV with p2p between sub-groups (e.g., via IBLink). + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ def __init__( @@ -7067,7 +7061,7 @@ def __init__( cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, - cp_comm_type: Union[str, List[str]] = "p2p", + cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, ) -> None: super().__init__() @@ -7222,7 +7216,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, - cp_comm_type: Union[str, List[str]] = "p2p", + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -7236,18 +7230,18 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : Union[str, List[str]], default = `p2p` + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a" or ["a2a", "p2p"]. + Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. - ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a - to QKV across each CP sub-group (e.g., via NVLink), then exchanging - KV with p2p between sub-groups (e.g., via IBLink). + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks @@ -8261,7 +8255,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, - cp_comm_type: Union[str, List[str]] = "p2p", + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -8275,18 +8269,18 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : Union[str, List[str]], default = `p2p` + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a", ["a2a", "p2p"]. + Can be "p2p" or "all_gather" or "a2a", "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. - ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a - to QKV across each CP sub-group (e.g., via NVLink), then exchanging - KV with p2p between sub-groups (e.g., via IBLink). + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index e4bff05476..3df00b66a8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -487,7 +487,7 @@ def set_context_parallel_group( cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream, - cp_comm_type: Union[str, List[str]] = "p2p", + cp_comm_type: str = "p2p", ) -> None: """ Set the context parallel attributes for the given @@ -501,18 +501,18 @@ def set_context_parallel_group( list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : Union[str, List[str]], default = `p2p` + cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a", or ["a2a", "p2p"]. + Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. - ["a2a", "p2p"]: hierarchical CP implementation. First applying a2a - to QKV across each CP sub-group (e.g., via NVLink), then exchanging - KV with p2p between sub-groups (e.g., via IBLink). + "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): From 1db4e8b755713e82ee3ba4cba8e6327f4b406a00 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Wed, 25 Sep 2024 01:44:35 -0700 Subject: [PATCH 15/19] minor change Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index cf8736796a..36bddde370 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4113,7 +4113,7 @@ def attn_forward_func_with_cp( use_fused_attention, ] - if "p2p" in cp_comm_type: + if cp_comm_type in ["p2p", "a2a+p2p"]: args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": From 1655edceb3badfa287979b841eef8ed208d9c3e3 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Wed, 25 Sep 2024 10:00:36 -0700 Subject: [PATCH 16/19] an assert info change Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 36bddde370..49ecc4fc21 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4057,7 +4057,7 @@ def attn_forward_func_with_cp( cp_group = cp_group[0] cp_comm_type = "a2a" else: - assert isinstance(cp_group, dist_group_type), "Unsupported process group for non-hierarchical CP implementation!" + assert isinstance(cp_group, dist_group_type), f"Unsupported process group for CP communication type {cp_comm_type}!" assert qkv_format in [ "bshd", From edbe8983ab38d0944be104d08fa20547210febcf Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Thu, 26 Sep 2024 01:14:15 -0700 Subject: [PATCH 17/19] dout shape fix Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 46 ++++++++++++++++--------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 49ecc4fc21..e9d1213b72 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2172,12 +2172,15 @@ def forward( out = flash_attn_a2a_communicate( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) - if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + elif not use_fused_attention: + out = out.view(-1, *out.shape[-2:]) if fp8 and use_fused_attention: amax_cp_fwd = amax_per_step.amax(dim=1) @@ -2377,6 +2380,9 @@ def backward(ctx, dout): fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if cp_size_a2a > 1: + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True) out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True @@ -3257,10 +3263,13 @@ def forward( torch.cuda.current_stream().wait_stream(cp_stream) - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) + if use_fused_attention: + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + else: + out = out.view(-1, *out.shape[-2:]) ctx.save_for_backward( q, @@ -3743,12 +3752,13 @@ def forward( out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) - if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, batch_size, *out.shape[-2:]) + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, batch_size, *out.shape[-2:]) if fp8: if fp8_meta["recipe"].fp8_mha: @@ -3888,6 +3898,10 @@ def backward(ctx, dout): fused_attn_dqkv_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True From b096051efff87341a181e0d985c1bb5cb81cba48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Sep 2024 07:47:22 +0000 Subject: [PATCH 18/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_attn/run_fused_attn_with_cp.py | 7 ++-- .../fused_attn/test_fused_attn_with_cp.py | 2 +- transformer_engine/pytorch/attention.py | 33 +++++++++++++++---- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index f028902c62..ce185041ca 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -61,7 +61,7 @@ def run_dpa_with_cp( cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") if cp_comm_type == "a2a+p2p": assert world_size % 2 == 0 - cp_comm_sub_ranks = [range(i*2, (i+1)*2) for i in range(world_size//2)] + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] cp_comm_sub_groups = [] for sub_ranks in cp_comm_sub_ranks: @@ -241,7 +241,10 @@ def run_dpa_with_cp( bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, ) if dtype == "fp8": diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index dfb710f8ce..572a671b71 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -37,7 +37,7 @@ def get_bash_arguments(num_gpus, **kwargs): - args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node="+str(num_gpus)] + args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus)] te_path = os.getenv("TE_PATH", "/opt/transformerengine") script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") args.append(script_path) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e9d1213b72..3366114025 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1445,8 +1445,13 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) if isinstance(cp_group, list): - assert qkv_format != "thd", f"{qkv_format} format is not supported with hierarchical CP implementation yet!" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported with hierarchical CP implementation yet!" + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert attn_bias_type == "no_bias", ( + f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" + " yet!" + ) cp_group_a2a = cp_group[0] cp_size_a2a = get_distributed_world_size(cp_group_a2a) rank_a2a = get_distributed_rank(cp_group_a2a) @@ -2385,7 +2390,13 @@ def backward(ctx, dout): dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True) out, dout = flash_attn_a2a_communicate( - [out, dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True + [out, dout], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + True, ) if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: dout = cast_from_fp8( @@ -2992,7 +3003,13 @@ def backward(ctx, dout): if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False) dq, dk, dv = flash_attn_a2a_communicate( - [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, False + [dq, dk, dv], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + False, ) if ctx.qkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] @@ -4062,7 +4079,9 @@ def attn_forward_func_with_cp( """ if cp_comm_type == "a2a+p2p": - assert isinstance(cp_group, list), "Hierarchical CP implementation needs multi-level CP groups!" + assert isinstance( + cp_group, list + ), "Hierarchical CP implementation needs multi-level CP groups!" assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" if get_distributed_world_size(cp_group[0]) == 1: cp_group = cp_group[1] @@ -4071,7 +4090,9 @@ def attn_forward_func_with_cp( cp_group = cp_group[0] cp_comm_type = "a2a" else: - assert isinstance(cp_group, dist_group_type), f"Unsupported process group for CP communication type {cp_comm_type}!" + assert isinstance( + cp_group, dist_group_type + ), f"Unsupported process group for CP communication type {cp_comm_type}!" assert qkv_format in [ "bshd", From 67c590d81c2ea33465c550240e9eefb3ce0955c9 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Fri, 27 Sep 2024 17:54:13 -0700 Subject: [PATCH 19/19] move function definitions to the front of the first call Signed-off-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 204 ++++++++++++------------ 1 file changed, 102 insertions(+), 102 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 86c0bb71db..aef6a2018e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1402,6 +1402,108 @@ def get_cu_seqlens_on_cp_rank( return cu_seqlens_on_cp_rank +@torch.compile +def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks + before or after CP communications (e.g., all-gather, all-to-all). This function is to compute + sequence chunk ids for reordering. + """ + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + if to_contiguous: + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + else: + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + return chunk_ids + + +@torch.compile +def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): + """Reorder sequence chunk for A2A communication.""" + if before_attn: + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + else: + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) + return x + + +def flash_attn_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + chunk_ids_for_a2a: torch.Tensor, + seq_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """A2A communication for context parallelism.""" + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + if before_attn: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + a2a_inputs[i] = x.movedim(-3, 0).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -3060,26 +3162,6 @@ def backward(ctx, dout): ) -@torch.compile -def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): - """ - Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. - To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks - before or after CP communications (e.g., all-gather, all-to-all). This function is to compute - sequence chunk ids for reordering. - """ - chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) - if to_contiguous: - for rank in range(cp_size): - chunk_ids[rank] = 2 * rank - chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 - else: - for rank in range(cp_size): - chunk_ids[2 * rank] = rank - chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 - return chunk_ids - - def get_kv_seq_info_after_all_gather( local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal ): @@ -3505,88 +3587,6 @@ def backward(ctx, dout): ) -@torch.compile -def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): - """Reorder sequence chunk for A2A communication.""" - if before_attn: - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] - x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) - # reorder the sequence chunks - x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) - else: - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.movedim(seq_dim, 0).contiguous() - # reorder the sequence chunks - x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] - x = x.view(cp_size, 2, *x.shape[1:]) - return x - - -def flash_attn_a2a_communicate( - a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], - chunk_ids_for_a2a: torch.Tensor, - seq_dim: int, - cp_size: int, - cp_group: dist_group_type, - cp_stream: torch.cuda.Stream, - before_attn: bool, -) -> Union[torch.Tensor, List[torch.Tensor]]: - """A2A communication for context parallelism.""" - a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs - a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) - if before_attn: - for i in range(len(a2a_inputs) + 2): - if 0 < i < len(a2a_inputs) + 1: - a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) - a2a_reqs[i - 1] = torch.distributed.all_to_all_single( - a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True - ) - if i > 1: - with torch.cuda.stream(cp_stream): - a2a_reqs[i - 2].wait() - x = a2a_outputs[i - 2] - # reorder the sequence chunks - x = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn - ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] - a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) - if i < len(a2a_inputs): - x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() - else: - for i in range(len(a2a_inputs) + 2): - if 0 < i < len(a2a_inputs) + 1: - a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) - a2a_reqs[i - 1] = torch.distributed.all_to_all_single( - a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True - ) - if i < len(a2a_inputs): - x = a2a_inputs[i] - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) - # reorder the sequence chunks - a2a_inputs[i] = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn - ) - if i > 1: - with torch.cuda.stream(cp_stream): - a2a_reqs[i - 2].wait() - x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] - a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) - torch.cuda.current_stream().wait_stream(cp_stream) - return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs - - class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): """ Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.