-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix auto-microbatch] FSDP reshard and cleanup after OOM to fix the cuda memory leak #3030
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! A few minor nits.
Would like @cli99 's thoughts too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test. Looks good to me.
Thanks for calling this out and great catch! I think that calling the |
Thank you @awgu for the quick response! |
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a really nice PR. Super great catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
What does this PR do?
There are two issues with the current auto-microbatch.
This PR targets at 2nd issue. PR 3016 (WIP) targets the 1st issue.
For the 2nd memory issue, it's because if OOM happens during fsdp fwd/bwd after unshard, we need to manually reshard and cleanup. Otherwise, there is memory leaks like this:
There is an existing FSDP backward callback _post_backward_final_callback which tries to reshard and cleanup at the end of fsdp backward. We just call this API after OOM happens.
test
unit test
python -m composer.cli.launcher -n 2 -m pytest -m gpu tests/trainer/test_fsdp.py -k test_fsdp_reshard_after_oom
python -m composer.cli.launcher -n 2 -m pytest -m gpu tests/trainer/test_fsdp.py -k test_fsdp_same_state_after_oom_reshard
end-to-end test
llama2 13B finetuning on 16 A100-40G GPUs with global train batch size = 1024
llama2 70B finetuning on 32 H100-80G GPUs with global train batch size = 1024