Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix auto-microbatch] FSDP reshard and cleanup after OOM to fix the cuda memory leak #3030

Merged
merged 22 commits into from
Feb 22, 2024

Conversation

bigning
Copy link
Contributor

@bigning bigning commented Feb 18, 2024

What does this PR do?

There are two issues with the current auto-microbatch.

  1. nccl timeout or hang
  2. it finds a lower device train micro-batch size or still OOM with device_train_microbatch_size = 1

This PR targets at 2nd issue. PR 3016 (WIP) targets the 1st issue.

For the 2nd memory issue, it's because if OOM happens during fsdp fwd/bwd after unshard, we need to manually reshard and cleanup. Otherwise, there is memory leaks like this:
image

There is an existing FSDP backward callback _post_backward_final_callback which tries to reshard and cleanup at the end of fsdp backward. We just call this API after OOM happens.

test

unit test

python -m composer.cli.launcher -n 2 -m pytest -m gpu tests/trainer/test_fsdp.py -k test_fsdp_reshard_after_oom
python -m composer.cli.launcher -n 2 -m pytest -m gpu tests/trainer/test_fsdp.py -k test_fsdp_same_state_after_oom_reshard

end-to-end test

llama2 13B finetuning on 16 A100-40G GPUs with global train batch size = 1024

  • baseline: llama2-13b-baseline-new-Ln2FYB OOMed even with device-train-microbatch-size = 1
  • test: llama2-13b-test-16-40g-r7z22-IBHkBe run well with device-train-microbatch-size = 2

llama2 70B finetuning on 32 H100-80G GPUs with global train batch size = 1024

  • baseline: llama2-70b-8192-32-h100-80g-baseline-npn6Qj OOMed even with device-train-microbatch-size = 1
  • test: llaama2-70b-8192-32-h100-80g-test-vbLmLL chose device-train-microbatch-size=2 and run well before evaluation. After evaluation, it has nccl timeout error, which is a different issue. We'll debug and fix the timeout in a separate PR.
  • fixed device-train-microbatch-size=4: llama2-70b-32-h100-80g-fixed-4-KKDQ7I OOMed.

@bigning bigning marked this pull request as ready for review February 20, 2024 17:30
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! A few minor nits.

Would like @cli99 's thoughts too

composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
tests/trainer/test_fsdp.py Outdated Show resolved Hide resolved
tests/trainer/test_fsdp.py Outdated Show resolved Hide resolved
Copy link
Contributor

@cli99 cli99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test. Looks good to me.

@awgu
Copy link

awgu commented Feb 20, 2024

Thanks for calling this out and great catch! I think that calling the _post_backward_final_callback sounds good to me.

@bigning
Copy link
Contributor Author

bigning commented Feb 20, 2024

Thanks for calling this out and great catch! I think that calling the _post_backward_final_callback sounds good to me.

Thank you @awgu for the quick response!

composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
composer/trainer/trainer.py Outdated Show resolved Hide resolved
tests/trainer/test_fsdp.py Outdated Show resolved Hide resolved
tests/trainer/test_fsdp.py Outdated Show resolved Hide resolved
tests/trainer/test_fsdp.py Outdated Show resolved Hide resolved
bigning and others added 2 commits February 21, 2024 13:09
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
@bigning bigning requested a review from mvpatel2000 February 21, 2024 21:20
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really nice PR. Super great catch

@bigning bigning enabled auto-merge (squash) February 21, 2024 22:46
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

tests/trainer/test_fsdp.py Outdated Show resolved Hide resolved
@bigning bigning merged commit 2133c17 into dev Feb 22, 2024
14 checks passed
@bigning bigning deleted the reshard-after-oom branch February 22, 2024 02:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants