Skip to content

Commit

Permalink
fixing the bug of flash_attn import and the wrong gather index when u…
Browse files Browse the repository at this point in the history
…sing flash_attn_cuda in sequence parallel (#406)

Co-authored-by: Jinghan Yao <jyao@orthus.nic.uoregon.edu>
  • Loading branch information
YJHMITWEB and Jinghan Yao authored Aug 1, 2024
1 parent 8822a5c commit 1bfc35c
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@
try:
# FlashAttention (1.x)
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_triton import flash_attn_func
except ImportError:
flash_attn_unpadded_func = None

try:
from flash_attn.flash_attn_triton import flash_attn_func
except ImportError:
flash_attn_func = None

try:
Expand Down Expand Up @@ -599,7 +602,11 @@ def __init__(self, config, layer_number,
if self.enable_ds_sequence_parallel:
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group())
self.dist_attn = DistributedAttention(
local_attn,
parallel_state.get_sequence_parallel_group(),
gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0)
# flash_attn_cuda assumes [b, s, nh, hd] layout, we need to make sure all2all gathers into the correct sequence dimension.
else:
if self.use_flash_attn:
self.core_attention_flash = local_attn
Expand Down

0 comments on commit 1bfc35c

Please sign in to comment.