Skip to content

Commit

Permalink
Use fused implementation of RoPE in MultiHeadAttention (#658)
Browse files Browse the repository at this point in the history
* Use fused implementation of RoPE in MultiHeadAttention

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Fix freqs dtype

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ptrendx and ksivaman authored Feb 15, 2024
1 parent 1e78094 commit 8d62d5c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def _run_transformer_layer(
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda")
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")

# Set up model
block = (
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3625,8 +3625,8 @@ def forward(
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format)
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)

context_layer = self.core_attention(
query_layer,
Expand Down

0 comments on commit 8d62d5c

Please sign in to comment.