diff --git a/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py b/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py index 40a22088..6283700b 100644 --- a/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py +++ b/src/nanotron/parallel/ring_flash_attn/zigzag_ring_flash_attn.py @@ -1,5 +1,5 @@ import torch -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward +from flash_attn_interface import _flash_attn_backward, _flash_attn_forward from .utils import RingComm, update_out_and_lse @@ -29,17 +29,12 @@ def zigzag_ring_flash_attn_forward( next_k, next_v = None, None def forward(q, k, v, causal): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + block_out, _, _, _, _, block_lse, _ = _flash_attn_forward( q, k, v, - dropout_p, softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, + causal ) return block_out, block_lse @@ -125,14 +120,9 @@ def backward(doubt, q, k, v, out, softmax_lse, causal): dq_buffer[:, :seqlen_q], dk_buffer[:, :seqlen_kv], dv_buffer[:, :seqlen_kv], - dropout_p, softmax_scale, causal, - window_size, - softcap, - alibi_slopes, - deterministic, - rng_state=None, + deterministic ) for step in range(kv_comm.world_size):