From e36273a43f1e4fc2e8775c3802dac9b53106a5d5 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 18 Oct 2024 17:56:45 -0700 Subject: [PATCH 1/5] WIP: fix get_swa_mask for padding Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 24 +-- transformer_engine/pytorch/attention.py | 161 +++++++++++--------- 2 files changed, 100 insertions(+), 85 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 4b4eecbf39..72694d05a2 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -531,18 +531,18 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "swa_6_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"), + "swa_6_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5f8357a01b..7d394a6970 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1030,9 +1030,24 @@ def get_swa_mask( attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, ) -> torch.Tensor: """ - Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. - For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner, - and for other mask types, the bottom right corner. + Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. Requirements for the + shapes of `attention_mask` given an `attn_mask_type` are the same as in DotProductAttention. + For "`causal`" and "`padding_causal`" mask types, the sliding window diagonal is aligned to the + top left corner of the softmax matrix; for others, the bottom right corner. Note that when padding + is applied, the bottom right corner comes from the [actual_seqlen_q[i], actual_seqlen_kv[i]] matrix, + for each batch i, not the [max_seqlen_q, max_seqlen_kv] matrix.:: + + attn_mask_type output shape diagonal alignment + -------------------------------------------------------------------------------------------- + no_mask [1, 1, max_seqlen_q, max_seqlen_kv] bottom right + causal [1, 1, max_seqlen_q, max_seqlen_kv] top left + causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] bottom right + padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on + actual sequence lengths + padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] top left + padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on + actual sequence lengths + arbitrary same as attention_mask bottom right Parameters ---------- @@ -1056,27 +1071,73 @@ def get_swa_mask( Returns ---------- + attn_mask_type: str, default = `no_mask` + New attention mask type "arbitrary". attention_mask: torch.Tensor - Combined `attention_mask` (input) and sliding window attention mask. - The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None; - else, the same shape as input `attention_mask`. + Result after combining input mask and sliding window mask. + actual_seqlens_q: torch.Tensor + For padding masks, the actual sequence lengths for queries, in shape [batch_size]. + For other masks, `None`. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. + For other masks, `None`. """ - mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda") - if attn_mask_type in ["causal"]: - left = window_size[0] if window_size[0] != -1 else max_seqlen_q - right = window_size[1] if window_size[1] != -1 else max_seqlen_q - mask_upper = torch.triu(mask, diagonal=-left) - mask_lower = torch.tril(mask_upper, diagonal=right) + # perform basic checks + if window_size is None: + window_size = (-1, -1) + if "causal" in attn_mask_type: + window_size = (window_size[0], 0) + window_size = ( + max_seqlen_kv if window_size[0] == -1 else window_size[0], + max_seqlen_q if window_size[1] == -1 else window_size[1] + ) + + # apply padding mask + actual_seqlens_q = None + actual_seqlens_kv = None + if "padding" in attn_mask_type: + if max_seqlen_q == max_seqlen_kv: + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + batch_size = attention_mask.shape[0] + m = attention_mask.logical_not() + actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) + actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) + + # apply SWA mask + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + if attn_mask_type in ["no_mask", "causal_bottom_right", "arbitrary"]: + swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] + swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] + elif attn_mask_type in ["causal", "padding_causal"]: + swa_left = mask - window_size[0] + swa_right = mask + window_size[1] + elif attn_mask_type in ["padding", "padding_causal_bottom_right"]: + swa_left = mask.expand( + batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q - window_size[0]).view(batch_size, 1, 1, 1) + swa_right = mask.expand( + batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q + window_size[1]).view(batch_size, 1, 1, 1) + swa_mask = torch.logical_not(torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right <0, 1, 0)) + if attention_mask is not None: + attention_mask = torch.logical_or(swa_mask, attention_mask) else: - left = window_size[0] if window_size[0] != -1 else max_seqlen_kv - right = window_size[1] if window_size[1] != -1 else max_seqlen_kv - mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left) - mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right) + attention_mask = swa_mask + + # change type to arbitrary attn_mask_type = "arbitrary" - mask = mask_lower.logical_not() - if attention_mask is not None: - mask = torch.logical_and(attention_mask, mask) - return attn_mask_type, mask + + return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv @torch.no_grad() @@ -4731,6 +4792,7 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + window_size: Optional[Tuple[int, int]] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -4750,53 +4812,10 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) - if "padding" in attn_mask_type: - if self.attention_type == "self": - assert attention_mask.shape == ( - batch_size, - 1, - 1, - max_seqlen_q, - ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" - attention_mask = torch.logical_or( - attention_mask.squeeze(1).unsqueeze(3), attention_mask - ) - else: - assert ( - len(attention_mask) == 2 - and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) - and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) - ), ( - "attention_mask should be a tuple of two tensors with shapes " - "[b, 1, 1, sq] and [b, 1, 1, skv]!" - ) - attention_mask = torch.logical_or( - attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] - ) - mask = attention_mask.squeeze(1).logical_not() - actual_seqlens_q = mask[:, :, 0].sum(dim=1) - actual_seqlens_kv = mask[:, 0, :].sum(dim=1) - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - if attn_mask_type == "padding_causal": - attention_mask = torch.logical_or( - torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), - attention_mask, - ) - if attn_mask_type == "padding_causal_bottom_right": - attention_mask = torch.logical_or( - torch.where( - mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) - < 0, - 1, - 0, - ), - attention_mask, - ) + + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_swa_mask( + window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -8205,12 +8224,6 @@ def forward( ) if use_unfused_attention: - if window_size is not None and ( - window_size[0] != -1 or window_size[1] not in [-1, 0] - ): - attn_mask_type, attention_mask = get_swa_mask( - window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask - ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -8222,6 +8235,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -8235,6 +8249,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, From 4b1999698af7c72dcea18a2f1f814b42716c49af Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 18 Oct 2024 23:02:03 -0700 Subject: [PATCH 2/5] fix mask type setting Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7d394a6970..b67257b258 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1082,7 +1082,9 @@ def get_swa_mask( For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. For other masks, `None`. """ - # perform basic checks + # perform basic checks and change mask type + if window_size is not None and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + attn_mask_type = "arbitrary" if window_size is None: window_size = (-1, -1) if "causal" in attn_mask_type: @@ -1134,9 +1136,6 @@ def get_swa_mask( else: attention_mask = swa_mask - # change type to arbitrary - attn_mask_type = "arbitrary" - return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv From 7f08d477e82fed1b3ba477568bfe52a45c12a3fb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 18 Oct 2024 23:16:07 -0700 Subject: [PATCH 3/5] fix the order of checking valid swa and changing mask type Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b67257b258..04cb4bd220 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1082,9 +1082,8 @@ def get_swa_mask( For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. For other masks, `None`. """ - # perform basic checks and change mask type - if window_size is not None and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - attn_mask_type = "arbitrary" + # perform basic checks + change_type = window_size is not None and (window_size[0] != -1 or window_size[1] not in [-1, 0]) if window_size is None: window_size = (-1, -1) if "causal" in attn_mask_type: @@ -1136,6 +1135,10 @@ def get_swa_mask( else: attention_mask = swa_mask + # change mask type + if change_type: + attn_mask_type = "arbitrary" + return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv From 3a22f934ca20a08cbe621f3c7c2bc1054d975b74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 20:33:41 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_fused_attn.py | 8 ++++-- transformer_engine/pytorch/attention.py | 28 +++++++++++---------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 72694d05a2..ea136cfb76 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -541,8 +541,12 @@ def test_dpa_bias_shapes(dtype, model_configs, model): "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "swa_6_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"), - "swa_6_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), + "swa_6_0": ModelConfig( + 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_1": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 04cb4bd220..9a66e534bd 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1083,14 +1083,16 @@ def get_swa_mask( For other masks, `None`. """ # perform basic checks - change_type = window_size is not None and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + change_type = window_size is not None and ( + window_size[0] != -1 or window_size[1] not in [-1, 0] + ) if window_size is None: window_size = (-1, -1) if "causal" in attn_mask_type: window_size = (window_size[0], 0) window_size = ( - max_seqlen_kv if window_size[0] == -1 else window_size[0], - max_seqlen_q if window_size[1] == -1 else window_size[1] + max_seqlen_kv if window_size[0] == -1 else window_size[0], + max_seqlen_q if window_size[1] == -1 else window_size[1], ) # apply padding mask @@ -1113,9 +1115,7 @@ def get_swa_mask( # apply SWA mask mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) if attn_mask_type in ["no_mask", "causal_bottom_right", "arbitrary"]: swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] @@ -1123,13 +1123,15 @@ def get_swa_mask( swa_left = mask - window_size[0] swa_right = mask + window_size[1] elif attn_mask_type in ["padding", "padding_causal_bottom_right"]: - swa_left = mask.expand( - batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( - actual_seqlens_kv - actual_seqlens_q - window_size[0]).view(batch_size, 1, 1, 1) - swa_right = mask.expand( - batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( - actual_seqlens_kv - actual_seqlens_q + window_size[1]).view(batch_size, 1, 1, 1) - swa_mask = torch.logical_not(torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right <0, 1, 0)) + swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q - window_size[0] + ).view(batch_size, 1, 1, 1) + swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q + window_size[1] + ).view(batch_size, 1, 1, 1) + swa_mask = torch.logical_not( + torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) + ) if attention_mask is not None: attention_mask = torch.logical_or(swa_mask, attention_mask) else: From 5f5c5c35e09bed4c3e4fcbb69e48823df53b89ef Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:50:47 -0700 Subject: [PATCH 5/5] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9a66e534bd..3dba22ae24 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1107,7 +1107,6 @@ def get_swa_mask( attention_mask = torch.logical_or( attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] ) - batch_size = attention_mask.shape[0] m = attention_mask.logical_not() actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) @@ -1116,6 +1115,8 @@ def get_swa_mask( mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + swa_left = None + swa_right = None if attn_mask_type in ["no_mask", "causal_bottom_right", "arbitrary"]: swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] @@ -1123,6 +1124,7 @@ def get_swa_mask( swa_left = mask - window_size[0] swa_right = mask + window_size[1] elif attn_mask_type in ["padding", "padding_causal_bottom_right"]: + batch_size = attention_mask.shape[0] swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( actual_seqlens_kv - actual_seqlens_q - window_size[0] ).view(batch_size, 1, 1, 1)