Skip to content

Commit

Permalink
[PyTorch] Fix detection of 3 in 3hd/h3d layouts (#1187)
Browse files Browse the repository at this point in the history
* fix detection of 3 in 3hd/h3d layouts

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* error out when invalid layout group is provided

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] committed Sep 27, 2024
1 parent c4a5cb8 commit 8a1b7ee
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,21 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
Expand Down Expand Up @@ -252,9 +264,21 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
auto h = q_shape[q_shape.size() - 2];
Expand Down

0 comments on commit 8a1b7ee

Please sign in to comment.