Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] FP8 and activation checkpointing causes training instabilities #1190

Closed
Marks101 opened this issue Sep 18, 2024 · 2 comments · Fixed by #1254
Closed

[PyTorch] FP8 and activation checkpointing causes training instabilities #1190

Marks101 opened this issue Sep 18, 2024 · 2 comments · Fixed by #1254
Assignees

Comments

@Marks101
Copy link
Contributor

Hello team,

we noticed training instabilities when combining FP8 and activation checkpointing with transformer_engine.pytorch.checkpoint. When taking a closer look at this, we got the feeling that the FP8 scales in the backward pass are not updated properly. Here is a code snippet that is meant to show this behavior:

import torch
import transformer_engine as te

hidden_size = 768

def run_iterations(checkpoint: bool):
    torch.manual_seed(12345)
    model = te.pytorch.Linear(hidden_size, hidden_size)

    result = []
    for it in range(2):
        x = torch.randn((hidden_size, hidden_size), requires_grad=True, device="cuda")
        y_grad = torch.randn_like(x)

        with te.pytorch.fp8_autocast():
            if checkpoint:
                y = te.pytorch.checkpoint(model, x)
            else:
                y = model(x)

        y.backward(y_grad)
        result.append(dict(it=it, y=y, x_grad=x.grad, scaling_bwd=model.fp8_meta["scaling_bwd"].scale[0].item()))

    return result

result_ref = run_iterations(checkpoint=False)
result_cp = run_iterations(checkpoint=True)

for r_ref, r_cp in zip(result_ref, result_cp):
    max_diff_y = torch.max(torch.abs(r_ref["y"] - r_cp["y"])).item()
    max_diff_x_grad = torch.max(torch.abs(r_ref["x_grad"] - r_cp["x_grad"])).item()
    print(f"it={r_ref['it']}: {max_diff_y=}, {max_diff_x_grad=:.02}")
    print(f"      scale bwd: ref={r_ref['scaling_bwd']:.2}, cp={r_cp['scaling_bwd']:.2}")

The snippet creates a linear layer and runs two forward passes on it with FP8. This is excuted once with and the other time without activation checkpointing. The results are recorded and compared at the end.

When executing this code I get the following output:

it=0: max_diff_y=0.0, max_diff_x_grad=0.0
      scale bwd: ref=1.1e+04, cp=1.0
it=1: max_diff_y=0.0, max_diff_x_grad=0.21
      scale bwd: ref=1.1e+04, cp=1.0

We can see that the forward passes give the same results for the run with and without checkpointing. But the gradient of x differs in the second iteration. Additionally we see that the scaling fp8_meta["scaling_bwd"].scale[0] is not updated for the run with activation checkpointing. My guess is that this is due to the fact that reduce_and_update_bwd_fp8_tensors is not activated because FP8GlobalStateManager.is_first_fp8_module() == False in the second forward pass for activation recompute which is outside of the fp8_autocast context, see here.

Other modules like TransformerLayer have the same issue. Could you please have a look and check if you can reproduce these findings? Are we using checkpointing improperly?

Thanks

@ksivaman
Copy link
Member

@Marks101 Could you try #1254

@Marks101
Copy link
Contributor Author

Hi @ksivaman, thanks for taking a look at this. Your fix looks good: I installed it in our training environment and was able to prove that it fixes the issue with FP8 and activation recompute 🥳

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants