Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix get_swa_mask() for padding masks #1281

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model):

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_0": ModelConfig(
4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_1": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}


Expand Down
169 changes: 95 additions & 74 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,9 +1030,24 @@ def get_swa_mask(
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
and for other mask types, the bottom right corner.
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. Requirements for the
shapes of `attention_mask` given an `attn_mask_type` are the same as in DotProductAttention.
For "`causal`" and "`padding_causal`" mask types, the sliding window diagonal is aligned to the
top left corner of the softmax matrix; for others, the bottom right corner. Note that when padding
is applied, the bottom right corner comes from the [actual_seqlen_q[i], actual_seqlen_kv[i]] matrix,
for each batch i, not the [max_seqlen_q, max_seqlen_kv] matrix.::

attn_mask_type output shape diagonal alignment
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment

--------------------------------------------------------------------------------------------
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
causal [1, 1, max_seqlen_q, max_seqlen_kv] top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
actual sequence lengths
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
actual sequence lengths
arbitrary same as attention_mask bottom right

Parameters
----------
Expand All @@ -1056,27 +1071,79 @@ def get_swa_mask(

Returns
----------
attn_mask_type: str, default = `no_mask`
New attention mask type "arbitrary".
attention_mask: torch.Tensor
Combined `attention_mask` (input) and sliding window attention mask.
The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
else, the same shape as input `attention_mask`.
Result after combining input mask and sliding window mask.
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
For other masks, `None`.
"""
mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
if attn_mask_type in ["causal"]:
left = window_size[0] if window_size[0] != -1 else max_seqlen_q
right = window_size[1] if window_size[1] != -1 else max_seqlen_q
mask_upper = torch.triu(mask, diagonal=-left)
mask_lower = torch.tril(mask_upper, diagonal=right)
else:
left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
attn_mask_type = "arbitrary"
mask = mask_lower.logical_not()
# perform basic checks
change_type = window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
)
if window_size is None:
window_size = (-1, -1)
if "causal" in attn_mask_type:
window_size = (window_size[0], 0)
window_size = (
max_seqlen_kv if window_size[0] == -1 else window_size[0],
max_seqlen_q if window_size[1] == -1 else window_size[1],
)

# apply padding mask
actual_seqlens_q = None
actual_seqlens_kv = None
if "padding" in attn_mask_type:
if max_seqlen_q == max_seqlen_kv:
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
m = attention_mask.logical_not()
actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)

# apply SWA mask
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
swa_left = None
swa_right = None
if attn_mask_type in ["no_mask", "causal_bottom_right", "arbitrary"]:
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
elif attn_mask_type in ["causal", "padding_causal"]:
swa_left = mask - window_size[0]
swa_right = mask + window_size[1]
elif attn_mask_type in ["padding", "padding_causal_bottom_right"]:
batch_size = attention_mask.shape[0]
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q - window_size[0]
).view(batch_size, 1, 1, 1)
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
if attention_mask is not None:
mask = torch.logical_and(attention_mask, mask)
return attn_mask_type, mask
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
attention_mask = swa_mask

# change mask type
if change_type:
attn_mask_type = "arbitrary"

return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv


@torch.no_grad()
Expand Down Expand Up @@ -4731,6 +4798,7 @@ def forward(
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
Expand All @@ -4750,53 +4818,10 @@ def forward(
query_layer.shape[0],
key_layer.shape[0],
)
if "padding" in attn_mask_type:
if self.attention_type == "self":
assert attention_mask.shape == (
batch_size,
1,
1,
max_seqlen_q,
), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
assert (
len(attention_mask) == 2
and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
), (
"attention_mask should be a tuple of two tensors with shapes "
"[b, 1, 1, sq] and [b, 1, 1, skv]!"
)
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
mask = attention_mask.squeeze(1).logical_not()
actual_seqlens_q = mask[:, :, 0].sum(dim=1)
actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if attn_mask_type == "padding_causal":
attention_mask = torch.logical_or(
torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
attention_mask,
)
if attn_mask_type == "padding_causal_bottom_right":
attention_mask = torch.logical_or(
torch.where(
mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
+ (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
< 0,
1,
0,
),
attention_mask,
)

attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
)

batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
Expand Down Expand Up @@ -8205,12 +8230,6 @@ def forward(
)

if use_unfused_attention:
if window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
):
attn_mask_type, attention_mask = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
Expand All @@ -8222,6 +8241,7 @@ def forward(
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
Expand All @@ -8235,6 +8255,7 @@ def forward(
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
Expand Down
Loading