Skip to content

Commit

Permalink
improve get_attention_backend logic
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
  • Loading branch information
cyanguwa committed Sep 27, 2024
1 parent 209b8e5 commit 4b0ce23
Showing 1 changed file with 55 additions and 34 deletions.
89 changes: 55 additions & 34 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4669,6 +4669,12 @@ def get_qkv_layout(
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
q: torch.Tensor
Query tensor, which may be different from input `q`.
k: torch.Tensor
Key tensor, which may be different from input `k`.
v: torch.Tensor
Value tensor, which may be different from input `v`.
"""

check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
Expand All @@ -4677,66 +4683,81 @@ def get_qkv_layout(
def run_iteratively(q, k, v):
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]

stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
sv / v.shape[-1] for sv in v.stride()[:-1]
)

shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]

last_two_dims_size = q.shape[-1] * q.shape[-2]
check_3hd_offsets = all(
x.storage_offset() == i * last_two_dims_size for i, x in enumerate([q, k, v])
)
last_dim_size = q.shape[-1]
check_last_dim_offsets_qkv = all(
i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
check_h3d_offsets = all(
x.storage_offset() == i * last_dim_size for i, x in enumerate([q, k, v])
)

all_dims_size = [np.prod(x.shape) for x in [q, k]]
offset = all_dims_size[0] if check_ptrs_qkv else 0
last_two_dims_size = k.shape[-1] * k.shape[-2]
check_2hd_offsets = all(
x.storage_offset() == (offset + i * last_two_dims_size) for i, x in enumerate([k, v])
)
last_dim_size = k.shape[-1]
check_last_dim_offsets_kv = all(
i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
check_h2d_offsets = all(
x.storage_offset() == (offset + i * last_dim_size) for i, x in enumerate([k, v])
)

last_two_dims_size = q.shape[-1] * q.shape[-2]
check_last_two_dims_offsets_qkv = all(
i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
check_hd_offsets_qkv = all(
x.storage_offset() == sum(all_dims_size[:i]) for i, x in enumerate([q, k, v])
) if check_ptrs_qkv else all(
x.storage_offset() == 0 for i, x in enumerate([q, k, v])
)
last_two_dims_size = k.shape[-1] * k.shape[-2]
check_last_two_dims_offsets_kv = all(
i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
check_hd_offsets_qk = all(
x.storage_offset() == sum(all_dims_size[:i]) for i, x in enumerate([q, k])
) if not check_ptrs_qkv and check_ptrs_qk else all(
x.storage_offset() == 0 for i, x in enumerate([q, k])
)
check_hd_offsets_kv = all(
x.storage_offset() == sum(all_dims_size[1:i+1]) for i, x in enumerate([k, v])
) if not check_ptrs_qkv and check_ptrs_kv else all(
x.storage_offset() == 0 for i, x in enumerate([k, v])
)

if (
check_ptrs_qkv
and check_strides_qkv
and check_shapes_qkv
and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv
):
if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
# sb3hd, bs3hd, t3hd
# one chunk of memory with q, k, v interleaved at dim=-3
qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
elif (
check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
):
elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
# sbh3d, bsh3d, th3d
# one chunk of memory with q, k, v interleaved at dim=-2
qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
elif (
check_ptrs_kv
and check_strides_kv
and check_shapes_kv
and check_last_two_dims_offsets_kv
and not check_last_dim_offsets_kv
):
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
# two chunks of memory (q and kv) with k, v interleaved at dim=-3 in kv
# q and kv may be disjoint or consecutive in memory
# when consecutive, they may share the same data_ptr()
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
# two chunks of memory (q and kv) with k, v interleaved at dim=-2 in kv
# q and kv may be disjoint or consecutive in memory
# when consecutive, they may share the same data_ptr()
qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
elif check_strides_kv and check_shapes_kv:
elif check_strides_kv and check_shapes_kv and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk):
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
# three chunks of memory (q, k and v) which can be disjoint or consecutive
# when consecutive, they may share the same data_ptr()
qkv_layout = "_".join(list([qkv_format]) * 3)
else:
qkv_layout = "not_supported"
Expand Down

0 comments on commit 4b0ce23

Please sign in to comment.