You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
we noticed training instabilities when combining FP8 and activation checkpointing with transformer_engine.pytorch.checkpoint. When taking a closer look at this, we got the feeling that the FP8 scales in the backward pass are not updated properly. Here is a code snippet that is meant to show this behavior:
The snippet creates a linear layer and runs two forward passes on it with FP8. This is excuted once with and the other time without activation checkpointing. The results are recorded and compared at the end.
When executing this code I get the following output:
We can see that the forward passes give the same results for the run with and without checkpointing. But the gradient of x differs in the second iteration. Additionally we see that the scaling fp8_meta["scaling_bwd"].scale[0] is not updated for the run with activation checkpointing. My guess is that this is due to the fact that reduce_and_update_bwd_fp8_tensors is not activated because FP8GlobalStateManager.is_first_fp8_module() == False in the second forward pass for activation recompute which is outside of the fp8_autocast context, see here.
Other modules like TransformerLayer have the same issue. Could you please have a look and check if you can reproduce these findings? Are we using checkpointing improperly?
Thanks
The text was updated successfully, but these errors were encountered:
Hi @ksivaman, thanks for taking a look at this. Your fix looks good: I installed it in our training environment and was able to prove that it fixes the issue with FP8 and activation recompute 🥳
Hello team,
we noticed training instabilities when combining FP8 and activation checkpointing with
transformer_engine.pytorch.checkpoint
. When taking a closer look at this, we got the feeling that the FP8 scales in the backward pass are not updated properly. Here is a code snippet that is meant to show this behavior:The snippet creates a linear layer and runs two forward passes on it with FP8. This is excuted once with and the other time without activation checkpointing. The results are recorded and compared at the end.
When executing this code I get the following output:
We can see that the forward passes give the same results for the run with and without checkpointing. But the gradient of x differs in the second iteration. Additionally we see that the scaling
fp8_meta["scaling_bwd"].scale[0]
is not updated for the run with activation checkpointing. My guess is that this is due to the fact thatreduce_and_update_bwd_fp8_tensors
is not activated becauseFP8GlobalStateManager.is_first_fp8_module() == False
in the second forward pass for activation recompute which is outside of thefp8_autocast
context, see here.Other modules like TransformerLayer have the same issue. Could you please have a look and check if you can reproduce these findings? Are we using checkpointing improperly?
Thanks
The text was updated successfully, but these errors were encountered: