diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index fb1fc97a33..b2968a688d 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -95,9 +95,21 @@ std::vector fused_attn_fwd_qkvpacked( auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } std::vector o_shape{q_shape.begin(), q_shape.end()}; @@ -252,9 +264,21 @@ std::vector fused_attn_bwd_qkvpacked( auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } auto h = q_shape[q_shape.size() - 2];