Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support split qkv linear and sp overlap comm #415

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

inkcherry
Copy link

@inkcherry inkcherry commented Jul 5, 2024

work with microsoft/DeepSpeed#5691
when use ds_sequence_parallel, open the following 2 flags to enable overlap comm.
--split-qkv-linear
--ds-sequence-parallel-overlap-comm

loadams pushed a commit to microsoft/DeepSpeed that referenced this pull request Aug 1, 2024
SP is a fantastic piece of work, it is very elegant and concise, at the
current stage, a transformer layer's forward and backward passes involve
8 all-to-all operations, with 5 opportunities for overlapping
communication:

Forward pass: The QKV matrix operations can be pipelined alongside some
of the all-to-all communications.
Backward pass: DQ, DK, DV all-to-all communications can be pipelined
alongside matrix operations.
Backward pass: DO_w can be parallel with DO_input, involving matrix
operations and all-to-all communications. Similar overlap-comm
strategies are used in Megatron for TP/TP-sp parallelism.
I tested under conditions of 1N8C zero1, disabled activation
checkpointing, ds-sp=8, and gbs=16:
1B 64K
7B 16K
They showed over 10% improvement (where I found that for mega-ds, using
split QKV itself can also enhance performance due to reducing slice +
cat operations in fwd/bwd), despite some TFLOPs already performing at a
relatively good level.
co-work with microsoft/Megatron-DeepSpeed#415

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
@delock
Copy link

delock commented Aug 29, 2024

microsoft/DeepSpeed#5691 is merged. @inkcherry do you still need this PR be reviewed? Can you resolve conflict on this branch?

@inkcherry
Copy link
Author

@tohtana , @loadams notice microsoft/DeepSpeed#5691 is merged, could you merge this one ? thanks!

@yingtongxiong
Copy link

Hello,When I run the pretrain_gpt.py,I met the following bugs,
image
@inkcherry

@inkcherry
Copy link
Author

inkcherry commented Nov 5, 2024

@yingtongxiong If using this branch,
could you try to update the DeepSpeed version (to 8.30 or later) and enable flash-v2,disable activation_checkpoint to test it out refer(https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples_deepspeed/sequence_parallel/ds_pretrain_gpt_1.3B_seq_parallel_32k.sh)?
Would this issue occur if we don’t enable two overlap options?

@loadams
Copy link

loadams commented Nov 7, 2024

Hi @inkcherry - could you take a look at resolving the merge conflicts on this?

@inkcherry
Copy link
Author

inkcherry commented Nov 14, 2024

Hi @inkcherry - could you take a look at resolving the merge conflicts on this?

Hi, @loadams ,
I resolved the conflict and noticed that in the latest version of DeepSpeed, a view operation was missing in some updates compared to the original version https://github.com/microsoft/DeepSpeed/blob/17ed7c77c58611a923a6c8d2a3d21d359cd046e8/deepspeed/sequence/layer.py#L56 , which caused the issue. I added it back microsoft/DeepSpeed#6750 and validated it with a loss check.

Currently master mds + master ds (197~200 steps):

lm loss: 8.855590E+00
lm loss: 8.892502E+00
lm loss: 8.766361E+00
lm loss: 8.618977E+00 

this branch + ds fix patch + enable overlap(197~200 steps):

lm loss: 8.855516E+00
lm loss: 8.890095E+00 
lm loss: 8.765872E+00
lm loss: 8.620874E+00

@yingtongxiong
Copy link

yingtongxiong commented Nov 25, 2024

Hello,When I run the pretrain_gpt.py,I met the following bugs, image @inkcherry

img_v3_02gv_cd713df8-ae8e-4b17-88ee-9d06812f246g

hello, and now I met this problem, the run python file is the pretrain_gpt.py

@yingtongxiong
Copy link

@yingtongxiong If using this branch, could you try to update the DeepSpeed version (to 8.30 or later) and enable flash-v2,disable activation_checkpoint to test it out refer(https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples_deepspeed/sequence_parallel/ds_pretrain_gpt_1.3B_seq_parallel_32k.sh)? Would this issue occur if we don’t enable two overlap options?

I can run this shell (where I enable flash-v2 and disable activation-checkpoint) if I don't enable two overlap options.

@inkcherry
Copy link
Author

@yingtongxiong
Yes, please run together with this fix. microsoft/DeepSpeed#6750

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants