diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index c7b0f1b04d..0e8a7f1e94 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -811,7 +811,7 @@ def _run_transformer_layer( rotary_pos_emb = None if RoPE: PE = RotaryPositionEmbedding(dim=config.head_dim) - rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda") + rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") # Set up model block = ( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3926fec3de..a67ed8c0b4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3625,8 +3625,8 @@ def forward( # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format) + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) context_layer = self.core_attention( query_layer,