Skip to content

Commit

Permalink
flash attention 3
Browse files Browse the repository at this point in the history
  • Loading branch information
jquesnelle committed Aug 16, 2024
1 parent 62f824b commit c19a3cd
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
from flash_attn_interface import _flash_attn_backward, _flash_attn_forward

from .utils import RingComm, update_out_and_lse

Expand Down Expand Up @@ -29,17 +29,12 @@ def zigzag_ring_flash_attn_forward(
next_k, next_v = None, None

def forward(q, k, v, causal):
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
block_out, _, _, _, _, block_lse, _ = _flash_attn_forward(
q,
k,
v,
dropout_p,
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
causal
)
return block_out, block_lse

Expand Down Expand Up @@ -125,14 +120,9 @@ def backward(doubt, q, k, v, out, softmax_lse, causal):
dq_buffer[:, :seqlen_q],
dk_buffer[:, :seqlen_kv],
dv_buffer[:, :seqlen_kv],
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
rng_state=None,
deterministic
)

for step in range(kv_comm.world_size):
Expand Down

0 comments on commit c19a3cd

Please sign in to comment.