Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] grad_weight can't be NoneType when running with DeepSpeed on Zero3. #428

Merged
merged 1 commit into from
Sep 20, 2024

Conversation

ys950902
Copy link

When running the Megatron-DeepSpeed with DeepSpeed on zero3, the grad_weight is set to None by default, that will cause the error issue below:
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1253
AttributeError: 'NoneType' object has no attribute 'numel'
This pr is to fix this error issue, the weight.grad equals to grad_weight.

@ys950902
Copy link
Author

@tjruwase could you please take a look on this pr, it will block the users to use the zero3 stage.

@ys950902
Copy link
Author

ys950902 commented Sep 3, 2024

Any concern for this pr?

@tjruwase
Copy link

tjruwase commented Sep 3, 2024

@ys950902, apologies for the delay on this PR. I am confused about this issue since Megatron-DeepSpeed and ZeRO3 have always been compatible. Can you share some repro details to help my understanding? Thanks!

@ys950902
Copy link
Author

ys950902 commented Sep 19, 2024

@ys950902, apologies for the delay on this PR. I am confused about this issue since Megatron-DeepSpeed and ZeRO3 have always been compatible. Can you share some repro details to help my understanding? Thanks!

Hi @tjruwase, sorry for late response, for zero-bubble added in latest Megatron-DeepSpeed, the backward computation will be divided in two parts(weight gradient and output), and the weight gradient will not compute at once and will set to be None, when in running zero2/3 will divided the gradient may cause some unexpected error, this pr is to add the flag when doing pipeline-parallelism go the current path, when doing zero2/3 go the former path.

@ys950902
Copy link
Author

Hi @tjruwase, if you still have some concern for this pr, please let me know.

@tjruwase
Copy link

Hi @tjruwase, sorry for late response, for zero-bubble added in latest Megatron-DeepSpeed, the backward computation will be divided in two parts(weight gradient and output),

Thanks for the explanation. In that case, shouldn't this behavior depend on whether zero-bubble is enabled? In other words, check for args.enable_zbh1_pipeline. Can you please clarify?

@ys950902
Copy link
Author

Hi @tjruwase, sorry for late response, for zero-bubble added in latest Megatron-DeepSpeed, the backward computation will be divided in two parts(weight gradient and output),

Thanks for the explanation. In that case, shouldn't this behavior depend on whether zero-bubble is enabled? In other words, check for args.enable_zbh1_pipeline. Can you please clarify?

Thanks for your reply, yes, maybe check for args.enable_zbh1_pipeline is more reasonable, because for pipelien-parallelism, only supported for ZERO1/ZERO0 on Deepspeed, it won't divided the gradient, so is also okay set to NONE on weight gradient calculation. I will modify the check to args.enable_zbh1_pipeline to avoid the confuse.

@ys950902
Copy link
Author

ys950902 commented Sep 20, 2024

Hi @tjruwase, I noticed the bug issue #442 in zero-bubble, so I did a little modified to better understanding for customers, and this pr is the solution for this bug, Only enable args.enable_zbh1_pipeline go this path calculate the weight gradient when schedule is pop, when is not enable go the former path to calculate the weight gradient. I think this is more clear. And if approved, could you please also merge this pr, then we can upgrade the Megatron-DeepSpeed to latest for next release.

@tjruwase tjruwase merged commit 598c092 into microsoft:main Sep 20, 2024
5 checks passed
@ys950902 ys950902 deleted the sy/bug_fix branch September 23, 2024 05:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants