From 1bfc35c4e279b4edcd8291cb99dd3992e02bc8ba Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 1 Aug 2024 15:35:11 -0700 Subject: [PATCH] fixing the bug of flash_attn import and the wrong gather index when using flash_attn_cuda in sequence parallel (#406) Co-authored-by: Jinghan Yao --- megatron/model/transformer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e79abea3cf..7467190582 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -37,9 +37,12 @@ try: # FlashAttention (1.x) from flash_attn.flash_attn_interface import flash_attn_unpadded_func - from flash_attn.flash_attn_triton import flash_attn_func except ImportError: flash_attn_unpadded_func = None + +try: + from flash_attn.flash_attn_triton import flash_attn_func +except ImportError: flash_attn_func = None try: @@ -599,7 +602,11 @@ def __init__(self, config, layer_number, if self.enable_ds_sequence_parallel: assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) + self.dist_attn = DistributedAttention( + local_attn, + parallel_state.get_sequence_parallel_group(), + gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0) + # flash_attn_cuda assumes [b, s, nh, hd] layout, we need to make sure all2all gathers into the correct sequence dimension. else: if self.use_flash_attn: self.core_attention_flash = local_attn