Skip to content

Commit

Permalink
assert amax reduction is needed for FP8+CP
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
  • Loading branch information
xrennvidia committed Aug 15, 2024
1 parent 1e53357 commit 67c7e7b
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5910,6 +5910,9 @@ def forward(
assert (
fp8_meta is not None
), "FP8 metadata fp8_meta is required for FP8 attention!"
assert (
not context_parallel or fp8_meta["recipe"].reduce_amax
), "Amax reduction across TP+CP group is necessary when using context parallelism with FP8!"

if context_parallel:
assert (
Expand Down

0 comments on commit 67c7e7b

Please sign in to comment.