Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix FP8 activation recompute #1254

Merged
merged 3 commits into from
Oct 16, 2024

Conversation

ksivaman
Copy link
Member

@ksivaman ksivaman commented Oct 15, 2024

Description

The amax reduction of all backward tensors happens in the first module (one of the base modules) in a given fp8_autocast. The ctx.reduce_and_update_bwd_fp8_tensors flag is saved by querying the FP8GlobalStateManager.is_first_fp8_module() which only returns True for the first module in the fp8_autocast. However, this introduces a bug during activation recompute since the recompute phase runs outside the fp8 context, and the first module flags are never set. This results in the amaxes for gradients not getting reduced.

Fixes #1190

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

The activation_recompute_forward maintains a queue structure to pass values of the IS_FIRST_FP8_MODULE flag from the forward phase to the recompute phase. During the recompute phase, it is reset back to not disturb any nested autocasts.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ksivaman ksivaman added the bug Something isn't working label Oct 15, 2024
@ksivaman ksivaman requested a review from denera October 15, 2024 13:41
@ksivaman ksivaman self-assigned this Oct 15, 2024
@ksivaman ksivaman marked this pull request as draft October 15, 2024 13:41
@ksivaman
Copy link
Member Author

/te-ci pytorch

Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ptrendx
Copy link
Member

ptrendx commented Oct 15, 2024

@ksivaman Could you put some more information about the bug and the fix in the description?

@ksivaman
Copy link
Member Author

/te-ci pytorch

@ksivaman ksivaman merged commit a518151 into NVIDIA:main Oct 16, 2024
14 of 15 checks passed
timmoon10 pushed a commit to timmoon10/TransformerEngine that referenced this pull request Nov 7, 2024
Fix FP8 activation recompute

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[PyTorch] FP8 and activation checkpointing causes training instabilities
3 participants